From 6ab5001c588a97ec4379aed3c224d785964448c1 Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Mon, 19 Feb 2024 21:52:16 +0100 Subject: [PATCH] fix(tf2): correct off-by-one error in learning rate decay calculation (#415) This commit resolves an issue that led to incorrect learning rate decay. The root cause was an off-by-one error in the step count, which skewed the decay calculation. With this fix, the learning rate now decays accurately according to the specified schedule. --- stable_learning_control/algos/tf2/lac/lac.py | 22 +++++++------------- stable_learning_control/algos/tf2/sac/sac.py | 20 ++++++------------ 2 files changed, 13 insertions(+), 29 deletions(-) diff --git a/stable_learning_control/algos/tf2/lac/lac.py b/stable_learning_control/algos/tf2/lac/lac.py index 02fadbf9..ca1f9857 100644 --- a/stable_learning_control/algos/tf2/lac/lac.py +++ b/stable_learning_control/algos/tf2/lac/lac.py @@ -1101,7 +1101,9 @@ def lac( # Calculate the number of learning rate scheduler steps. if lr_decay_ref == "step": # NOTE: Decay applied at policy update to improve performance. - lr_decay_steps = (total_steps - update_after) / update_every + lr_decay_steps = ( + total_steps - update_after + ) / update_every + 1 # NOTE: +1 since we start at the initial learning rate. else: lr_decay_steps = epochs @@ -1110,16 +1112,6 @@ def lac( lr_a_scheduler = get_lr_scheduler(lr_decay_type, lr_a, lr_a_final, lr_decay_steps) lr_c_scheduler = get_lr_scheduler(lr_decay_type, lr_c, lr_c_final, lr_decay_steps) - # Create step based learning rate schedulers. - # NOTE: Used to estimate the learning rate at each step. - if lr_decay_ref == "step": - lr_a_step_scheduler = get_lr_scheduler( - lr_decay_type, lr_a, lr_a_final, lr_decay_steps + 1 - ) - lr_c_step_scheduler = get_lr_scheduler( - lr_decay_type, lr_c, lr_c_final, lr_decay_steps + 1 - ) - # Restore policy if supplied. if start_policy is not None: logger.log(f"Restoring model from '{start_policy}'.", type="info") @@ -1303,10 +1295,10 @@ def lac( # Retrieve current learning rates. if lr_decay_ref == "step": progress = max((t + 1) - update_after, 0) / update_every - lr_actor = lr_a_step_scheduler(progress) - lr_critic = lr_c_step_scheduler(progress) - lr_alpha = lr_a_step_scheduler(progress) - lr_labda = lr_a_step_scheduler(progress) + lr_actor = lr_a_scheduler(progress) + lr_critic = lr_c_scheduler(progress) + lr_alpha = lr_a_scheduler(progress) + lr_labda = lr_a_scheduler(progress) else: lr_actor = policy._pi_optimizer.lr.numpy() lr_critic = policy._c_optimizer.lr.numpy() diff --git a/stable_learning_control/algos/tf2/sac/sac.py b/stable_learning_control/algos/tf2/sac/sac.py index c49a7023..c9a0d71b 100644 --- a/stable_learning_control/algos/tf2/sac/sac.py +++ b/stable_learning_control/algos/tf2/sac/sac.py @@ -968,7 +968,9 @@ def sac( # Calculate the number of learning rate scheduler steps. if lr_decay_ref == "step": # NOTE: Decay applied at policy update to improve performance. - lr_decay_steps = (total_steps - update_after) / update_every + lr_decay_steps = ( + total_steps - update_after + ) / update_every + 1 # NOTE: +1 since we start at the initial learning rate. else: lr_decay_steps = epochs @@ -977,16 +979,6 @@ def sac( lr_a_scheduler = get_lr_scheduler(lr_decay_type, lr_a, lr_a_final, lr_decay_steps) lr_c_scheduler = get_lr_scheduler(lr_decay_type, lr_c, lr_c_final, lr_decay_steps) - # Create step based learning rate schedulers. - # NOTE: Used to estimate the learning rate at each step. - if lr_decay_ref == "step": - lr_a_step_scheduler = get_lr_scheduler( - lr_decay_type, lr_a, lr_a_final, lr_decay_steps + 1 - ) - lr_c_step_scheduler = get_lr_scheduler( - lr_decay_type, lr_c, lr_c_final, lr_decay_steps + 1 - ) - # Restore policy if supplied. if start_policy is not None: logger.log(f"Restoring model from '{start_policy}'.", type="info") @@ -1138,9 +1130,9 @@ def sac( # Retrieve current learning rates. if lr_decay_ref == "step": progress = max((t + 1) - update_after, 0) / update_every - lr_actor = lr_a_step_scheduler(progress) - lr_critic = lr_c_step_scheduler(progress) - lr_alpha = lr_a_step_scheduler(progress) + lr_actor = lr_a_scheduler(progress) + lr_critic = lr_c_scheduler(progress) + lr_alpha = lr_a_scheduler(progress) else: lr_actor = policy._pi_optimizer.lr.numpy() lr_critic = policy._c_optimizer.lr.numpy()