-
Notifications
You must be signed in to change notification settings - Fork 593
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Manual Network Try 3. #340
base: main
Are you sure you want to change the base?
Changes from all commits
4243d80
9ba61c4
c67085f
7607ed2
99604da
31366c6
f57bd8a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,8 @@ | |
from exo.networking.grpc.grpc_server import GRPCServer | ||
from exo.networking.udp.udp_discovery import UDPDiscovery | ||
from exo.networking.tailscale.tailscale_discovery import TailscaleDiscovery | ||
from exo.networking.manual.read_config import ReadManualConfig | ||
from exo.networking.manual.manual_discovery import ManualDiscovery | ||
from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle | ||
from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy | ||
from exo.api import ChatGPTAPI | ||
|
@@ -35,8 +37,9 @@ | |
parser.add_argument("--max-parallel-downloads", type=int, default=4, help="Max parallel downloads for model shards download") | ||
parser.add_argument("--prometheus-client-port", type=int, default=None, help="Prometheus client port") | ||
parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast port for discovery") | ||
parser.add_argument("--discovery-module", type=str, choices=["udp", "tailscale"], default="udp", help="Discovery module to use") | ||
parser.add_argument("--discovery-module", type=str, choices=["udp", "manual", "tailscale"], default="udp", help="Discovery module to use") | ||
parser.add_argument("--discovery-timeout", type=int, default=30, help="Discovery timeout in seconds") | ||
parser.add_argument("--discovery-config", type=str, default="topology.yml", help="Config file for manual discovery") | ||
parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of peers to wait to connect to before starting") | ||
parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT API port") | ||
parser.add_argument("--chatgpt-api-response-timeout", type=int, default=90, help="ChatGPT API response timeout in seconds") | ||
|
@@ -59,9 +62,28 @@ | |
inference_engine = get_inference_engine(inference_engine_name, shard_downloader) | ||
print(f"Using inference engine: {inference_engine.__class__.__name__} with shard downloader: {shard_downloader.__class__.__name__}") | ||
|
||
if args.node_port is None: | ||
args.node_port = find_available_port(args.node_host) | ||
if DEBUG >= 1: print(f"Using available port: {args.node_port}") | ||
try: | ||
if args.discovery_module == "manual": | ||
if args.discovery_config is None: | ||
raise ValueError("--discovery-config is necessary when using --discovery-module manual") | ||
else: | ||
# Read for the current instance | ||
list_device = ReadManualConfig(discovery_config=args.discovery_config) | ||
list_device.device_capabilities((str((list_device.whoami)))) | ||
# Initialisation of the current instance with the first ReadManualConfig, and directly into main.py | ||
args.node_id = list_device.node_id | ||
args.node_host = list_device.node_host | ||
args.node_port = list_device.node_port | ||
args.wait_for_peers = 1 | ||
else: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why else? can't the code here also run for manual discovery? |
||
if args.node_port is None: | ||
args.node_port = find_available_port(args.node_host) | ||
if DEBUG >= 1: print(f"Using available port: {args.node_port}") | ||
except ValueError as e: | ||
if DEBUG >= 2: | ||
print(f"Error: {e}") | ||
traceback.print_exc() | ||
exit() | ||
|
||
args.node_id = args.node_id or get_or_create_node_id() | ||
chatgpt_api_endpoints = [f"http://{ip}:{args.chatgpt_api_port}/v1/chat/completions" for ip in get_all_ip_addresses()] | ||
|
@@ -76,6 +98,8 @@ | |
|
||
if args.discovery_module == "udp": | ||
discovery = UDPDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port, lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities), discovery_timeout=args.discovery_timeout) | ||
elif args.discovery_module == "manual": | ||
discovery = ManualDiscovery(args.node_id, args.node_port, lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities), discovery_config=args.discovery_config, discovery_timeout=args.discovery_timeout) | ||
elif args.discovery_module == "tailscale": | ||
discovery = TailscaleDiscovery(args.node_id, args.node_port, lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities), discovery_timeout=args.discovery_timeout, tailscale_api_key=args.tailscale_api_key, tailnet=args.tailnet_name) | ||
topology_viz = TopologyViz(chatgpt_api_endpoints=chatgpt_api_endpoints, web_chat_urls=web_chat_urls) if not args.disable_tui else None | ||
|
@@ -99,6 +123,7 @@ | |
node.on_token.register("update_topology_viz").on_next( | ||
lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") else None | ||
) | ||
|
||
def preemptively_start_download(request_id: str, opaque_status: str): | ||
try: | ||
status = json.loads(opaque_status) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
import asyncio | ||
import time | ||
import traceback | ||
from typing import List, Dict, Callable, Tuple | ||
from exo.networking.discovery import Discovery | ||
from exo.networking.peer_handle import PeerHandle | ||
from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES | ||
from exo.helpers import DEBUG | ||
from exo.networking.manual.read_config import ReadManualConfig | ||
from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle | ||
|
||
class ManualDiscovery(Discovery): | ||
def __init__( | ||
self, | ||
node_id: str, | ||
node_port: int, | ||
create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle], | ||
device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES, | ||
discovery_config: str = "topology.yml", | ||
discovery_interval: int = 5, | ||
discovery_timeout: int = 30, | ||
update_interval: int = 15, | ||
): | ||
self.node_id = node_id | ||
self.node_port = node_port | ||
self.discovery_config = discovery_config | ||
# Read for discovering every other devices except this instance | ||
self.list_device = ReadManualConfig(discovery_config=self.discovery_config) | ||
self.create_peer_handle = create_peer_handle | ||
self.discovery_interval = discovery_interval | ||
self.discovery_timeout = discovery_timeout | ||
self.update_interval = update_interval | ||
self.device_capabilities = device_capabilities | ||
self.known_peers: Dict[str, Tuple[PeerHandle, float, float]] = {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't this just be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't know about that, for now I simplified the discovery with one function that input all the other devices on startup. Frankly it is your call, same for the GRPC problem. Thanks in advance. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure I can help. it's quite simple tho - only healthy peers should be in the topology |
||
self.discovery_task = None | ||
self.cleanup_task = None | ||
self._device_id = None | ||
self.update_task = None | ||
|
||
async def start(self): | ||
self.device_capabilities = self.list_device.device_capabilities(self.list_device.whoami) | ||
self.discovery_task = asyncio.create_task(self.task_discover_peers()) | ||
|
||
async def stop(self): | ||
if self.discovery_task: | ||
self.discovery_task.cancel() | ||
await asyncio.gather(self.discovery_task, return_exceptions=True) | ||
|
||
async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]: | ||
if wait_for_peers > 0: | ||
while len(self.known_peers) < wait_for_peers: | ||
if DEBUG >= 2: print(f"Current peers: {len(self.known_peers)}/{wait_for_peers}. Waiting for more peers...") | ||
await asyncio.sleep(0.1) | ||
return [peer_handle for peer_handle, _, _, _ in self.known_peers.values()] | ||
|
||
async def task_discover_peers(self): | ||
while True: | ||
try: | ||
if DEBUG >= 2: print("task_discover_peers") | ||
current_time = time.time() | ||
|
||
for device in self.list_device.config_devices: | ||
if f"{self.list_device.whoami}" != f"{device['server']}": | ||
if DEBUG >= 2: print(f"Getting Id {device['id']} == Adresse: {device['address']} {device['port']}") | ||
peer_id = device['id'] | ||
peer_host = device['address'] | ||
peer_port = device['port'] | ||
if peer_id not in self.known_peers or self.known_peers[peer_id][0].addr() != f"{peer_host}:{peer_port}": | ||
new_peer_handle = self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", self.list_device.device_capabilities((str((f"{device['server']}"))))) | ||
try: | ||
is_connected = await new_peer_handle.is_connected() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no need to check is_connected. just health check |
||
health_ok = await new_peer_handle.health_check() | ||
if is_connected == True: | ||
if health_ok == True: | ||
if DEBUG >= 2: print(f"Adding {peer_id=} at {peer_host}:{peer_port}. Replace existing peer_id: {peer_id in self.known_peers}") | ||
self.known_peers[peer_id] = ( | ||
new_peer_handle, | ||
current_time, | ||
current_time, | ||
) | ||
else: | ||
if DEBUG >= 2: print(f"{peer_id=} at {peer_host}:{peer_port} not healthy.") | ||
else: | ||
if DEBUG >= 2: print(f"{peer_id=} at {peer_host}:{peer_port} not connected.") | ||
except Exception as e: | ||
if DEBUG >= 2: print(f"Error checking peer {peer_id}: {e}") | ||
|
||
if DEBUG >= 2: print("Peer statuses:", {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, {connected_at=}, {last_seen=}, health_check={await peer_handle.health_check()}" for peer_handle, connected_at, last_seen in self.known_peers.values()}) | ||
|
||
except Exception as e: | ||
print(f"Error in discover peers: {e}") | ||
print(traceback.format_exc()) | ||
finally: | ||
await asyncio.sleep(self.discovery_interval) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops | ||
from exo.helpers import DEBUG | ||
import socket | ||
import yaml | ||
|
||
class ReadManualConfig(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this a class? Just make a single function that reads the device capabilities from the file There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Frankly, I think it is more flexible. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I dont understand what you mean |
||
def __init__( | ||
self, | ||
discovery_config: str = "topology.yml", | ||
): | ||
self._discovery_config = discovery_config | ||
self._config_devices = None | ||
self._whoami = f"{socket.gethostname()}" | ||
self._node_id = "NONE" | ||
self._node_host = "NONE" | ||
self._node_port = "NONE" | ||
self._model = "NONE" | ||
self._chip = "NONE" | ||
self._memory = "NONE" | ||
self._fp32 = "NONE" | ||
self._fp16 = "NONE" | ||
self._int8 = "NONE" | ||
|
||
def device_capabilities(self, gethostname) -> DeviceCapabilities: | ||
with open(self._discovery_config, 'r') as f: | ||
self._config_devices = yaml.safe_load(f) | ||
f.close() | ||
|
||
for device in self._config_devices: | ||
if (str(f"{gethostname}")) == (str(f"{device['server']}")): | ||
if DEBUG >= 2: print(f"Read Id {device['id']} == Adresse: {device['address']} {device['port']}") | ||
self._node_id = (str(f"{device['id']}")) | ||
self._node_host = (str(f"{device['address']}")) | ||
self._node_port = (int(f"{device['port']}")) | ||
if DEBUG >= 2: print(f"Capabilities:") | ||
for capability, value in device['device_capabilities'].items(): | ||
if DEBUG >= 2: print(f"{capability}: {value}") | ||
if f"{capability}" == "model": | ||
self._model = (str(f"{value}")) | ||
if f"{capability}" == "chip": | ||
self._chip = (str(f"{value}")) | ||
if f"{capability}" == "memory": | ||
self._memory = (float(f"{value}")) | ||
if f"{capability}" == "flops": | ||
for flopstr, flopvalue in device['device_capabilities']['flops'].items(): | ||
if f"{flopstr}" == "fp32": | ||
self._fp32 = (float(f"{flopvalue}")) | ||
if f"{flopstr}" == "fp16": | ||
self._fp16 = (float(f"{flopvalue}")) | ||
if f"{flopstr}" == "int8": | ||
self._int8 = (float(f"{flopvalue}")) | ||
|
||
return DeviceCapabilities( | ||
model=self._model, | ||
chip=self._chip, | ||
memory=self._memory // 2**20, | ||
flops=DeviceFlops(fp32=self._fp32, fp16=self._fp16, int8=self._int8), | ||
) | ||
|
||
return DeviceCapabilities( | ||
model="Unknown Device", | ||
chip="Unknown Chip", | ||
memory=0 // 2**20, | ||
flops=DeviceFlops(fp32=0, fp16=0, int8=0), | ||
) | ||
|
||
@property | ||
def discovery_config(self): | ||
return self._discovery_config | ||
|
||
@property | ||
def config_devices(self): | ||
return self._config_devices | ||
|
||
@property | ||
def whoami(self): | ||
return self._whoami | ||
|
||
@property | ||
def node_id(self): | ||
return self._node_id | ||
|
||
@property | ||
def node_host(self): | ||
return self._node_host | ||
|
||
@property | ||
def node_port(self): | ||
return self._node_port | ||
|
||
@property | ||
def model(self): | ||
return self._model | ||
|
||
@property | ||
def chip(self): | ||
return self._chip | ||
|
||
@property | ||
def memory(self): | ||
return self._memory | ||
|
||
@property | ||
def fp32(self): | ||
return self._fp32 | ||
|
||
@property | ||
def fp16(self): | ||
return self._fp16 | ||
|
||
@property | ||
def int8(self): | ||
return self._int8 | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
#!/bin/bash | ||
if test ! -e /var/www/exo/python/bin/python3; then | ||
/usr/bin/python3 -m venv --system-site-packages /var/www/exo/python/ | ||
fi | ||
/var/www/exo/python/bin/pip install pyyaml | ||
/var/www/exo/python/bin/python3 /var/www/exo/validate_topology.py |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
- server: 'edgenode2' | ||
id: '5e9bdfc4-3a8f-47d2-b1ec-dc7ee3c11a85' | ||
address: 192.168.45.22 | ||
port: 50051 | ||
device_capabilities: | ||
model: 'NVIDIA RTX 4000 ADA GENERATION' | ||
chip: 'NVIDIA' | ||
memory: 20 | ||
flops: | ||
fp32: 26.7 | ||
fp16: 26.7 | ||
int8: 258.0 | ||
- server: 'edgenode3' | ||
id: 'bbf0d2fe-a37c-46ce-be23-4a14f6c52c74' | ||
address: 192.168.45.32 | ||
port: 50051 | ||
device_capabilities: | ||
model: 'NVIDIA RTX 4000 ADA GENERATION' | ||
chip: 'NVIDIA' | ||
memory: 20 | ||
flops: | ||
fp32: 26.7 | ||
fp16: 26.7 | ||
int8: 258.0 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import yaml | ||
|
||
with open('topology.yml', 'r') as f: | ||
edgenodes = yaml.safe_load(f) | ||
|
||
for node in edgenodes: | ||
print(f"{node['server']} {node['id']}:") | ||
print(f" Adresse: {node['address']} {node['port']}") | ||
print(f" Capabilities:") | ||
for capability, value in node['device_capabilities'].items(): | ||
print(f" {capability}: {value}") | ||
if f"{capability}" == "flops": | ||
for flopstr, flopvalue in node['device_capabilities']['flops'].items(): | ||
print(f" {flopstr}: {flopvalue}") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
redundant else