Skip to content

Commit

Permalink
fix(tf2): correct step-based learning rate decay (#407)
Browse files Browse the repository at this point in the history
This commit addresses an issue with the step-based learning rate decay
mechanism when `lr_decay_ref` is set to 'step'. Previously, the learning
rate was decaying too rapidly due to a bug in the decay logic. This fix
ensures that the learning rate decays at the correct pace as per the
step-based decay configuration.
  • Loading branch information
rickstaa authored Feb 12, 2024
1 parent 73c1374 commit 642a193
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 47 deletions.
4 changes: 2 additions & 2 deletions stable_learning_control/algos/tf2/common/get_lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ def get_lr_scheduler(decaying_lr_type, lr_start, lr_final, steps):
"""Creates a learning rate scheduler.
Args:
decaying_lr_type (str): The learning rate decay type that is used (
options are: ``linear`` and ``exponential`` and ``constant``).
decaying_lr_type (str): The learning rate decay type that is used (options are:
``linear`` and ``exponential`` and ``constant``).
lr_start (float): Initial learning rate.
lr_final (float): Final learning rate.
steps (int, optional): Number of steps/epochs used in the training. This
Expand Down
94 changes: 71 additions & 23 deletions stable_learning_control/algos/tf2/lac/lac.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,6 +952,7 @@ def lac(
- replay_buffer (union[:class:`~stable_learning_control.algos.common.buffers.ReplayBuffer`, :class:`~stable_learning_control.algos.common.buffers.FiniteHorizonReplayBuffer`]):
The replay buffer used during training.
""" # noqa: E501, D301
update_after = max(1, update_after) # You can not update before the first step.
validate_args(**locals())

# Retrieve hyperparameters while filtering out the logger_kwargs.
Expand Down Expand Up @@ -1083,11 +1084,41 @@ def lac(
device,
)

# Parse learning rate decay type.
valid_lr_decay_options = ["step", "epoch"]
lr_decay_ref = lr_decay_ref.lower()
if lr_decay_ref not in valid_lr_decay_options:
options = [f"'{option}'" for option in valid_lr_decay_options]
logger.log(
f"The learning rate decay reference variable was set to '{lr_decay_ref}', "
"which is not a valid option. Valid options are "
f"{', '.join(options)}. The learning rate decay reference "
"variable has been set to 'epoch'.",
type="warning",
)
lr_decay_ref = "epoch"

# Calculate the number of learning rate scheduler steps.
if lr_decay_ref == "step":
# NOTE: Decay applied at policy update to improve performance.
lr_decay_steps = (total_steps - update_after) / update_every
else:
lr_decay_steps = epochs

# Create learning rate schedulers.
# NOTE: Alpha and labda currently use the same scheduler as the actor.
lr_decay_ref_var = total_steps if lr_decay_ref.lower() == "steps" else epochs
lr_a_scheduler = get_lr_scheduler(lr_decay_type, lr_a, lr_a_final, lr_decay_ref_var)
lr_c_scheduler = get_lr_scheduler(lr_decay_type, lr_c, lr_c_final, lr_decay_ref_var)
lr_a_scheduler = get_lr_scheduler(lr_decay_type, lr_a, lr_a_final, lr_decay_steps)
lr_c_scheduler = get_lr_scheduler(lr_decay_type, lr_c, lr_c_final, lr_decay_steps)

# Create step based learning rate schedulers.
# NOTE: Used to estimate the learning rate at each step.
if lr_decay_ref == "step":
lr_a_step_scheduler = get_lr_scheduler(
lr_decay_type, lr_a, lr_a_final, lr_decay_steps + 1
)
lr_c_step_scheduler = get_lr_scheduler(
lr_decay_type, lr_c, lr_c_final, lr_decay_steps + 1
)

# Restore policy if supplied.
if start_policy is not None:
Expand Down Expand Up @@ -1165,6 +1196,7 @@ def lac(
"Entropy",
]
if use_tensorboard:
# NOTE: TensorBoard counts from 0.
logger.log_to_tb(
"Lr_a",
policy._pi_optimizer.lr.numpy(),
Expand Down Expand Up @@ -1227,6 +1259,7 @@ def lac(
# NOTE: Improved compared to Han et al. 2020. Previously, updates were based on
# memory size, which only changed at terminal states.
if (t + 1) >= update_after and ((t + 1) - update_after) % update_every == 0:
n_update = ((t + 1) - update_after) // update_every
for _ in range(steps_per_update):
batch = replay_buffer.sample_batch(batch_size)
update_diagnostics = policy.update(data=batch)
Expand All @@ -1235,18 +1268,20 @@ def lac(
# Step based learning rate decay.
if lr_decay_ref.lower() == "step":
lr_a_now = max(
lr_a_scheduler(t + 1), lr_a_final
lr_a_scheduler(n_update + 1), lr_a_final
) # Make sure lr is bounded above final lr.
lr_c_now = max(
lr_c_scheduler(t + 1), lr_c_final
lr_c_scheduler(n_update + 1), lr_c_final
) # Make sure lr is bounded above final lr.
policy.set_learning_rates(
lr_a=lr_a_now, lr_c=lr_c_now, lr_alpha=lr_a_now, lr_labda=lr_a_now
)

# SGD batch tb logging.
if use_tensorboard and not tb_low_log_freq:
logger.log_to_tb(keys=diag_tb_log_list, global_step=t)
logger.log_to_tb(
keys=diag_tb_log_list, global_step=t
) # NOTE: TensorBoard counts from 0.

# End of epoch handling (Save model, test performance and log data)
if (t + 1) % steps_per_epoch == 0:
Expand All @@ -1265,21 +1300,22 @@ def lac(
extend=True,
)

# Epoch based learning rate decay.
if lr_decay_ref.lower() != "step":
lr_a_now = max(
lr_a_scheduler(epoch), lr_a_final
) # Make sure lr is bounded above final.
lr_c_now = max(
lr_c_scheduler(epoch), lr_c_final
) # Make sure lr is bounded above final.
policy.set_learning_rates(
lr_a=lr_a_now, lr_c=lr_c_now, lr_alpha=lr_a_now, lr_labda=lr_a_now
)
# Retrieve current learning rates.
if lr_decay_ref == "step":
progress = max((t + 1) - update_after, 0) / update_every
lr_actor = lr_a_step_scheduler(progress)
lr_critic = lr_c_step_scheduler(progress)
lr_alpha = lr_a_step_scheduler(progress)
lr_labda = lr_a_step_scheduler(progress)
else:
lr_actor = policy._pi_optimizer.lr.numpy()
lr_critic = policy._c_optimizer.lr.numpy()
lr_alpha = policy._log_alpha_optimizer.lr.numpy()
lr_labda = policy._log_labda_optimizer.lr.numpy()

# Log info about epoch.
logger.log_tabular("Epoch", epoch)
logger.log_tabular("TotalEnvInteracts", t)
logger.log_tabular("TotalEnvInteracts", t + 1)
logger.log_tabular(
"EpRet",
with_min_and_max=True,
Expand All @@ -1299,25 +1335,25 @@ def lac(
)
logger.log_tabular(
"Lr_a",
policy._pi_optimizer.lr.numpy(),
lr_actor,
tb_write=use_tensorboard,
tb_prefix="LearningRates",
)
logger.log_tabular(
"Lr_c",
policy._c_optimizer.lr.numpy(),
lr_critic,
tb_write=use_tensorboard,
tb_prefix="LearningRates",
)
logger.log_tabular(
"Lr_alpha",
policy._log_alpha_optimizer.lr.numpy(),
lr_alpha,
tb_write=use_tensorboard,
tb_prefix="LearningRates",
)
logger.log_tabular(
"Lr_labda",
policy._log_labda_optimizer.lr.numpy(),
lr_labda,
tb_write=use_tensorboard,
tb_prefix="LearningRates",
)
Expand Down Expand Up @@ -1360,7 +1396,19 @@ def lac(
tb_write=(use_tensorboard and tb_low_log_freq),
)
logger.log_tabular("Time", time.time() - start_time)
logger.dump_tabular(global_step=t)
logger.dump_tabular(global_step=t) # NOTE: TensorBoard counts from 0.

# Epoch based learning rate decay.
if lr_decay_ref.lower() != "step":
lr_a_now = max(
lr_a_scheduler(epoch), lr_a_final
) # Make sure lr is bounded above final.
lr_c_now = max(
lr_c_scheduler(epoch), lr_c_final
) # Make sure lr is bounded above final.
policy.set_learning_rates(
lr_a=lr_a_now, lr_c=lr_c_now, lr_alpha=lr_a_now, lr_labda=lr_a_now
)

# Export model to 'SavedModel'
if export:
Expand Down
90 changes: 68 additions & 22 deletions stable_learning_control/algos/tf2/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,6 +821,7 @@ def sac(
- replay_buffer (union[:class:`~stable_learning_control.algos.common.buffers.ReplayBuffer`, :class:`~stable_learning_control.algos.common.buffers.FiniteHorizonReplayBuffer`]):
The replay buffer used during training.
""" # noqa: E501, D301
update_after = max(1, update_after) # You can not update before the first step.
validate_args(**locals())

# Retrieve hyperparameters while filtering out the logger_kwargs.
Expand Down Expand Up @@ -950,11 +951,41 @@ def sac(
device,
)

# Parse learning rate decay type.
valid_lr_decay_options = ["step", "epoch"]
lr_decay_ref = lr_decay_ref.lower()
if lr_decay_ref not in valid_lr_decay_options:
options = [f"'{option}'" for option in valid_lr_decay_options]
logger.log(
f"The learning rate decay reference variable was set to '{lr_decay_ref}', "
"which is not a valid option. Valid options are "
f"{', '.join(options)}. The learning rate decay reference "
"variable has been set to 'epoch'.",
type="warning",
)
lr_decay_ref = "epoch"

# Calculate the number of learning rate scheduler steps.
if lr_decay_ref == "step":
# NOTE: Decay applied at policy update to improve performance.
lr_decay_steps = (total_steps - update_after) / update_every
else:
lr_decay_steps = epochs

# Create learning rate schedulers.
# NOTE: Alpha currently uses the same scheduler as the actor.
lr_decay_ref_var = total_steps if lr_decay_ref.lower() == "steps" else epochs
lr_a_scheduler = get_lr_scheduler(lr_decay_type, lr_a, lr_a_final, lr_decay_ref_var)
lr_c_scheduler = get_lr_scheduler(lr_decay_type, lr_c, lr_c_final, lr_decay_ref_var)
lr_a_scheduler = get_lr_scheduler(lr_decay_type, lr_a, lr_a_final, lr_decay_steps)
lr_c_scheduler = get_lr_scheduler(lr_decay_type, lr_c, lr_c_final, lr_decay_steps)

# Create step based learning rate schedulers.
# NOTE: Used to estimate the learning rate at each step.
if lr_decay_ref == "step":
lr_a_step_scheduler = get_lr_scheduler(
lr_decay_type, lr_a, lr_a_final, lr_decay_steps + 1
)
lr_c_step_scheduler = get_lr_scheduler(
lr_decay_type, lr_c, lr_c_final, lr_decay_steps + 1
)

# Restore policy if supplied.
if start_policy is not None:
Expand Down Expand Up @@ -1010,6 +1041,7 @@ def sac(
# Setup diagnostics tb_write dict and store initial learning rates.
diag_tb_log_list = ["LossQ", "LossPi", "Alpha", "LossAlpha", "Entropy"]
if use_tensorboard:
# NOTE: TensorBoard counts from 0.
logger.log_to_tb(
"Lr_a",
policy._pi_optimizer.lr.numpy(),
Expand Down Expand Up @@ -1062,6 +1094,7 @@ def sac(
# NOTE: Improved compared to Han et al. 2020. Previously, updates were based on
# memory size, which only changed at terminal states.
if (t + 1) >= update_after and ((t + 1) - update_after) % update_every == 0:
n_update = ((t + 1) - update_after) // update_every
for _ in range(steps_per_update):
batch = replay_buffer.sample_batch(batch_size)
update_diagnostics = policy.update(data=batch)
Expand All @@ -1070,18 +1103,20 @@ def sac(
# Step based learning rate decay.
if lr_decay_ref.lower() == "step":
lr_a_now = max(
lr_a_scheduler(t + 1), lr_a_final
lr_a_scheduler(n_update + 1), lr_a_final
) # Make sure lr is bounded above final lr.
lr_c_now = max(
lr_c_scheduler(t + 1), lr_c_final
lr_c_scheduler(n_update + 1), lr_c_final
) # Make sure lr is bounded above final lr.
policy.set_learning_rates(
lr_a=lr_a_now, lr_c=lr_c_now, lr_alpha=lr_a_now
)

# SGD batch tb logging.
if use_tensorboard and not tb_low_log_freq:
logger.log_to_tb(keys=diag_tb_log_list, global_step=t)
logger.log_to_tb(
keys=diag_tb_log_list, global_step=t
) # NOTE: TensorBoard counts from 0.

# End of epoch handling (Save model, test performance and log data)
if (t + 1) % steps_per_epoch == 0:
Expand All @@ -1100,21 +1135,20 @@ def sac(
extend=True,
)

# Epoch based learning rate decay.
if lr_decay_ref.lower() != "step":
lr_a_now = max(
lr_a_scheduler(epoch), lr_a_final
) # Make sure lr is bounded above final.
lr_c_now = max(
lr_c_scheduler(epoch), lr_c_final
) # Make sure lr is bounded above final.
policy.set_learning_rates(
lr_a=lr_a_now, lr_c=lr_c_now, lr_alpha=lr_a_now
)
# Retrieve current learning rates.
if lr_decay_ref == "step":
progress = max((t + 1) - update_after, 0) / update_every
lr_actor = lr_a_step_scheduler(progress)
lr_critic = lr_c_step_scheduler(progress)
lr_alpha = lr_a_step_scheduler(progress)
else:
lr_actor = policy._pi_optimizer.lr.numpy()
lr_critic = policy._c_optimizer.lr.numpy()
lr_alpha = policy._log_alpha_optimizer.lr.numpy()

# Log info about epoch.
logger.log_tabular("Epoch", epoch)
logger.log_tabular("TotalEnvInteracts", t)
logger.log_tabular("TotalEnvInteracts", t + 1)
logger.log_tabular(
"EpRet",
with_min_and_max=True,
Expand All @@ -1134,19 +1168,19 @@ def sac(
)
logger.log_tabular(
"Lr_a",
policy._pi_optimizer.lr.numpy(),
lr_actor,
tb_write=use_tensorboard,
tb_prefix="LearningRates",
)
logger.log_tabular(
"Lr_c",
policy._c_optimizer.lr.numpy(),
lr_critic,
tb_write=use_tensorboard,
tb_prefix="LearningRates",
)
logger.log_tabular(
"Lr_alpha",
policy._log_alpha_optimizer.lr.numpy(),
lr_alpha,
tb_write=use_tensorboard,
tb_prefix="LearningRates",
)
Expand Down Expand Up @@ -1180,7 +1214,19 @@ def sac(
tb_write=(use_tensorboard and tb_low_log_freq),
)
logger.log_tabular("Time", time.time() - start_time)
logger.dump_tabular(global_step=t)
logger.dump_tabular(global_step=t) # NOTE: TensorBoard counts from 0.

# Epoch based learning rate decay.
if lr_decay_ref.lower() != "step":
lr_a_now = max(
lr_a_scheduler(epoch), lr_a_final
) # Make sure lr is bounded above final.
lr_c_now = max(
lr_c_scheduler(epoch), lr_c_final
) # Make sure lr is bounded above final.
policy.set_learning_rates(
lr_a=lr_a_now, lr_c=lr_c_now, lr_alpha=lr_a_now
)

# Export model to 'SavedModel'
if export:
Expand Down

0 comments on commit 642a193

Please sign in to comment.