Skip to content

Commit

Permalink
feat(pytorch): add alpha/lambda learning rate customization (#412)
Browse files Browse the repository at this point in the history
This commit enhances user control over the training process by allowing
direct customization of the alpha/lambda learning rates and their decay
rates. Users can now fine-tune these parameters to better suit their
specific training requirements.
  • Loading branch information
rickstaa authored Feb 19, 2024
1 parent 2b3693e commit 6feb749
Show file tree
Hide file tree
Showing 2 changed files with 342 additions and 78 deletions.
241 changes: 196 additions & 45 deletions stable_learning_control/algos/pytorch/lac/lac.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@
"AverageLossPi",
"AverageEntropy",
]
VALID_DECAY_TYPES = ["linear", "exponential", "constant"]
VALID_DECAY_REFERENCES = ["step", "epoch"]
DEFAULT_DECAY_TYPE = "linear"
DEFAULT_DECAY_REFERENCE = "epoch"


class LAC(nn.Module):
Expand Down Expand Up @@ -110,6 +114,8 @@ def __init__(
adaptive_temperature=True,
lr_a=1e-4,
lr_c=3e-4,
lr_alpha=1e-4,
lr_labda=3e-4,
device="cpu",
):
"""Initialise the LAC algorithm.
Expand Down Expand Up @@ -197,6 +203,10 @@ def __init__(
``1e-4``.
lr_c (float, optional): Learning rate used for the (lyapunov) critic.
Defaults to ``1e-4``.
lr_alpha (float, optional): Learning rate used for the entropy temperature.
Defaults to ``1e-4``.
lr_labda (float, optional): Learning rate used for the Lyapunov Lagrance
multiplier. Defaults to ``3e-4``.
device (str, optional): The device the networks are placed on (options:
``cpu``, ``gpu``, ``gpu:0``, ``gpu:1``, etc.). Defaults to ``cpu``.
Expand Down Expand Up @@ -258,8 +268,8 @@ def __init__(
self._alpha3 = alpha3
self._lr_a = lr_a
if self._adaptive_temperature:
self._lr_alpha = lr_a
self._lr_lag = lr_a
self._lr_alpha = lr_alpha
self._lr_lag = lr_labda
self._lr_c = lr_c
if not isinstance(target_entropy, (float, int)):
self._target_entropy = heuristic_target_entropy(env.action_space)
Expand Down Expand Up @@ -870,10 +880,18 @@ def lac(
adaptive_temperature=True,
lr_a=1e-4,
lr_c=3e-4,
lr_alpha=1e-4,
lr_labda=3e-4,
lr_a_final=1e-10,
lr_c_final=1e-10,
lr_decay_type="linear",
lr_decay_ref="epoch",
lr_alpha_final=1e-10,
lr_labda_final=1e-10,
lr_decay_type=DEFAULT_DECAY_TYPE,
lr_a_decay_type=DEFAULT_DECAY_TYPE,
lr_c_decay_type=DEFAULT_DECAY_TYPE,
lr_alpha_decay_type=DEFAULT_DECAY_TYPE,
lr_labda_decay_type=DEFAULT_DECAY_TYPE,
lr_decay_ref=DEFAULT_DECAY_REFERENCE,
batch_size=256,
replay_size=int(1e6),
horizon_length=0,
Expand Down Expand Up @@ -988,13 +1006,33 @@ def lac(
``1e-4``.
lr_c (float, optional): Learning rate used for the (lyapunov) critic.
Defaults to ``1e-4``.
lr_alpha (float, optional): Learning rate used for the entropy temperature.
Defaults to ``1e-4``.
lr_labda (float, optional): Learning rate used for the Lyapunov Lagrance
multiplier. Defaults to ``3e-4``.
lr_a_final(float, optional): The final actor learning rate that is achieved
at the end of the training. Defaults to ``1e-10``.
lr_c_final(float, optional): The final critic learning rate that is achieved
at the end of the training. Defaults to ``1e-10``.
lr_decay_type (str, optional): The learning rate decay type that is used (
options are: ``linear`` and ``exponential`` and ``constant``). Defaults to
``linear``.
lr_alpha_final(float, optional): The final alpha learning rate that is
achieved at the end of the training. Defaults to ``1e-10``.
lr_labda_final(float, optional): The final labda learning rate that is
achieved at the end of the training. Defaults to ``1e-10``.
lr_decay_type (str, optional): The learning rate decay type that is used (options
are: ``linear`` and ``exponential`` and ``constant``). Defaults to
``linear``.Can be overridden by the specific learning rate decay types.
lr_a_decay_type (str, optional): The learning rate decay type that is used for
the actor learning rate (options are: ``linear`` and ``exponential`` and
``constant``). If not specified, the general learning rate decay type is used.
lr_c_decay_type (str, optional): The learning rate decay type that is used for
the critic learning rate (options are: ``linear`` and ``exponential`` and
``constant``). If not specified, the general learning rate decay type is used.
lr_alpha_decay_type (str, optional): The learning rate decay type that is used
for the alpha learning rate (options are: ``linear`` and ``exponential``
and ``constant``). If not specified, the general learning rate decay type is used.
lr_labda_decay_type (str, optional): The learning rate decay type that is used
for the labda learning rate (options are: ``linear`` and ``exponential``
and ``constant``). If not specified, the general learning rate decay type is used.
lr_decay_ref (str, optional): The reference variable that is used for decaying
the learning rate (options: ``epoch`` and ``step``). Defaults to ``epoch``.
batch_size (int, optional): Minibatch size for SGD. Defaults to ``256``.
Expand Down Expand Up @@ -1134,20 +1172,22 @@ def lac(
# torch.backends.cudnn.benchmark = False # Disable for reproducibility.

policy = LAC(
env,
actor_critic,
ac_kwargs,
opt_type,
alpha,
alpha3,
labda,
gamma,
polyak,
target_entropy,
adaptive_temperature,
lr_a,
lr_c,
device,
env=env,
actor_critic=actor_critic,
ac_kwargs=ac_kwargs,
opt_type=opt_type,
alpha=alpha,
alpha3=alpha3,
labda=labda,
gamma=gamma,
polyak=polyak,
target_entropy=target_entropy,
adaptive_temperature=adaptive_temperature,
lr_a=lr_a,
lr_c=lr_c,
lr_alpha=lr_alpha,
lr_labda=lr_labda,
device=device,
)

# Restore policy if supplied.
Expand Down Expand Up @@ -1199,19 +1239,51 @@ def lac(
logger.log("Network structure:\n", type="info")
logger.log(policy.ac, end="\n\n")

# Parse learning rate decay type.
valid_lr_decay_options = ["step", "epoch"]
# Parse learning rate decay reference.
lr_decay_ref = lr_decay_ref.lower()
if lr_decay_ref not in valid_lr_decay_options:
options = [f"'{option}'" for option in valid_lr_decay_options]
if lr_decay_ref not in VALID_DECAY_REFERENCES:
options = [f"'{option}'" for option in VALID_DECAY_REFERENCES]
logger.log(
f"The learning rate decay reference variable was set to '{lr_decay_ref}', "
"which is not a valid option. Valid options are "
f"{', '.join(options)}. The learning rate decay reference "
"variable has been set to 'epoch'.",
f"variable has been set to '{DEFAULT_DECAY_REFERENCE}'.",
type="warning",
)
lr_decay_ref = "epoch"
lr_decay_ref = DEFAULT_DECAY_REFERENCE

# Parse learning rate decay types.
lr_decay_type = lr_decay_type.lower()
if lr_decay_type not in VALID_DECAY_TYPES:
options = [f"'{option}'" for option in VALID_DECAY_TYPES]
logger.log(
f"The learning rate decay type was set to '{lr_decay_type}', which is not "
"a valid option. Valid options are "
f"{', '.join(options)}. The learning rate decay type has been set to "
f"'{DEFAULT_DECAY_TYPE}'.",
type="warning",
)
lr_decay_type = DEFAULT_DECAY_TYPE
decay_types = {
"actor": lr_a_decay_type.lower() if lr_a_decay_type else None,
"critic": lr_c_decay_type.lower() if lr_c_decay_type else None,
"alpha": lr_alpha_decay_type.lower() if lr_alpha_decay_type else None,
"labda": lr_labda_decay_type.lower() if lr_labda_decay_type else None,
}
for name, decay_type in decay_types.items():
if decay_type is None:
decay_types[name] = lr_decay_type
else:
if decay_type not in VALID_DECAY_TYPES:
logger.log(
f"Invalid {name} learning rate decay type: '{decay_type}'. Using "
f"global learning rate decay type: '{lr_decay_type}' instead.",
type="warning",
)
decay_types[name] = lr_decay_type
lr_a_decay_type, lr_c_decay_type, lr_alpha_decay_type, lr_labda_decay_type = (
decay_types.values()
)

# Calculate the number of learning rate scheduler steps.
if lr_decay_ref == "step":
Expand All @@ -1223,25 +1295,34 @@ def lac(
lr_decay_steps = epochs

# Setup learning rate schedulers.
lr_a_init, lr_c_init, lr_alpha_init, lr_labda_init = lr_a, lr_c, lr_alpha, lr_labda
opt_schedulers = {
"pi": get_lr_scheduler(
policy._pi_optimizer, lr_decay_type, lr_a, lr_a_final, lr_decay_steps
policy._pi_optimizer,
lr_a_decay_type,
lr_a_init,
lr_a_final,
lr_decay_steps,
),
"c": get_lr_scheduler(
policy._c_optimizer, lr_decay_type, lr_c, lr_c_final, lr_decay_steps
policy._c_optimizer,
lr_c_decay_type,
lr_c_init,
lr_c_final,
lr_decay_steps,
),
"alpha": get_lr_scheduler(
policy._log_alpha_optimizer,
lr_decay_type,
lr_a,
lr_a_final,
lr_alpha_decay_type,
lr_alpha_init,
lr_alpha_final,
lr_decay_steps,
),
"lambda": get_lr_scheduler(
"labda": get_lr_scheduler(
policy._log_labda_optimizer,
lr_decay_type,
lr_a,
lr_a_final,
lr_labda_decay_type,
lr_labda_init,
lr_labda_final,
lr_decay_steps,
),
}
Expand Down Expand Up @@ -1351,7 +1432,7 @@ def lac(
for scheduler in opt_schedulers.values():
scheduler.step()
policy.bound_lr(
lr_a_final, lr_c_final, lr_a_final, lr_a_final
lr_a_final, lr_c_final, lr_alpha_final, lr_labda_final
) # Make sure lr is bounded above the final lr.

# SGD batch tb logging.
Expand Down Expand Up @@ -1382,32 +1463,32 @@ def lac(
# NOTE: Estimate since 'step' decay is applied at policy update.
lr_actor = estimate_step_learning_rate(
opt_schedulers["pi"],
lr_a,
lr_a_init,
lr_a_final,
update_after,
total_steps,
t + 1,
)
lr_critic = estimate_step_learning_rate(
opt_schedulers["c"],
lr_c,
lr_c_init,
lr_c_final,
update_after,
total_steps,
t + 1,
)
lr_alpha = estimate_step_learning_rate(
opt_schedulers["alpha"],
lr_a,
lr_a_final,
lr_alpha_init,
lr_alpha_final,
update_after,
total_steps,
t + 1,
)
lr_labda = estimate_step_learning_rate(
opt_schedulers["lambda"],
lr_a,
lr_a_final,
opt_schedulers["labda"],
lr_labda_init,
lr_labda_final,
update_after,
total_steps,
t + 1,
Expand Down Expand Up @@ -1508,7 +1589,7 @@ def lac(
for scheduler in opt_schedulers.values():
scheduler.step()
policy.bound_lr(
lr_a_final, lr_c_final, lr_a_final, lr_a_final
lr_a_final, lr_c_final, lr_alpha_final, lr_labda_final
) # Make sure lr is bounded above the final lr.

# Export model to 'TorchScript'
Expand Down Expand Up @@ -1684,6 +1765,18 @@ def lac(
parser.add_argument(
"--lr_c", type=float, default=3e-4, help="critic learning rate (default: 1e-4)"
)
parser.add_argument(
"--lr_alpha",
type=float,
default=1e-4,
help="entropy temperature learning rate (default: 1e-4)",
)
parser.add_argument(
"--lr_labda",
type=float,
default=3e-4,
help="lyapunov Lagrance multiplier learning rate (default: 3e-4)",
)
parser.add_argument(
"--lr_a_final",
type=float,
Expand All @@ -1696,12 +1789,62 @@ def lac(
default=1e-10,
help="the finalcritic learning rate (default: 1e-10)",
)
parser.add_argument(
"--lr_alpha_final",
type=float,
default=1e-10,
help="the final entropy temperature learning rate (default: 1e-10)",
)
parser.add_argument(
"--lr_labda_final",
type=float,
default=1e-10,
help="the final lyapunov Lagrance multiplier learning rate (default: 1e-10)",
)
parser.add_argument(
"--lr_decay_type",
type=str,
default="linear",
help="the learning rate decay type (default: linear)",
)
parser.add_argument(
"--lr_a_decay_type",
type=str,
default=None,
help=(
"the learning rate decay type that is used for the actor learning rate. "
"If not specified, the general learning rate decay type is used."
),
)
parser.add_argument(
"--lr_c_decay_type",
type=str,
default=None,
help=(
"the learning rate decay type that is used for the critic learning rate. "
"If not specified, the general learning rate decay type is used."
),
)
parser.add_argument(
"--lr_alpha_decay_type",
type=str,
default=None,
help=(
"the learning rate decay type that is used for the entropy temperature "
"learning rate. If not specified, the general learning rate decay type is "
"used."
),
)
parser.add_argument(
"--lr_labda_decay_type",
type=str,
default=None,
help=(
"the learning rate decay type that is used for the lyapunov Lagrance "
"multiplier learning rate. If not specified, the general learning rate "
"decay type is used."
),
)
parser.add_argument(
"--lr_decay_ref",
type=str,
Expand Down Expand Up @@ -1914,9 +2057,17 @@ def lac(
adaptive_temperature=args.adaptive_temperature,
lr_a=args.lr_a,
lr_c=args.lr_c,
lr_alpha=args.lr_alpha,
lr_labda=args.lr_labda,
lr_a_final=args.lr_a_final,
lr_c_final=args.lr_c_final,
lr_alpha_final=args.lr_a_final,
lr_labda_final=args.lr_a_final,
lr_decay_type=args.lr_decay_type,
lr_a_decay_type=args.lr_a_decay_type,
lr_c_decay_type=args.lr_c_decay_type,
lr_alpha_decay_type=args.lr_alpha_decay_type,
lr_labda_decay_type=args.lr_labda_decay_type,
lr_decay_ref=args.lr_decay_ref,
batch_size=args.batch_size,
replay_size=args.replay_size,
Expand Down
Loading

0 comments on commit 6feb749

Please sign in to comment.