Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Manual Network Try 3. #340

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 29 additions & 4 deletions exo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from exo.networking.grpc.grpc_server import GRPCServer
from exo.networking.udp.udp_discovery import UDPDiscovery
from exo.networking.tailscale.tailscale_discovery import TailscaleDiscovery
from exo.networking.manual.read_config import ReadManualConfig
from exo.networking.manual.manual_discovery import ManualDiscovery
from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
from exo.api import ChatGPTAPI
Expand All @@ -35,8 +37,9 @@
parser.add_argument("--max-parallel-downloads", type=int, default=4, help="Max parallel downloads for model shards download")
parser.add_argument("--prometheus-client-port", type=int, default=None, help="Prometheus client port")
parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast port for discovery")
parser.add_argument("--discovery-module", type=str, choices=["udp", "tailscale"], default="udp", help="Discovery module to use")
parser.add_argument("--discovery-module", type=str, choices=["udp", "manual", "tailscale"], default="udp", help="Discovery module to use")
parser.add_argument("--discovery-timeout", type=int, default=30, help="Discovery timeout in seconds")
parser.add_argument("--discovery-config", type=str, default="topology.yml", help="Config file for manual discovery")
parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of peers to wait to connect to before starting")
parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT API port")
parser.add_argument("--chatgpt-api-response-timeout", type=int, default=90, help="ChatGPT API response timeout in seconds")
Expand All @@ -59,9 +62,28 @@
inference_engine = get_inference_engine(inference_engine_name, shard_downloader)
print(f"Using inference engine: {inference_engine.__class__.__name__} with shard downloader: {shard_downloader.__class__.__name__}")

if args.node_port is None:
args.node_port = find_available_port(args.node_host)
if DEBUG >= 1: print(f"Using available port: {args.node_port}")
try:
if args.discovery_module == "manual":
if args.discovery_config is None:
raise ValueError("--discovery-config is necessary when using --discovery-module manual")
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

redundant else

# Read for the current instance
list_device = ReadManualConfig(discovery_config=args.discovery_config)
list_device.device_capabilities((str((list_device.whoami))))
# Initialisation of the current instance with the first ReadManualConfig, and directly into main.py
args.node_id = list_device.node_id
args.node_host = list_device.node_host
args.node_port = list_device.node_port
args.wait_for_peers = 1
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why else? can't the code here also run for manual discovery?

if args.node_port is None:
args.node_port = find_available_port(args.node_host)
if DEBUG >= 1: print(f"Using available port: {args.node_port}")
except ValueError as e:
if DEBUG >= 2:
print(f"Error: {e}")
traceback.print_exc()
exit()

args.node_id = args.node_id or get_or_create_node_id()
chatgpt_api_endpoints = [f"http://{ip}:{args.chatgpt_api_port}/v1/chat/completions" for ip in get_all_ip_addresses()]
Expand All @@ -76,6 +98,8 @@

if args.discovery_module == "udp":
discovery = UDPDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port, lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities), discovery_timeout=args.discovery_timeout)
elif args.discovery_module == "manual":
discovery = ManualDiscovery(args.node_id, args.node_port, lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities), discovery_config=args.discovery_config, discovery_timeout=args.discovery_timeout)
elif args.discovery_module == "tailscale":
discovery = TailscaleDiscovery(args.node_id, args.node_port, lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities), discovery_timeout=args.discovery_timeout, tailscale_api_key=args.tailscale_api_key, tailnet=args.tailnet_name)
topology_viz = TopologyViz(chatgpt_api_endpoints=chatgpt_api_endpoints, web_chat_urls=web_chat_urls) if not args.disable_tui else None
Expand All @@ -99,6 +123,7 @@
node.on_token.register("update_topology_viz").on_next(
lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") else None
)

def preemptively_start_download(request_id: str, opaque_status: str):
try:
status = json.loads(opaque_status)
Expand Down
Empty file.
96 changes: 96 additions & 0 deletions exo/networking/manual/manual_discovery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import asyncio
import time
import traceback
from typing import List, Dict, Callable, Tuple
from exo.networking.discovery import Discovery
from exo.networking.peer_handle import PeerHandle
from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
from exo.helpers import DEBUG
from exo.networking.manual.read_config import ReadManualConfig
from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle

class ManualDiscovery(Discovery):
def __init__(
self,
node_id: str,
node_port: int,
create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle],
device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES,
discovery_config: str = "topology.yml",
discovery_interval: int = 5,
discovery_timeout: int = 30,
update_interval: int = 15,
):
self.node_id = node_id
self.node_port = node_port
self.discovery_config = discovery_config
# Read for discovering every other devices except this instance
self.list_device = ReadManualConfig(discovery_config=self.discovery_config)
self.create_peer_handle = create_peer_handle
self.discovery_interval = discovery_interval
self.discovery_timeout = discovery_timeout
self.update_interval = update_interval
self.device_capabilities = device_capabilities
self.known_peers: Dict[str, Tuple[PeerHandle, float, float]] = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this just be Dict[str, PeerHandle] now? I don't think we need the times any more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know about that, for now I simplified the discovery with one function that input all the other devices on startup.
But I remove disconnection, ... etc for debug purpose.

Frankly it is your call, same for the GRPC problem.
Do you have some discord time to help me on :
is_connected = await new_peer_handle.is_connected()
health_ok = await new_peer_handle.health_check()
that are not working, please ?

Thanks in advance.
Best Regards.
Benjamin.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure I can help. it's quite simple tho - only healthy peers should be in the topology

self.discovery_task = None
self.cleanup_task = None
self._device_id = None
self.update_task = None

async def start(self):
self.device_capabilities = self.list_device.device_capabilities(self.list_device.whoami)
self.discovery_task = asyncio.create_task(self.task_discover_peers())

async def stop(self):
if self.discovery_task:
self.discovery_task.cancel()
await asyncio.gather(self.discovery_task, return_exceptions=True)

async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
if wait_for_peers > 0:
while len(self.known_peers) < wait_for_peers:
if DEBUG >= 2: print(f"Current peers: {len(self.known_peers)}/{wait_for_peers}. Waiting for more peers...")
await asyncio.sleep(0.1)
return [peer_handle for peer_handle, _, _, _ in self.known_peers.values()]

async def task_discover_peers(self):
while True:
try:
if DEBUG >= 2: print("task_discover_peers")
current_time = time.time()

for device in self.list_device.config_devices:
if f"{self.list_device.whoami}" != f"{device['server']}":
if DEBUG >= 2: print(f"Getting Id {device['id']} == Adresse: {device['address']} {device['port']}")
peer_id = device['id']
peer_host = device['address']
peer_port = device['port']
if peer_id not in self.known_peers or self.known_peers[peer_id][0].addr() != f"{peer_host}:{peer_port}":
new_peer_handle = self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", self.list_device.device_capabilities((str((f"{device['server']}")))))
try:
is_connected = await new_peer_handle.is_connected()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to check is_connected. just health check

health_ok = await new_peer_handle.health_check()
if is_connected == True:
if health_ok == True:
if DEBUG >= 2: print(f"Adding {peer_id=} at {peer_host}:{peer_port}. Replace existing peer_id: {peer_id in self.known_peers}")
self.known_peers[peer_id] = (
new_peer_handle,
current_time,
current_time,
)
else:
if DEBUG >= 2: print(f"{peer_id=} at {peer_host}:{peer_port} not healthy.")
else:
if DEBUG >= 2: print(f"{peer_id=} at {peer_host}:{peer_port} not connected.")
except Exception as e:
if DEBUG >= 2: print(f"Error checking peer {peer_id}: {e}")

if DEBUG >= 2: print("Peer statuses:", {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, {connected_at=}, {last_seen=}, health_check={await peer_handle.health_check()}" for peer_handle, connected_at, last_seen in self.known_peers.values()})

except Exception as e:
print(f"Error in discover peers: {e}")
print(traceback.format_exc())
finally:
await asyncio.sleep(self.discovery_interval)


114 changes: 114 additions & 0 deletions exo/networking/manual/read_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
from exo.helpers import DEBUG
import socket
import yaml

class ReadManualConfig():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this a class? Just make a single function that reads the device capabilities from the file

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Frankly, I think it is more flexible.
For exemple, for now, I set args.wait_for_peers = 1 in main.py.
But after we finish debug, I want the number of peer to wait to comme from ReadManualConfig.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont understand what you mean

def __init__(
self,
discovery_config: str = "topology.yml",
):
self._discovery_config = discovery_config
self._config_devices = None
self._whoami = f"{socket.gethostname()}"
self._node_id = "NONE"
self._node_host = "NONE"
self._node_port = "NONE"
self._model = "NONE"
self._chip = "NONE"
self._memory = "NONE"
self._fp32 = "NONE"
self._fp16 = "NONE"
self._int8 = "NONE"

def device_capabilities(self, gethostname) -> DeviceCapabilities:
with open(self._discovery_config, 'r') as f:
self._config_devices = yaml.safe_load(f)
f.close()

for device in self._config_devices:
if (str(f"{gethostname}")) == (str(f"{device['server']}")):
if DEBUG >= 2: print(f"Read Id {device['id']} == Adresse: {device['address']} {device['port']}")
self._node_id = (str(f"{device['id']}"))
self._node_host = (str(f"{device['address']}"))
self._node_port = (int(f"{device['port']}"))
if DEBUG >= 2: print(f"Capabilities:")
for capability, value in device['device_capabilities'].items():
if DEBUG >= 2: print(f"{capability}: {value}")
if f"{capability}" == "model":
self._model = (str(f"{value}"))
if f"{capability}" == "chip":
self._chip = (str(f"{value}"))
if f"{capability}" == "memory":
self._memory = (float(f"{value}"))
if f"{capability}" == "flops":
for flopstr, flopvalue in device['device_capabilities']['flops'].items():
if f"{flopstr}" == "fp32":
self._fp32 = (float(f"{flopvalue}"))
if f"{flopstr}" == "fp16":
self._fp16 = (float(f"{flopvalue}"))
if f"{flopstr}" == "int8":
self._int8 = (float(f"{flopvalue}"))

return DeviceCapabilities(
model=self._model,
chip=self._chip,
memory=self._memory // 2**20,
flops=DeviceFlops(fp32=self._fp32, fp16=self._fp16, int8=self._int8),
)

return DeviceCapabilities(
model="Unknown Device",
chip="Unknown Chip",
memory=0 // 2**20,
flops=DeviceFlops(fp32=0, fp16=0, int8=0),
)

@property
def discovery_config(self):
return self._discovery_config

@property
def config_devices(self):
return self._config_devices

@property
def whoami(self):
return self._whoami

@property
def node_id(self):
return self._node_id

@property
def node_host(self):
return self._node_host

@property
def node_port(self):
return self._node_port

@property
def model(self):
return self._model

@property
def chip(self):
return self._chip

@property
def memory(self):
return self._memory

@property
def fp32(self):
return self._fp32

@property
def fp16(self):
return self._fp16

@property
def int8(self):
return self._int8

6 changes: 6 additions & 0 deletions test/test_topology.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/bin/bash
if test ! -e /var/www/exo/python/bin/python3; then
/usr/bin/python3 -m venv --system-site-packages /var/www/exo/python/
fi
/var/www/exo/python/bin/pip install pyyaml
/var/www/exo/python/bin/python3 /var/www/exo/validate_topology.py
24 changes: 24 additions & 0 deletions test/topology.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
- server: 'edgenode2'
id: '5e9bdfc4-3a8f-47d2-b1ec-dc7ee3c11a85'
address: 192.168.45.22
port: 50051
device_capabilities:
model: 'NVIDIA RTX 4000 ADA GENERATION'
chip: 'NVIDIA'
memory: 20
flops:
fp32: 26.7
fp16: 26.7
int8: 258.0
- server: 'edgenode3'
id: 'bbf0d2fe-a37c-46ce-be23-4a14f6c52c74'
address: 192.168.45.32
port: 50051
device_capabilities:
model: 'NVIDIA RTX 4000 ADA GENERATION'
chip: 'NVIDIA'
memory: 20
flops:
fp32: 26.7
fp16: 26.7
int8: 258.0
15 changes: 15 additions & 0 deletions test/validate_topology.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import yaml

with open('topology.yml', 'r') as f:
edgenodes = yaml.safe_load(f)

for node in edgenodes:
print(f"{node['server']} {node['id']}:")
print(f" Adresse: {node['address']} {node['port']}")
print(f" Capabilities:")
for capability, value in node['device_capabilities'].items():
print(f" {capability}: {value}")
if f"{capability}" == "flops":
for flopstr, flopvalue in node['device_capabilities']['flops'].items():
print(f" {flopstr}: {flopvalue}")