diff --git a/colossalai/cli/launcher/__init__.py b/colossalai/cli/launcher/__init__.py index c63008af7659..4ada68b4b68f 100644 --- a/colossalai/cli/launcher/__init__.py +++ b/colossalai/cli/launcher/__init__.py @@ -5,27 +5,34 @@ @click.command(help="Launch distributed training on a single node or multiple nodes", context_settings=dict(ignore_unknown_options=True)) -@click.option("-H", "-host", "--host", type=str, default=None, help="the list of machines to launch") -@click.option("--hostfile", +@click.option("-H", + "-host", + "--host", type=str, default=None, - help="Hostfile path that defines the device pool available to the job (e.g. worker-name:number of slots)") + help="the list of hostnames to launch in the format ,") @click.option( - "--include", + "--hostfile", type=str, default=None, - help= - "Specify computing devices to use during execution. String format is NODE_SPEC@NODE_SPEC where NODE_SPEC=:" -) + help="Hostfile path that defines the device pool available to the job, each line in the file is a hostname") +@click.option("--include", + type=str, + default=None, + help="Specify computing devices to use during execution. String format is ,," + " only effective when used with --hostfile.") @click.option( "--exclude", type=str, default=None, help= - "Specify computing devices to NOT use during execution. Mutually exclusive with --include. Formatting is the same as --include." -) -@click.option("--num_nodes", type=int, default=-1, help="Total number of worker nodes to use.") -@click.option("--nproc_per_node", type=int, default=-1, help="Number of GPUs to use on each node.") + "Specify computing devices to NOT use during execution. Mutually exclusive with --include. Formatting is the same as --includ," + " only effective when used with --hostfile.") +@click.option("--num_nodes", + type=int, + default=-1, + help="Total number of worker nodes to use, only effective when used with --hostfile.") +@click.option("--nproc_per_node", type=int, default=None, help="Number of GPUs to use on each node.") @click.option("--master_port", type=int, default=29500, @@ -35,34 +42,43 @@ default="127.0.0.1", help="(optional) IP address of node 0, will be inferred via 'hostname -I' if not specified.") @click.option( - "--launcher", - type=click.Choice(['torch', 'openmpi', 'slurm'], case_sensitive=False), - default="torch", - help="(optional) choose launcher backend for multi-node training. Options currently include PDSH, OpenMPI, SLURM.") -@click.option("--launcher_args", - type=str, - default=None, - help="(optional) pass launcher specific arguments as a single quoted argument.") + "--extra_launch_args", + type=str, + default=None, + help= + "Set additional torch distributed launcher arguments such as --standalone. The format is --extra_launch_args arg1=1,arg2=2. " + "This will be converted to --arg1=1 --arg2=2 during execution") +@click.option("--ssh-port", type=int, default=None, help="(optional) the port used for ssh connection") @click.argument("user_script", type=str) @click.argument('user_args', nargs=-1) def run(host: str, hostfile: str, num_nodes: int, nproc_per_node: int, include: str, exclude: str, master_addr: str, - master_port: int, launcher: str, launcher_args: str, user_script: str, user_args: str): + master_port: int, extra_launch_args: str, ssh_port: int, user_script: str, user_args: str) -> None: """ To launch multiple processes on a single node or multiple nodes via command line. Usage:: - # run on the current node with all available GPUs - colossalai run train.py + # run with 4 GPUs on the current node use default port 29500 + colossalai run --nprocs_per_node 4 train.py - # run with only 2 GPUs on the current node - colossalai run --nprocs_per_node 2 train.py + # run with 2 GPUs on the current node at port 29550 + colossalai run --nprocs_per_node 4 --master_port 29550 train.py # run on two nodes - colossalai run --host , train.py + colossalai run --host , --master_addr host1 --nprocs_per_node 4 train.py # run with hostfile - colossalai run --hostfile train.py + colossalai run --hostfile --master_addr --nprocs_per_node 4 train.py + + # run with hostfile with only included hosts + colossalai run --hostfile --master_addr host1 --include host1,host2 --nprocs_per_node 4 train.py + + # run with hostfile excluding the hosts selected + colossalai run --hostfile --master_addr host1 --exclude host2 --nprocs_per_node 4 train.py """ + if not user_script.endswith('.py'): + click.echo(f'Error: invalid Python file {user_script}. Did you use a wrong option? Try colossalai run --help') + exit() + args_dict = locals() args = Config(args_dict) args.user_args = list(args.user_args) diff --git a/colossalai/cli/launcher/hostinfo.py b/colossalai/cli/launcher/hostinfo.py new file mode 100644 index 000000000000..2f0830c5880d --- /dev/null +++ b/colossalai/cli/launcher/hostinfo.py @@ -0,0 +1,122 @@ +from typing import List +import socket + + +class HostInfo: + """ + A data class to store host connection-related data. + + Args: + hostname (str): name or IP address of the host + port (str): the port for ssh connection + """ + + def __init__( + self, + hostname: str, + port: str = None, + ): + self.hostname = hostname + self.port = port + self.is_local_host = HostInfo.is_host_localhost(hostname, port) + + @staticmethod + def is_host_localhost(hostname: str, port: str = None) -> None: + """ + Check if the host refers to the local machine. + + Args: + hostname (str): name or IP address of the host + port (str): the port for ssh connection + + Returns: + bool: True if it is local, False otherwise + """ + + if port is None: + port = 22 # no port specified, lets just use the ssh port + hostname = socket.getfqdn(hostname) + if hostname in ("localhost", "127.0.0.1", "0.0.0.0"): + return True + localhost = socket.gethostname() + localaddrs = socket.getaddrinfo(localhost, port) + targetaddrs = socket.getaddrinfo(hostname, port) + for (family, socktype, proto, canonname, sockaddr) in localaddrs: + for (rfamily, rsocktype, rproto, rcanonname, rsockaddr) in targetaddrs: + if rsockaddr[0] == sockaddr[0]: + return True + return False + + def __str__(self): + return f'hostname: {self.hostname}, port: {self.port}' + + def __repr__(self): + return self.__str__() + + +class HostInfoList: + """ + A data class to store a list of HostInfo objects. + """ + + def __init__(self): + self.hostinfo_list = [] + + def append(self, hostinfo: HostInfo) -> None: + """ + Add an HostInfo object to the list. + + Args: + hostinfo (HostInfo): host information + """ + + self.hostinfo_list.append(hostinfo) + + def remove(self, hostname: str) -> None: + """ + Add an HostInfo object to the list. + + Args: + hostname (str): the name of the host + """ + + hostinfo = self.get_hostinfo(hostname) + self.hostinfo_list.remove(hostinfo) + + def get_hostinfo(self, hostname: str) -> HostInfo: + """ + Return the HostInfo object which matches with the hostname. + + Args: + hostname (str): the name of the host + + Returns: + hostinfo (HostInfo): the HostInfo object which matches with the hostname + """ + + for hostinfo in self.hostinfo_list: + if hostinfo.hostname == hostname: + return hostinfo + + raise Exception(f"Hostname {hostname} is not found") + + def has(self, hostname: str) -> bool: + """ + Check if the hostname has been added. + + Args: + hostname (str): the name of the host + + Returns: + bool: True if added, False otherwise + """ + for hostinfo in self.hostinfo_list: + if hostinfo.hostname == hostname: + return True + return False + + def __iter__(self): + return iter(self.hostinfo_list) + + def __len__(self): + return len(self.hostinfo_list) diff --git a/colossalai/cli/launcher/multinode_runner.py b/colossalai/cli/launcher/multinode_runner.py index 9a80cc2954d9..b3cabcbf2ae5 100644 --- a/colossalai/cli/launcher/multinode_runner.py +++ b/colossalai/cli/launcher/multinode_runner.py @@ -1,69 +1,120 @@ -import os -import sys -import shutil -from shlex import quote -from abc import ABC, abstractmethod - -from colossalai.logging import get_dist_logger - - -class MultiNodeRunner(ABC): - - def __init__(self, args): - self.args = args - self.user_arguments = self.args.user_args - self.user_script = args.user_script - self.exports = {} - - @abstractmethod - def backend_exists(self): - """Return whether the corresponding backend exists""" - - @abstractmethod - def get_cmd(self, environment, active_devices): - """Return the command to execute on node""" - - def add_export(self, key, var): - self.exports[key.strip()] = var.strip() - - @property - def name(self): - """Return the name of the backend""" - return self.__class__.__name__ - - -class PDSHRunner(MultiNodeRunner): - - def __init__(self, args): - super().__init__(args) - - def backend_exists(self): - return shutil.which('pdsh') - - @property - def name(self): - return "pdsh" - - def parse_user_args(self): - return list(map(lambda x: x if x.startswith("-") else f"'{x}'", self.args.user_args)) - - def get_cmd(self, environment, active_devices, args): - environment['PDSH_RCMD_TYPE'] = 'ssh' - - active_workers = ",".join(active_devices.keys()) - print("Running on the following workers: %s" % active_workers) - - pdsh_cmd_args = ['pdsh', '-f', str(1024), '-w', active_workers] - - exports = "" - for key, val in self.exports.items(): - exports += f"export {key}={quote(val)}; " - - # https://linux.die.net/man/1/pdsh - # %n will be replaced by pdsh command - colossal_launch = [ - exports, f"cd {os.path.abspath('.')};", sys.executable, "-u", "-m", "torch.distributed.launch", - f"--nproc_per_node={args.nproc_per_node}", f"--master_addr={args.master_addr}", - f"--master_port={args.master_port}" - ] - return pdsh_cmd_args + colossal_launch + [self.user_script] + self.user_arguments +import fabric +from fabric import Connection +from .hostinfo import HostInfo, HostInfoList +from multiprocessing import Pipe, Process +from multiprocessing import connection as mp_connection +import click + + +def run_on_host(hostinfo: HostInfo, workdir: str, recv_conn: mp_connection.Connection, + send_conn: mp_connection.Connection, env: dict) -> None: + """ + Use fabric connection to execute command on local or remote hosts. + + Args: + hostinfo (HostInfo): host information + workdir (str): the directory to execute the command + recv_conn (multiprocessing.connection.Connection): receive messages from the master sender + send_conn (multiprocessing.connection.Connection): send messages to the master receiver + env (dict): a dictionary for environment variables + """ + + fab_conn = fabric.Connection(hostinfo.hostname, port=hostinfo.port) + finish = False + env_msg = ' '.join([f'{k}=\"{v}\"' for k, v in env.items()]) + + # keep listening until exit + while not finish: + # receive cmd + cmds = recv_conn.recv() + + if cmds == 'exit': + # exit from the loop + finish = True + break + else: + # execute the commands + try: + # cd to execute directory + with fab_conn.cd(workdir): + # propagate the runtime environment + with fab_conn.prefix(f"export {env_msg}"): + if hostinfo.is_local_host: + # execute on the local machine + fab_conn.local(cmds, hide=False) + else: + # execute on the remote machine + fab_conn.run(cmds, hide=False) + send_conn.send('success') + except: + click.echo(f"Error: failed to run {cmds} on {hostinfo.hostname}") + send_conn.send('failure') + + # shutdown + send_conn.send("finish") + fab_conn.close() + + +class MultiNodeRunner: + """ + A runner to execute commands on an array of machines. This runner + is inspired by Nezha (https://github.com/zhuzilin/NeZha). + """ + + def __init__(self): + self.processes = {} + self.master_send_conns = {} + self.master_recv_conns = {} + + def connect(self, host_info_list: HostInfoList, workdir: str, env: dict) -> None: + """ + Establish connections to a list of hosts + + Args: + host_info_list (HostInfoList): a list of HostInfo objects + workdir (str): the directory where command is executed + env (dict): environment variables to propagate to hosts + """ + for hostinfo in host_info_list: + master_send_conn, worker_recv_conn = Pipe() + master_recv_conn, worker_send_conn = Pipe() + p = Process(target=run_on_host, args=(hostinfo, workdir, worker_recv_conn, worker_send_conn, env)) + p.start() + self.processes[hostinfo.hostname] = p + self.master_recv_conns[hostinfo.hostname] = master_recv_conn + self.master_send_conns[hostinfo.hostname] = master_send_conn + + def send(self, hostinfo: HostInfo, cmd: str) -> None: + """ + Send a command to a local/remote host. + + Args: + hostinfo (HostInfo): host information + cmd (str): the command to execute + """ + + assert hostinfo.hostname in self.master_send_conns, \ + f'{hostinfo} is not found in the current connections' + conn = self.master_send_conns[hostinfo.hostname] + conn.send(cmd) + + def stop_all(self) -> None: + """ + Stop connections to all hosts. + """ + + for hostname, conn in self.master_send_conns.items(): + conn.send('exit') + + def recv_from_all(self) -> dict: + """ + Receive messages from all hosts + + Returns: + msg_from_node (dict): a dictionry which contains messages from each node + """ + + msg_from_node = dict() + for hostname, conn in self.master_recv_conns.items(): + msg_from_node[hostname] = conn.recv() + return msg_from_node diff --git a/colossalai/cli/launcher/run.py b/colossalai/cli/launcher/run.py index 5bb88fd3e7f0..e078a57c15c9 100644 --- a/colossalai/cli/launcher/run.py +++ b/colossalai/cli/launcher/run.py @@ -1,65 +1,72 @@ import click -import subprocess -import collections import sys import os import torch from colossalai.context import Config -from .multinode_runner import PDSHRunner -from copy import deepcopy +from .multinode_runner import MultiNodeRunner +from .hostinfo import HostInfo, HostInfoList +from typing import List +from packaging import version +# Constants that define our syntax +NODE_SEP = ',' + + +def fetch_hostfile(hostfile_path: str, ssh_port: int) -> HostInfoList: + """ + Parse the hostfile to obtain a list of hosts. + + A hostfile should look like: + worker-0 + worker-1 + worker-2 + ... + + Args: + hostfile_path (str): the path to the hostfile + ssh_port (int): the port to connect to the host + """ -def fetch_hostfile(hostfile_path): if not os.path.isfile(hostfile_path): click.echo(f"Error: Unable to find the hostfile, no such file: {hostfile_path}") exit() - # e.g., worker-0:16 with open(hostfile_path, 'r') as fd: - device_pool = collections.OrderedDict() + device_pool = HostInfoList() + for line in fd.readlines(): line = line.strip() if line == '': # skip empty lines continue - try: - hostname, slot_count = line.split(":") - slot_count = int(slot_count) - except ValueError as err: - click.echo(f"Error: Hostfile is not formatted correctly, expected :, but found {line}") - exit() - if hostname in device_pool: + # build the HostInfo object + hostname = line.strip() + hostinfo = HostInfo(hostname=hostname, port=ssh_port) + + if device_pool.has(hostname): click.echo(f"Error: found duplicate host {hostname} in the hostfile") exit() - device_pool[hostname] = slot_count - return device_pool - -def _stable_remove_duplicates(data): - # Create a new list in the same order as original but with duplicates - # removed, should never be more than ~16 elements so simple is best - new_list = [] - for x in data: - if x not in new_list: - new_list.append(x) - return new_list + device_pool.append(hostinfo) + return device_pool -def parse_device_filter(host_info, include_str=None, exclude_str=None): +def parse_device_filter(device_pool: HostInfoList, include_str=None, exclude_str=None) -> HostInfoList: '''Parse an inclusion or exclusion string and filter a hostfile dictionary. Examples: - include_str="worker-0@worker-1:0,2" will use all slots on worker-0 and - slots [0, 2] on worker-1. - exclude_str="worker-1:0" will use all available devices except - slot 0 on worker-1. - ''' + include_str="worker-0,worker-1" will execute jobs only on worker-0 and worker-1. + exclude_str="worker-1" will use all available devices except worker-1. - # Constants that define our syntax - NODE_SEP = '@' - SLOT_LIST_START = ':' - SLOT_SEP = ',' + Args: + device_pool (HostInfoList): a list of HostInfo objects + include_str (str): --include option passed by user, default None + exclude_str (str): --exclude option passed by user, default None + + Returns: + filtered_hosts (HostInfoList): filtered hosts after inclusion/exclusion + ''' # Ensure include/exclude are mutually exclusive if include_str and exclude_str: @@ -68,176 +75,207 @@ def parse_device_filter(host_info, include_str=None, exclude_str=None): # no-op if include_str is None and exclude_str is None: - return host_info + return device_pool # Either build from scratch or remove items - filtered_hosts = dict() if include_str: parse_str = include_str + filtered_hosts = HostInfoList() elif exclude_str: - filtered_hosts = deepcopy(host_info) parse_str = exclude_str + filtered_hosts = device_pool # foreach node in the list for node_config in parse_str.split(NODE_SEP): - # Node can either be alone or node:slot,slot,slot - if SLOT_LIST_START in node_config: - hostname, slots = node_config.split(SLOT_LIST_START) - slots = [int(x) for x in slots.split(SLOT_SEP)] - - # sanity checks - if hostname not in host_info: - click.echo(f"Hostname '{hostname}' not found in hostfile") - exit() - for slot in slots: - if slot not in host_info[hostname]: - click.echo(f"No slot '{slot}' specified on host '{hostname}'") - - # If include string, build the list from here - if include_str: - filtered_hosts[hostname] = slots - elif exclude_str: - for slot in slots: - click.echo(f'- removing {slot} from {hostname}') - filtered_hosts[hostname].remove(slot) - - # User just specified the whole node - else: - hostname = node_config - # sanity check hostname - if hostname not in host_info: - click.echo(f"Hostname '{hostname}' not found in hostfile") - exit() - - if include_str: - filtered_hosts[hostname] = host_info[hostname] - elif exclude_str: - filtered_hosts[hostname] = [] - - # Post-processing to remove duplicates and empty nodes - del_keys = [] - for hostname in filtered_hosts: - # Remove duplicates - filtered_hosts[hostname] = _stable_remove_duplicates(filtered_hosts[hostname]) - # Remove empty hosts - if len(filtered_hosts[hostname]) == 0: - del_keys.append(hostname) - - # remove unneeded hosts - for name in del_keys: - del filtered_hosts[name] - - # Lastly, go over filtered_hosts and convert to a OrderedDict() to ensure - # we map ranks to nodes correctly by maintaining host_info ordering. - ordered_hosts = collections.OrderedDict() - for host in host_info: - if host in filtered_hosts: - ordered_hosts[host] = filtered_hosts[host] + hostname = node_config + hostinfo = device_pool.get_hostinfo(hostname) + # sanity check hostname + if not device_pool.has(hostname): + click.echo(f"Error: Hostname '{hostname}' not found in hostfile") + exit() + + if include_str: + filtered_hosts.append(hostinfo) + elif exclude_str: + filtered_hosts.remove(hostname) + + return filtered_hosts + + +def get_launch_command( + master_addr: str, + master_port: int, + nproc_per_node: int, + user_script: str, + user_args: List[str], + node_rank: int, + num_nodes: int, + extra_launch_args: str = None, +) -> str: + """ + Generate a command for distributed training. + + Args: + master_addr (str): the host of the master node + master_port (str): the port of the master node + nproc_per_node (str): the number of processes to launch on each node + user_script (str): the user Python file + user_args (str): the arguments for the user script + node_rank (int): the unique ID for the node + num_nodes (int): the number of nodes to execute jobs + + Returns: + cmd (str): the command the start distributed training + """ - return ordered_hosts + def _arg_dict_to_list(arg_dict): + ret = [] + + for k, v in arg_dict.items(): + if v: + ret.append(f'--{k}={v}') + else: + ret.append(f'--{k}') + return ret + + if extra_launch_args: + extra_launch_args_dict = dict() + for arg in extra_launch_args.split(','): + if '=' in arg: + k, v = arg.split('=') + extra_launch_args_dict[k] = v + else: + extra_launch_args_dict[arg] = None + extra_launch_args = extra_launch_args_dict + else: + extra_launch_args = dict() + torch_version = version.parse(torch.__version__) + assert torch_version.major == 1 -def parse_inclusion_exclusion(device_pool, inclusion, exclusion): - active_devices = collections.OrderedDict() - for hostname, slots in device_pool.items(): - active_devices[hostname] = list(range(slots)) + if torch_version.minor < 9: + cmd = [ + sys.executable, "-m", "torch.distributed.launch", f"--nproc_per_node={nproc_per_node}", + f"--master_addr={master_addr}", f"--master_port={master_port}", f"--nnodes={num_nodes}", + f"--node_rank={node_rank}" + ] + else: + # extra launch args for torch distributed launcher with torch >= 1.9 + default_torchrun_rdzv_args = dict(rdzv_backend="c10d", + rdzv_endpoint=f"{master_addr}:{master_port}", + rdzv_id="colossalai-default-job") + + # update rdzv arguments + for key in default_torchrun_rdzv_args.keys(): + if key in extra_launch_args: + value = extra_launch_args.pop(key) + default_torchrun_rdzv_args[key] = value + + if torch_version.minor < 10: + cmd = [ + sys.executable, "-m", "torch.distributed.run", f"--nproc_per_node={nproc_per_node}", + f"--nnodes={num_nodes}", f"--node_rank={node_rank}" + ] + else: + cmd = [ + "torchrun", f"--nproc_per_node={nproc_per_node}", f"--nnodes={num_nodes}", f"--node_rank={node_rank}" + ] + cmd += _arg_dict_to_list(default_torchrun_rdzv_args) - return parse_device_filter(active_devices, include_str=inclusion, exclude_str=exclusion) + cmd += _arg_dict_to_list(extra_launch_args) + [user_script] + user_args + cmd = ' '.join(cmd) + return cmd -def launch_multi_processes(args): +def launch_multi_processes(args: Config) -> None: """ Launch multiple processes on a single node or multiple nodes. The overall logic can be summarized as the pseudo code below: - if hostfile given: - hostinfo = parse_hostfile(hostfile) - hostinfo = include_or_exclude_hosts(hostinfo) - launch_on_multi_nodes(hostinfo) - elif hosts given: - hostinfo = parse_hosts(hosts) - launch_on_multi_nodes(hostinfo) - else: - launch_on_current_node() + if hostfile given: + hostinfo = parse_hostfile(hostfile) + hostinfo = include_or_exclude_hosts(hostinfo) + launch_on_multi_nodes(hostinfo) + elif hosts given: + hostinfo = parse_hosts(hosts) + launch_on_multi_nodes(hostinfo) + else: + launch_on_current_node() + + Args: + args (Config): the arguments taken from command line + """ assert isinstance(args, Config) + if args.nproc_per_node is None: + click.echo("--nproc_per_node did not receive any value") + exit() + # cannot accept hosts and hostfile at the same time if args.host and args.hostfile: click.echo("Error: hostfile and hosts are mutually exclusive, only one is required") # check if hostfile is given if args.hostfile: - device_pool = fetch_hostfile(args.hostfile) - else: - device_pool = None - - # filter and only keep the ones needed - active_devices = None - if device_pool: - active_devices = parse_inclusion_exclusion(device_pool, args.include, args.exclude) + device_pool = fetch_hostfile(args.hostfile, ssh_port=args.ssh_port) + active_device_pool = parse_device_filter(device_pool, args.include, args.exclude) if args.num_nodes > 0: # only keep the first num_nodes to execute jobs - updated_active_devices = collections.OrderedDict() - for count, hostname in enumerate(active_devices.keys()): + updated_active_device_pool = HostInfoList() + for count, hostinfo in enumerate(active_device_pool): if args.num_nodes == count: break - updated_active_devices[hostname] = active_devices[hostname] - active_devices = updated_active_devices - - if args.nproc_per_node > 0: - # only keep the first - updated_active_devices = collections.OrderedDict() - for hostname, active_devices in active_devices.items(): - if len(active_devices) < args.nproc_per_node: - click.echo( - f"Error: The number of available GPUs on {hostname} is smaller than the argument nproc_per_node" - ) - exit() - updated_active_devices[hostname] = active_devices[args.nproc_per_node] - active_devices = updated_active_devices + updated_active_device_pool.append(hostinfo) + active_device_pool = updated_active_device_pool + else: + active_device_pool = None env = os.environ.copy() # use hosts if hostfile is not given - if args.host and active_devices is None: - hostinfo = collections.OrderedDict() - host_list = args.host.strip().split(',') + if args.host and active_device_pool is None: + active_device_pool = HostInfoList() + host_list = args.host.strip().split(NODE_SEP) for hostname in host_list: - hostinfo[hostname] = args.nproc_per_node - active_devices = hostinfo - - # run on local node if not hosts or hostfile is given - if not active_devices: - if args.nproc_per_node == -1 or args.nproc_per_node > torch.cuda.device_count(): - nproc_per_node = torch.cuda.device_count() - else: - nproc_per_node = args.nproc_per_node - - if torch.__version__ <= "1.9": - cmd = [ - sys.executable, "-u", "-m", "torch.distributed.launch", f"--nproc_per_node={nproc_per_node}", - f"--master_addr={args.master_addr}", f"--master_port={args.master_port}" - ] + [args.user_script] + args.user_args - else: - cmd = [ - "torchrun", f"--nproc_per_node={nproc_per_node}", f"--master_addr={args.master_addr}", - f"--master_port={args.master_port}" - ] + [args.user_script] + args.user_args - else: - runner = PDSHRunner(args) - - curr_path = os.path.abspath('.') - if 'PYTHONPATH' in env: - env['PYTHONPATH'] = curr_path + ":" + env['PYTHONPATH'] - else: - env['PYTHONPATH'] = curr_path - - cmd = runner.get_cmd(env, active_devices, args) - - result = subprocess.Popen(cmd, env=env) - result.wait() - if result.returncode > 0: - sys.exit(result.returncode) + hostinfo = HostInfo(hostname=hostname, port=args.ssh_port) + active_device_pool.append(hostinfo) + + if not active_device_pool: + # run on local node if not hosts or hostfile is given + # add local node to host info list + active_device_pool = HostInfoList() + localhost_info = HostInfo(hostname='127.0.0.1', port=args.ssh_port) + active_device_pool.append(localhost_info) + + # launch distributed processes + runner = MultiNodeRunner() + curr_path = os.path.abspath('.') + + # collect current path env + env = dict() + for k, v in os.environ.items(): + # do not support multi-line env var + if v and '\n' not in v: + env[k] = v + + # establish remote connection + runner.connect(host_info_list=active_device_pool, workdir=curr_path, env=env) + + # execute distributed launching command + for node_id, hostinfo in enumerate(active_device_pool): + cmd = get_launch_command(master_addr=args.master_addr, + master_port=args.master_port, + nproc_per_node=args.nproc_per_node, + user_script=args.user_script, + user_args=args.user_args, + node_rank=node_id, + num_nodes=len(active_device_pool), + extra_launch_args=args.extra_launch_args) + runner.send(hostinfo=hostinfo, cmd=cmd) + + runner.recv_from_all() + runner.stop_all() + runner.recv_from_all() diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 02a907f09dee..572ee77ddda9 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -6,3 +6,4 @@ packaging pre-commit rich click +fabric