Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(tf2): correct step-based learning rate decay #407

Merged
merged 1 commit into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions stable_learning_control/algos/tf2/common/get_lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ def get_lr_scheduler(decaying_lr_type, lr_start, lr_final, steps):
"""Creates a learning rate scheduler.
Args:
decaying_lr_type (str): The learning rate decay type that is used (
options are: ``linear`` and ``exponential`` and ``constant``).
decaying_lr_type (str): The learning rate decay type that is used (options are:
``linear`` and ``exponential`` and ``constant``).
lr_start (float): Initial learning rate.
lr_final (float): Final learning rate.
steps (int, optional): Number of steps/epochs used in the training. This
Expand Down
94 changes: 71 additions & 23 deletions stable_learning_control/algos/tf2/lac/lac.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,6 +952,7 @@ def lac(
- replay_buffer (union[:class:`~stable_learning_control.algos.common.buffers.ReplayBuffer`, :class:`~stable_learning_control.algos.common.buffers.FiniteHorizonReplayBuffer`]):
The replay buffer used during training.
""" # noqa: E501, D301
update_after = max(1, update_after) # You can not update before the first step.
validate_args(**locals())

# Retrieve hyperparameters while filtering out the logger_kwargs.
Expand Down Expand Up @@ -1083,11 +1084,41 @@ def lac(
device,
)

# Parse learning rate decay type.
valid_lr_decay_options = ["step", "epoch"]
lr_decay_ref = lr_decay_ref.lower()
if lr_decay_ref not in valid_lr_decay_options:
options = [f"'{option}'" for option in valid_lr_decay_options]
logger.log(
f"The learning rate decay reference variable was set to '{lr_decay_ref}', "
"which is not a valid option. Valid options are "
f"{', '.join(options)}. The learning rate decay reference "
"variable has been set to 'epoch'.",
type="warning",
)
lr_decay_ref = "epoch"

# Calculate the number of learning rate scheduler steps.
if lr_decay_ref == "step":
# NOTE: Decay applied at policy update to improve performance.
lr_decay_steps = (total_steps - update_after) / update_every
else:
lr_decay_steps = epochs

# Create learning rate schedulers.
# NOTE: Alpha and labda currently use the same scheduler as the actor.
lr_decay_ref_var = total_steps if lr_decay_ref.lower() == "steps" else epochs
lr_a_scheduler = get_lr_scheduler(lr_decay_type, lr_a, lr_a_final, lr_decay_ref_var)
lr_c_scheduler = get_lr_scheduler(lr_decay_type, lr_c, lr_c_final, lr_decay_ref_var)
lr_a_scheduler = get_lr_scheduler(lr_decay_type, lr_a, lr_a_final, lr_decay_steps)
lr_c_scheduler = get_lr_scheduler(lr_decay_type, lr_c, lr_c_final, lr_decay_steps)

# Create step based learning rate schedulers.
# NOTE: Used to estimate the learning rate at each step.
if lr_decay_ref == "step":
lr_a_step_scheduler = get_lr_scheduler(
lr_decay_type, lr_a, lr_a_final, lr_decay_steps + 1
)
lr_c_step_scheduler = get_lr_scheduler(
lr_decay_type, lr_c, lr_c_final, lr_decay_steps + 1
)

# Restore policy if supplied.
if start_policy is not None:
Expand Down Expand Up @@ -1165,6 +1196,7 @@ def lac(
"Entropy",
]
if use_tensorboard:
# NOTE: TensorBoard counts from 0.
logger.log_to_tb(
"Lr_a",
policy._pi_optimizer.lr.numpy(),
Expand Down Expand Up @@ -1227,6 +1259,7 @@ def lac(
# NOTE: Improved compared to Han et al. 2020. Previously, updates were based on
# memory size, which only changed at terminal states.
if (t + 1) >= update_after and ((t + 1) - update_after) % update_every == 0:
n_update = ((t + 1) - update_after) // update_every
for _ in range(steps_per_update):
batch = replay_buffer.sample_batch(batch_size)
update_diagnostics = policy.update(data=batch)
Expand All @@ -1235,18 +1268,20 @@ def lac(
# Step based learning rate decay.
if lr_decay_ref.lower() == "step":
lr_a_now = max(
lr_a_scheduler(t + 1), lr_a_final
lr_a_scheduler(n_update + 1), lr_a_final
) # Make sure lr is bounded above final lr.
lr_c_now = max(
lr_c_scheduler(t + 1), lr_c_final
lr_c_scheduler(n_update + 1), lr_c_final
) # Make sure lr is bounded above final lr.
policy.set_learning_rates(
lr_a=lr_a_now, lr_c=lr_c_now, lr_alpha=lr_a_now, lr_labda=lr_a_now
)

# SGD batch tb logging.
if use_tensorboard and not tb_low_log_freq:
logger.log_to_tb(keys=diag_tb_log_list, global_step=t)
logger.log_to_tb(
keys=diag_tb_log_list, global_step=t
) # NOTE: TensorBoard counts from 0.

# End of epoch handling (Save model, test performance and log data)
if (t + 1) % steps_per_epoch == 0:
Expand All @@ -1265,21 +1300,22 @@ def lac(
extend=True,
)

# Epoch based learning rate decay.
if lr_decay_ref.lower() != "step":
lr_a_now = max(
lr_a_scheduler(epoch), lr_a_final
) # Make sure lr is bounded above final.
lr_c_now = max(
lr_c_scheduler(epoch), lr_c_final
) # Make sure lr is bounded above final.
policy.set_learning_rates(
lr_a=lr_a_now, lr_c=lr_c_now, lr_alpha=lr_a_now, lr_labda=lr_a_now
)
# Retrieve current learning rates.
if lr_decay_ref == "step":
progress = max((t + 1) - update_after, 0) / update_every
lr_actor = lr_a_step_scheduler(progress)
lr_critic = lr_c_step_scheduler(progress)
lr_alpha = lr_a_step_scheduler(progress)
lr_labda = lr_a_step_scheduler(progress)
else:
lr_actor = policy._pi_optimizer.lr.numpy()
lr_critic = policy._c_optimizer.lr.numpy()
lr_alpha = policy._log_alpha_optimizer.lr.numpy()
lr_labda = policy._log_labda_optimizer.lr.numpy()

# Log info about epoch.
logger.log_tabular("Epoch", epoch)
logger.log_tabular("TotalEnvInteracts", t)
logger.log_tabular("TotalEnvInteracts", t + 1)
logger.log_tabular(
"EpRet",
with_min_and_max=True,
Expand All @@ -1299,25 +1335,25 @@ def lac(
)
logger.log_tabular(
"Lr_a",
policy._pi_optimizer.lr.numpy(),
lr_actor,
tb_write=use_tensorboard,
tb_prefix="LearningRates",
)
logger.log_tabular(
"Lr_c",
policy._c_optimizer.lr.numpy(),
lr_critic,
tb_write=use_tensorboard,
tb_prefix="LearningRates",
)
logger.log_tabular(
"Lr_alpha",
policy._log_alpha_optimizer.lr.numpy(),
lr_alpha,
tb_write=use_tensorboard,
tb_prefix="LearningRates",
)
logger.log_tabular(
"Lr_labda",
policy._log_labda_optimizer.lr.numpy(),
lr_labda,
tb_write=use_tensorboard,
tb_prefix="LearningRates",
)
Expand Down Expand Up @@ -1360,7 +1396,19 @@ def lac(
tb_write=(use_tensorboard and tb_low_log_freq),
)
logger.log_tabular("Time", time.time() - start_time)
logger.dump_tabular(global_step=t)
logger.dump_tabular(global_step=t) # NOTE: TensorBoard counts from 0.

# Epoch based learning rate decay.
if lr_decay_ref.lower() != "step":
lr_a_now = max(
lr_a_scheduler(epoch), lr_a_final
) # Make sure lr is bounded above final.
lr_c_now = max(
lr_c_scheduler(epoch), lr_c_final
) # Make sure lr is bounded above final.
policy.set_learning_rates(
lr_a=lr_a_now, lr_c=lr_c_now, lr_alpha=lr_a_now, lr_labda=lr_a_now
)

# Export model to 'SavedModel'
if export:
Expand Down
90 changes: 68 additions & 22 deletions stable_learning_control/algos/tf2/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,6 +821,7 @@ def sac(
- replay_buffer (union[:class:`~stable_learning_control.algos.common.buffers.ReplayBuffer`, :class:`~stable_learning_control.algos.common.buffers.FiniteHorizonReplayBuffer`]):
The replay buffer used during training.
""" # noqa: E501, D301
update_after = max(1, update_after) # You can not update before the first step.
validate_args(**locals())

# Retrieve hyperparameters while filtering out the logger_kwargs.
Expand Down Expand Up @@ -950,11 +951,41 @@ def sac(
device,
)

# Parse learning rate decay type.
valid_lr_decay_options = ["step", "epoch"]
lr_decay_ref = lr_decay_ref.lower()
if lr_decay_ref not in valid_lr_decay_options:
options = [f"'{option}'" for option in valid_lr_decay_options]
logger.log(
f"The learning rate decay reference variable was set to '{lr_decay_ref}', "
"which is not a valid option. Valid options are "
f"{', '.join(options)}. The learning rate decay reference "
"variable has been set to 'epoch'.",
type="warning",
)
lr_decay_ref = "epoch"

# Calculate the number of learning rate scheduler steps.
if lr_decay_ref == "step":
# NOTE: Decay applied at policy update to improve performance.
lr_decay_steps = (total_steps - update_after) / update_every
else:
lr_decay_steps = epochs

# Create learning rate schedulers.
# NOTE: Alpha currently uses the same scheduler as the actor.
lr_decay_ref_var = total_steps if lr_decay_ref.lower() == "steps" else epochs
lr_a_scheduler = get_lr_scheduler(lr_decay_type, lr_a, lr_a_final, lr_decay_ref_var)
lr_c_scheduler = get_lr_scheduler(lr_decay_type, lr_c, lr_c_final, lr_decay_ref_var)
lr_a_scheduler = get_lr_scheduler(lr_decay_type, lr_a, lr_a_final, lr_decay_steps)
lr_c_scheduler = get_lr_scheduler(lr_decay_type, lr_c, lr_c_final, lr_decay_steps)

# Create step based learning rate schedulers.
# NOTE: Used to estimate the learning rate at each step.
if lr_decay_ref == "step":
lr_a_step_scheduler = get_lr_scheduler(
lr_decay_type, lr_a, lr_a_final, lr_decay_steps + 1
)
lr_c_step_scheduler = get_lr_scheduler(
lr_decay_type, lr_c, lr_c_final, lr_decay_steps + 1
)

# Restore policy if supplied.
if start_policy is not None:
Expand Down Expand Up @@ -1010,6 +1041,7 @@ def sac(
# Setup diagnostics tb_write dict and store initial learning rates.
diag_tb_log_list = ["LossQ", "LossPi", "Alpha", "LossAlpha", "Entropy"]
if use_tensorboard:
# NOTE: TensorBoard counts from 0.
logger.log_to_tb(
"Lr_a",
policy._pi_optimizer.lr.numpy(),
Expand Down Expand Up @@ -1062,6 +1094,7 @@ def sac(
# NOTE: Improved compared to Han et al. 2020. Previously, updates were based on
# memory size, which only changed at terminal states.
if (t + 1) >= update_after and ((t + 1) - update_after) % update_every == 0:
n_update = ((t + 1) - update_after) // update_every
for _ in range(steps_per_update):
batch = replay_buffer.sample_batch(batch_size)
update_diagnostics = policy.update(data=batch)
Expand All @@ -1070,18 +1103,20 @@ def sac(
# Step based learning rate decay.
if lr_decay_ref.lower() == "step":
lr_a_now = max(
lr_a_scheduler(t + 1), lr_a_final
lr_a_scheduler(n_update + 1), lr_a_final
) # Make sure lr is bounded above final lr.
lr_c_now = max(
lr_c_scheduler(t + 1), lr_c_final
lr_c_scheduler(n_update + 1), lr_c_final
) # Make sure lr is bounded above final lr.
policy.set_learning_rates(
lr_a=lr_a_now, lr_c=lr_c_now, lr_alpha=lr_a_now
)

# SGD batch tb logging.
if use_tensorboard and not tb_low_log_freq:
logger.log_to_tb(keys=diag_tb_log_list, global_step=t)
logger.log_to_tb(
keys=diag_tb_log_list, global_step=t
) # NOTE: TensorBoard counts from 0.

# End of epoch handling (Save model, test performance and log data)
if (t + 1) % steps_per_epoch == 0:
Expand All @@ -1100,21 +1135,20 @@ def sac(
extend=True,
)

# Epoch based learning rate decay.
if lr_decay_ref.lower() != "step":
lr_a_now = max(
lr_a_scheduler(epoch), lr_a_final
) # Make sure lr is bounded above final.
lr_c_now = max(
lr_c_scheduler(epoch), lr_c_final
) # Make sure lr is bounded above final.
policy.set_learning_rates(
lr_a=lr_a_now, lr_c=lr_c_now, lr_alpha=lr_a_now
)
# Retrieve current learning rates.
if lr_decay_ref == "step":
progress = max((t + 1) - update_after, 0) / update_every
lr_actor = lr_a_step_scheduler(progress)
lr_critic = lr_c_step_scheduler(progress)
lr_alpha = lr_a_step_scheduler(progress)
else:
lr_actor = policy._pi_optimizer.lr.numpy()
lr_critic = policy._c_optimizer.lr.numpy()
lr_alpha = policy._log_alpha_optimizer.lr.numpy()

# Log info about epoch.
logger.log_tabular("Epoch", epoch)
logger.log_tabular("TotalEnvInteracts", t)
logger.log_tabular("TotalEnvInteracts", t + 1)
logger.log_tabular(
"EpRet",
with_min_and_max=True,
Expand All @@ -1134,19 +1168,19 @@ def sac(
)
logger.log_tabular(
"Lr_a",
policy._pi_optimizer.lr.numpy(),
lr_actor,
tb_write=use_tensorboard,
tb_prefix="LearningRates",
)
logger.log_tabular(
"Lr_c",
policy._c_optimizer.lr.numpy(),
lr_critic,
tb_write=use_tensorboard,
tb_prefix="LearningRates",
)
logger.log_tabular(
"Lr_alpha",
policy._log_alpha_optimizer.lr.numpy(),
lr_alpha,
tb_write=use_tensorboard,
tb_prefix="LearningRates",
)
Expand Down Expand Up @@ -1180,7 +1214,19 @@ def sac(
tb_write=(use_tensorboard and tb_low_log_freq),
)
logger.log_tabular("Time", time.time() - start_time)
logger.dump_tabular(global_step=t)
logger.dump_tabular(global_step=t) # NOTE: TensorBoard counts from 0.

# Epoch based learning rate decay.
if lr_decay_ref.lower() != "step":
lr_a_now = max(
lr_a_scheduler(epoch), lr_a_final
) # Make sure lr is bounded above final.
lr_c_now = max(
lr_c_scheduler(epoch), lr_c_final
) # Make sure lr is bounded above final.
policy.set_learning_rates(
lr_a=lr_a_now, lr_c=lr_c_now, lr_alpha=lr_a_now
)

# Export model to 'SavedModel'
if export:
Expand Down