Skip to content

Commit

Permalink
[CLI] refactored the launch CLI and fixed bugs in multi-node launching (
Browse files Browse the repository at this point in the history
#844)

* [cli] fixed multi-node job launching

* [cli] fixed a bug in version comparison

* [cli] support launching with env var

* [cli] fixed multi-node job launching

* [cli] fixed a bug in version comparison

* [cli] support launching with env var

* added docstring

* [cli] added extra launch arguments

* [cli] added default launch rdzv args

* [cli] fixed version comparison

* [cli] added docstring examples and requierment

* polish docstring

* polish code

* polish code
  • Loading branch information
FrankLeeeee authored Apr 24, 2022
1 parent e5ea3fd commit cf6d1c9
Show file tree
Hide file tree
Showing 5 changed files with 492 additions and 264 deletions.
68 changes: 42 additions & 26 deletions colossalai/cli/launcher/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,34 @@

@click.command(help="Launch distributed training on a single node or multiple nodes",
context_settings=dict(ignore_unknown_options=True))
@click.option("-H", "-host", "--host", type=str, default=None, help="the list of machines to launch")
@click.option("--hostfile",
@click.option("-H",
"-host",
"--host",
type=str,
default=None,
help="Hostfile path that defines the device pool available to the job (e.g. worker-name:number of slots)")
help="the list of hostnames to launch in the format <host1>,<host2>")
@click.option(
"--include",
"--hostfile",
type=str,
default=None,
help=
"Specify computing devices to use during execution. String format is NODE_SPEC@NODE_SPEC where NODE_SPEC=<worker-name>:<list-of-slots>"
)
help="Hostfile path that defines the device pool available to the job, each line in the file is a hostname")
@click.option("--include",
type=str,
default=None,
help="Specify computing devices to use during execution. String format is <host1>,<host2>,"
" only effective when used with --hostfile.")
@click.option(
"--exclude",
type=str,
default=None,
help=
"Specify computing devices to NOT use during execution. Mutually exclusive with --include. Formatting is the same as --include."
)
@click.option("--num_nodes", type=int, default=-1, help="Total number of worker nodes to use.")
@click.option("--nproc_per_node", type=int, default=-1, help="Number of GPUs to use on each node.")
"Specify computing devices to NOT use during execution. Mutually exclusive with --include. Formatting is the same as --includ,"
" only effective when used with --hostfile.")
@click.option("--num_nodes",
type=int,
default=-1,
help="Total number of worker nodes to use, only effective when used with --hostfile.")
@click.option("--nproc_per_node", type=int, default=None, help="Number of GPUs to use on each node.")
@click.option("--master_port",
type=int,
default=29500,
Expand All @@ -35,34 +42,43 @@
default="127.0.0.1",
help="(optional) IP address of node 0, will be inferred via 'hostname -I' if not specified.")
@click.option(
"--launcher",
type=click.Choice(['torch', 'openmpi', 'slurm'], case_sensitive=False),
default="torch",
help="(optional) choose launcher backend for multi-node training. Options currently include PDSH, OpenMPI, SLURM.")
@click.option("--launcher_args",
type=str,
default=None,
help="(optional) pass launcher specific arguments as a single quoted argument.")
"--extra_launch_args",
type=str,
default=None,
help=
"Set additional torch distributed launcher arguments such as --standalone. The format is --extra_launch_args arg1=1,arg2=2. "
"This will be converted to --arg1=1 --arg2=2 during execution")
@click.option("--ssh-port", type=int, default=None, help="(optional) the port used for ssh connection")
@click.argument("user_script", type=str)
@click.argument('user_args', nargs=-1)
def run(host: str, hostfile: str, num_nodes: int, nproc_per_node: int, include: str, exclude: str, master_addr: str,
master_port: int, launcher: str, launcher_args: str, user_script: str, user_args: str):
master_port: int, extra_launch_args: str, ssh_port: int, user_script: str, user_args: str) -> None:
"""
To launch multiple processes on a single node or multiple nodes via command line.
Usage::
# run on the current node with all available GPUs
colossalai run train.py
# run with 4 GPUs on the current node use default port 29500
colossalai run --nprocs_per_node 4 train.py
# run with only 2 GPUs on the current node
colossalai run --nprocs_per_node 2 train.py
# run with 2 GPUs on the current node at port 29550
colossalai run --nprocs_per_node 4 --master_port 29550 train.py
# run on two nodes
colossalai run --host <host1>,<host2> train.py
colossalai run --host <host1>,<host2> --master_addr host1 --nprocs_per_node 4 train.py
# run with hostfile
colossalai run --hostfile <file_path> train.py
colossalai run --hostfile <file_path> --master_addr <host> --nprocs_per_node 4 train.py
# run with hostfile with only included hosts
colossalai run --hostfile <file_path> --master_addr host1 --include host1,host2 --nprocs_per_node 4 train.py
# run with hostfile excluding the hosts selected
colossalai run --hostfile <file_path> --master_addr host1 --exclude host2 --nprocs_per_node 4 train.py
"""
if not user_script.endswith('.py'):
click.echo(f'Error: invalid Python file {user_script}. Did you use a wrong option? Try colossalai run --help')
exit()

args_dict = locals()
args = Config(args_dict)
args.user_args = list(args.user_args)
Expand Down
122 changes: 122 additions & 0 deletions colossalai/cli/launcher/hostinfo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from typing import List
import socket


class HostInfo:
"""
A data class to store host connection-related data.
Args:
hostname (str): name or IP address of the host
port (str): the port for ssh connection
"""

def __init__(
self,
hostname: str,
port: str = None,
):
self.hostname = hostname
self.port = port
self.is_local_host = HostInfo.is_host_localhost(hostname, port)

@staticmethod
def is_host_localhost(hostname: str, port: str = None) -> None:
"""
Check if the host refers to the local machine.
Args:
hostname (str): name or IP address of the host
port (str): the port for ssh connection
Returns:
bool: True if it is local, False otherwise
"""

if port is None:
port = 22 # no port specified, lets just use the ssh port
hostname = socket.getfqdn(hostname)
if hostname in ("localhost", "127.0.0.1", "0.0.0.0"):
return True
localhost = socket.gethostname()
localaddrs = socket.getaddrinfo(localhost, port)
targetaddrs = socket.getaddrinfo(hostname, port)
for (family, socktype, proto, canonname, sockaddr) in localaddrs:
for (rfamily, rsocktype, rproto, rcanonname, rsockaddr) in targetaddrs:
if rsockaddr[0] == sockaddr[0]:
return True
return False

def __str__(self):
return f'hostname: {self.hostname}, port: {self.port}'

def __repr__(self):
return self.__str__()


class HostInfoList:
"""
A data class to store a list of HostInfo objects.
"""

def __init__(self):
self.hostinfo_list = []

def append(self, hostinfo: HostInfo) -> None:
"""
Add an HostInfo object to the list.
Args:
hostinfo (HostInfo): host information
"""

self.hostinfo_list.append(hostinfo)

def remove(self, hostname: str) -> None:
"""
Add an HostInfo object to the list.
Args:
hostname (str): the name of the host
"""

hostinfo = self.get_hostinfo(hostname)
self.hostinfo_list.remove(hostinfo)

def get_hostinfo(self, hostname: str) -> HostInfo:
"""
Return the HostInfo object which matches with the hostname.
Args:
hostname (str): the name of the host
Returns:
hostinfo (HostInfo): the HostInfo object which matches with the hostname
"""

for hostinfo in self.hostinfo_list:
if hostinfo.hostname == hostname:
return hostinfo

raise Exception(f"Hostname {hostname} is not found")

def has(self, hostname: str) -> bool:
"""
Check if the hostname has been added.
Args:
hostname (str): the name of the host
Returns:
bool: True if added, False otherwise
"""
for hostinfo in self.hostinfo_list:
if hostinfo.hostname == hostname:
return True
return False

def __iter__(self):
return iter(self.hostinfo_list)

def __len__(self):
return len(self.hostinfo_list)
Loading

0 comments on commit cf6d1c9

Please sign in to comment.