Skip to content

Commit

Permalink
feat: enable GPU device selection (#406)
Browse files Browse the repository at this point in the history
This commit enhances the functionality by allowing users to specify the
GPU device they wish to use for training. This feature supports multiple
GPUs, enabling more flexible and efficient resource utilization.
  • Loading branch information
rickstaa authored Feb 12, 2024
1 parent 7d7ac76 commit 73c1374
Show file tree
Hide file tree
Showing 10 changed files with 103 additions and 63 deletions.
18 changes: 12 additions & 6 deletions stable_learning_control/algos/pytorch/common/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,17 @@ class ReplayBuffer(CommonReplayBuffer):
sure a :obj:`torch.tensor` is returned when sampling.
Attributes:
device (str): The device the experiences are placed on (CPU or GPU).
device (str): The device the experiences are placed on (options: ``cpu``,
``gpu``, ``gpu:0``, ``gpu:1``, etc.).
"""

def __init__(self, device="cpu", *args, **kwargs):
"""Initialise the ReplayBuffer object.
Args:
device (str, optional): The computational device to put the sampled
experiences on.
experiences on (options: ``cpu``, ``gpu``, ``gpu:0``, ``gpu:1``,
etc.). Defaults to ``cpu``.
*args: All args to pass to the :class:`ReplayBuffer` parent class.
**kwargs: All kwargs to pass to the class:`ReplayBuffer` parent class.
"""
Expand Down Expand Up @@ -60,15 +62,17 @@ class FiniteHorizonReplayBuffer(CommonFiniteHorizonReplayBuffer):
sure a :obj:`torch.tensor` is returned when sampling.
Attributes:
device (str): The device the experiences are placed on (CPU or GPU).
device (str): The device the experiences are placed on (options: ``cpu``,
``gpu``, ``gpu:0``, ``gpu:1``, etc.).
"""

def __init__(self, device="cpu", *args, **kwargs):
"""Initialise the FiniteHorizonReplayBuffer object.
Args:
device (str, optional): The computational device to put the sampled
experiences on.
experiences on (options: ``cpu``, ``gpu``, ``gpu:0``, ``gpu:1``,
etc.). Defaults to ``cpu``.
*args: All args to pass to the :class:`FiniteHorizonReplayBuffer` parent
class.
**kwargs: All kwargs to pass to the class:`FiniteHorizonReplayBuffer` parent
Expand Down Expand Up @@ -102,15 +106,17 @@ class TrajectoryBuffer(CommonTrajectoryBuffer):
makes sure a :obj:`torch.tensor` is returned when sampling.
Attributes:
device (str): The device the experiences are placed on (CPU or GPU).
device (str): The device the experiences are placed on (options: ``cpu``,
``gpu``, ``gpu:0``, ``gpu:1``, etc.).
"""

def __init__(self, device="cpu", *args, **kwargs):
"""Initialise the TrajectoryBuffer object.
Args:
device (str, optional): The computational device to put the sampled
experiences on.
experiences on (options: ``cpu``, ``gpu``, ``gpu:0``, ``gpu:1``, etc.).
Defaults to ``cpu``.
*args: All args to pass to the :class:`TrajectoryBuffer` parent class.
**kwargs: All kwargs to pass to the :class:`TrajectoryBuffer` parent class.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def estimate_step_learning_rate(
update_after (int): The step number after which the learning rate should start
decreasing.
lr_final (float): The final learning rate.
total_steps (int): The total number of steps/epochs in the training process.
total_steps (int): The total number of steps/epochs in the training process.
Excludes the initial step.
step (int): The current step number. Excludes the initial step.
Expand Down
43 changes: 27 additions & 16 deletions stable_learning_control/algos/pytorch/common/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,40 @@ def retrieve_device(device_type="cpu"):
"""Retrieves the available computational device given a device type.
Args:
device_type (str): The device type (options: ``cpu`` and
``gpu``). Defaults to ``cpu``.
device_type (str): The device type (options: ``cpu``, ``gpu``, ``gpu:0``,
``gpu:1``, etc.). Defaults to ``cpu``.
Returns:
:obj:`torch.device`: The Pytorch device object.
"""
device_type = (
"cpu" if device_type.lower() not in ["gpu", "cpu"] else device_type.lower()
)
if torch.cuda.is_available() and device_type == "gpu":
device = torch.device("cuda")
elif not torch.cuda.is_available() and device_type == "gpu":
log_to_std_out(
(
device_type = device_type.lower()
if "gpu" in device_type:
if not torch.cuda.is_available():
log_to_std_out(
"GPU computing was enabled but the GPU can not be reached. "
"Reverting back to using CPU.",
"yellow",
),
type="warning",
)
device = torch.device("cpu")
type="warning",
)
device = torch.device("cpu")
else:
device_id = int(device_type.split(":")[1]) if ":" in device_type else 0
if device_id < torch.cuda.device_count():
device = torch.device(f"cuda:{device_id}")
else:
log_to_std_out(
f"GPU with ID {device_id} not found. Reverting back to the first "
"available GPU.",
"yellow",
type="warning",
)
device = torch.device("cuda:0")
else:
device = torch.device("cpu")
log_to_std_out(f"Torch is using the {device_type.upper()}.", type="info")
log_to_std_out(
f"Torch is using the {device}.",
type="info",
)
return device


Expand Down Expand Up @@ -147,7 +157,8 @@ def np_to_torch(input_object, dtype=None, device=None):
dtype (type, optional): The type you want to use for storing the data in the
tensor. Defaults to ``None`` (i.e. torch default will be used).
device (str, optional): The computational device on which the tensors should be
stored. Defaults to ``None`` (i.e. torch default device will be used).
stored. (options: ``cpu``, ``gpu``, ``gpu:0``, ``gpu:1``, etc.). Defaults
to ``None`` (i.e. torch default device will be used).
Returns:
object: The output python object in which numpy arrays have been converted to
Expand Down
16 changes: 8 additions & 8 deletions stable_learning_control/algos/pytorch/lac/lac.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,8 @@ def __init__(
``1e-4``.
lr_c (float, optional): Learning rate used for the (lyapunov) critic.
Defaults to ``1e-4``.
device (str, optional): The device the networks are placed on (``cpu``
or ``gpu``). Defaults to ``cpu``.
device (str, optional): The device the networks are placed on (options:
``cpu``, ``gpu``, ``gpu:0``, ``gpu:1``, etc.). Defaults to ``cpu``.
.. attention::
This class will behave differently when the ``actor_critic`` argument
Expand Down Expand Up @@ -818,8 +818,8 @@ def target_entropy(self, set_val):

@property
def device(self):
"""The device the networks are placed on (``cpu`` or ``gpu``). Defaults to
``cpu``.
"""The device the networks are placed on (options: ``cpu``, ``gpu``, ``gpu:0``,
``gpu:1``, etc.).
"""
return self._device

Expand Down Expand Up @@ -1004,8 +1004,8 @@ def lac(
Lyapunov Critic target. Defaults to ``0`` meaning the infinite-horizon
bellman backup is used.
seed (int): Seed for random number generators. Defaults to ``None``.
device (str, optional): The device the networks are placed on (``cpu``
or ``gpu``). Defaults to ``cpu``.
device (str, optional): The device the networks are placed on (options: ``cpu``,
``gpu``, ``gpu:0``, ``gpu:1``, etc.). Defaults to ``cpu``.
logger_kwargs (dict, optional): Keyword args for EpochLogger.
save_freq (int, optional): How often (in terms of gap between epochs) to save
the current policy and value function.
Expand Down Expand Up @@ -1738,8 +1738,8 @@ def lac(
type=str,
default="cpu",
help=(
"The device the networks are placed on: 'cpu' or 'gpu' (options: "
"default: cpu)"
"The device the networks are placed on. Options: 'cpu', 'gpu', 'gpu:0', "
"'gpu:1', etc. Defaults to 'cpu'."
),
)
parser.add_argument(
Expand Down
4 changes: 2 additions & 2 deletions stable_learning_control/algos/pytorch/latc/latc.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,8 @@ def latc(env_fn, actor_critic=None, *args, **kwargs):
type=str,
default="cpu",
help=(
"The device the networks are placed on: 'cpu' or 'gpu' (options: "
"default: cpu)"
"The device the networks are placed on. Options: 'cpu', 'gpu', 'gpu:0', "
"'gpu:1', etc. Defaults to 'cpu'."
),
)
parser.add_argument(
Expand Down
16 changes: 8 additions & 8 deletions stable_learning_control/algos/pytorch/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,8 @@ def __init__(
``1e-4``.
lr_c (float, optional): Learning rate used for the (Soft) critic.
Defaults to ``1e-4``.
device (str, optional): The device the networks are placed on (``cpu``
or ``gpu``). Defaults to ``cpu``.
device (str, optional): The device the networks are placed on (options:
``cpu``, ``gpu``, ``gpu:0``, ``gpu:1``, etc.). Defaults to ``cpu``.
""" # noqa: E501, D301
super().__init__()
self._setup_kwargs = {
Expand Down Expand Up @@ -696,8 +696,8 @@ def target_entropy(self, set_val):

@property
def device(self):
"""The device the networks are placed on (``cpu`` or ``gpu``). Defaults to
``cpu``.
"""The device the networks are placed on (options: ``cpu``, ``gpu``, ``gpu:0``,
``gpu:1``, etc.).
"""
return self._device

Expand Down Expand Up @@ -872,8 +872,8 @@ def sac(
replay_size (int, optional): Maximum length of replay buffer. Defaults to
``1e6``.
seed (int): Seed for random number generators. Defaults to ``None``.
device (str, optional): The device the networks are placed on (``cpu``
or ``gpu``). Defaults to ``cpu``.
device (str, optional): The device the networks are placed on (options: ``cpu``,
``gpu``, ``gpu:0``, ``gpu:1``, etc.). Defaults to ``cpu``.
logger_kwargs (dict, optional): Keyword args for EpochLogger.
save_freq (int, optional): How often (in terms of gap between epochs) to save
the current policy and value function.
Expand Down Expand Up @@ -1524,8 +1524,8 @@ def sac(
type=str,
default="cpu",
help=(
"The device the networks are placed on: 'cpu' or 'gpu' (options: "
"default: cpu)"
"The device the networks are placed on. Options: 'cpu', 'gpu', 'gpu:0', "
"'gpu:1', etc. Defaults to 'cpu'."
),
)
parser.add_argument(
Expand Down
31 changes: 27 additions & 4 deletions stable_learning_control/algos/tf2/common/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,39 @@ def set_device(device_type="cpu"):
"""Sets the computational device given a device type.
Args:
device_type (str): The device type (options: ``cpu`` and
``gpu``). Defaults to ``cpu``.
device_type (str): The device type (options: ``cpu``, ``gpu``, ``gpu:0``,
``gpu:1``, etc.). Defaults to ``cpu``.
Returns:
str: The type of device that is used.
"""
if device_type.lower() == "cpu":
device_type = device_type.lower()
if "gpu" in device_type:
if not tf.config.list_physical_devices("GPU"):
log_to_std_out(
"GPU computing was enabled but the GPU can not be reached. "
"Reverting back to using CPU.",
"yellow",
type="warning",
)
device_type = "cpu"
else:
device_id = int(device_type.split(":")[1]) if ":" in device_type else 0
gpus = tf.config.experimental.list_physical_devices("GPU")
if device_id < len(gpus):
tf.config.experimental.set_visible_devices(gpus[device_id], "GPU")
else:
log_to_std_out(
f"GPU with ID {device_id} not found. Reverting back to the first "
"available GPU.",
"yellow",
type="warning",
)
tf.config.experimental.set_visible_devices(gpus[0], "GPU")
else:
tf.config.set_visible_devices([], "GPU") # Force disable GPU.
log_to_std_out(f"TensorFlow is using the {device_type.upper()}.", type="info")
return device_type.lower()
return device_type


def mlp(sizes, activation, output_activation=None, name=""):
Expand Down
16 changes: 8 additions & 8 deletions stable_learning_control/algos/tf2/lac/lac.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,8 @@ def __init__(
``1e-4``.
lr_c (float, optional): Learning rate used for the (lyapunov) critic.
Defaults to ``1e-4``.
device (str, optional): The device the networks are placed on (``cpu``
or ``gpu``). Defaults to ``cpu``.
device (str, optional): The device the networks are placed on (options:
``cpu``, ``gpu``, ``gpu:0``, ``gpu:1``, etc.). Defaults to ``cpu``.
.. attention::
This class will behave differently when the ``actor_critic`` argument
Expand Down Expand Up @@ -749,8 +749,8 @@ def target_entropy(self, set_val):

@property
def device(self):
"""The device the networks are placed on (``cpu`` or ``gpu``). Defaults to
``cpu``.
"""The device the networks are placed on (options: ``cpu``, ``gpu``, ``gpu:0``,
``gpu:1``, etc.).
"""
return self._device

Expand Down Expand Up @@ -935,8 +935,8 @@ def lac(
Lyapunov Critic target. Defaults to ``0`` meaning the infinite-horizon
bellman backup is used.
seed (int): Seed for random number generators. Defaults to ``None``.
device (str, optional): The device the networks are placed on (``cpu``
or ``gpu``). Defaults to ``cpu``.
device (str, optional): The device the networks are placed on (options: ``cpu``,
``gpu``, ``gpu:0``, ``gpu:1``, etc.). Defaults to ``cpu``.
logger_kwargs (dict, optional): Keyword args for EpochLogger.
save_freq (int, optional): How often (in terms of gap between epochs) to save
the current policy and value function.
Expand Down Expand Up @@ -1590,8 +1590,8 @@ def lac(
type=str,
default="cpu",
help=(
"The device the networks are placed on: 'cpu' or 'gpu' (options: "
"default: cpu)"
"The device the networks are placed on. Options: 'cpu', 'gpu', 'gpu:0', "
"'gpu:1', etc. Defaults to 'cpu'."
),
)
parser.add_argument(
Expand Down
4 changes: 2 additions & 2 deletions stable_learning_control/algos/tf2/latc/latc.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,8 +323,8 @@ def latc(env_fn, actor_critic=None, *args, **kwargs):
type=str,
default="cpu",
help=(
"The device the networks are placed on: 'cpu' or 'gpu' (options: "
"default: cpu)"
"The device the networks are placed on. Options: 'cpu', 'gpu', 'gpu:0', "
"'gpu:1', etc. Defaults to 'cpu'."
),
)
parser.add_argument(
Expand Down
16 changes: 8 additions & 8 deletions stable_learning_control/algos/tf2/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,8 @@ def __init__(
``1e-4``.
lr_c (float, optional): Learning rate used for the (soft) critic.
Defaults to ``1e-4``.
device (str, optional): The device the networks are placed on (``cpu``
or ``gpu``). Defaults to ``cpu``.
device (str, optional): The device the networks are placed on (options:
``cpu``, ``gpu``, ``gpu:0``, ``gpu:1``, etc.). Defaults to ``cpu``.
""" # noqa: E501, D301
self._device = set_device(
device
Expand Down Expand Up @@ -628,8 +628,8 @@ def target_entropy(self, set_val):

@property
def device(self):
"""The device the networks are placed on (``cpu`` or ``gpu``). Defaults to
``cpu``.
"""The device the networks are placed on (options: ``cpu``, ``gpu``, ``gpu:0``,
``gpu:1``, etc.).
"""
return self._device

Expand Down Expand Up @@ -804,8 +804,8 @@ def sac(
replay_size (int, optional): Maximum length of replay buffer. Defaults to
``1e6``.
seed (int): Seed for random number generators. Defaults to ``None``.
device (str, optional): The device the networks are placed on (``cpu``
or ``gpu``). Defaults to ``cpu``.
device (str, optional): The device the networks are placed on (options: ``cpu``,
``gpu``, ``gpu:0``, ``gpu:1``, etc.). Defaults to ``cpu``.
logger_kwargs (dict, optional): Keyword args for EpochLogger.
save_freq (int, optional): How often (in terms of gap between epochs) to save
the current policy and value function.
Expand Down Expand Up @@ -1395,8 +1395,8 @@ def sac(
type=str,
default="cpu",
help=(
"The device the networks are placed on: 'cpu' or 'gpu' (options: "
"default: cpu)"
"The device the networks are placed on. Options: 'cpu', 'gpu', 'gpu:0', "
"'gpu:1', etc. Defaults to 'cpu'."
),
)
parser.add_argument(
Expand Down

0 comments on commit 73c1374

Please sign in to comment.