diff --git a/stable_learning_control/algos/tf2/lac/lac.py b/stable_learning_control/algos/tf2/lac/lac.py index 02fadbf9..ca1f9857 100644 --- a/stable_learning_control/algos/tf2/lac/lac.py +++ b/stable_learning_control/algos/tf2/lac/lac.py @@ -1101,7 +1101,9 @@ def lac( # Calculate the number of learning rate scheduler steps. if lr_decay_ref == "step": # NOTE: Decay applied at policy update to improve performance. - lr_decay_steps = (total_steps - update_after) / update_every + lr_decay_steps = ( + total_steps - update_after + ) / update_every + 1 # NOTE: +1 since we start at the initial learning rate. else: lr_decay_steps = epochs @@ -1110,16 +1112,6 @@ def lac( lr_a_scheduler = get_lr_scheduler(lr_decay_type, lr_a, lr_a_final, lr_decay_steps) lr_c_scheduler = get_lr_scheduler(lr_decay_type, lr_c, lr_c_final, lr_decay_steps) - # Create step based learning rate schedulers. - # NOTE: Used to estimate the learning rate at each step. - if lr_decay_ref == "step": - lr_a_step_scheduler = get_lr_scheduler( - lr_decay_type, lr_a, lr_a_final, lr_decay_steps + 1 - ) - lr_c_step_scheduler = get_lr_scheduler( - lr_decay_type, lr_c, lr_c_final, lr_decay_steps + 1 - ) - # Restore policy if supplied. if start_policy is not None: logger.log(f"Restoring model from '{start_policy}'.", type="info") @@ -1303,10 +1295,10 @@ def lac( # Retrieve current learning rates. if lr_decay_ref == "step": progress = max((t + 1) - update_after, 0) / update_every - lr_actor = lr_a_step_scheduler(progress) - lr_critic = lr_c_step_scheduler(progress) - lr_alpha = lr_a_step_scheduler(progress) - lr_labda = lr_a_step_scheduler(progress) + lr_actor = lr_a_scheduler(progress) + lr_critic = lr_c_scheduler(progress) + lr_alpha = lr_a_scheduler(progress) + lr_labda = lr_a_scheduler(progress) else: lr_actor = policy._pi_optimizer.lr.numpy() lr_critic = policy._c_optimizer.lr.numpy() diff --git a/stable_learning_control/algos/tf2/sac/sac.py b/stable_learning_control/algos/tf2/sac/sac.py index c49a7023..c9a0d71b 100644 --- a/stable_learning_control/algos/tf2/sac/sac.py +++ b/stable_learning_control/algos/tf2/sac/sac.py @@ -968,7 +968,9 @@ def sac( # Calculate the number of learning rate scheduler steps. if lr_decay_ref == "step": # NOTE: Decay applied at policy update to improve performance. - lr_decay_steps = (total_steps - update_after) / update_every + lr_decay_steps = ( + total_steps - update_after + ) / update_every + 1 # NOTE: +1 since we start at the initial learning rate. else: lr_decay_steps = epochs @@ -977,16 +979,6 @@ def sac( lr_a_scheduler = get_lr_scheduler(lr_decay_type, lr_a, lr_a_final, lr_decay_steps) lr_c_scheduler = get_lr_scheduler(lr_decay_type, lr_c, lr_c_final, lr_decay_steps) - # Create step based learning rate schedulers. - # NOTE: Used to estimate the learning rate at each step. - if lr_decay_ref == "step": - lr_a_step_scheduler = get_lr_scheduler( - lr_decay_type, lr_a, lr_a_final, lr_decay_steps + 1 - ) - lr_c_step_scheduler = get_lr_scheduler( - lr_decay_type, lr_c, lr_c_final, lr_decay_steps + 1 - ) - # Restore policy if supplied. if start_policy is not None: logger.log(f"Restoring model from '{start_policy}'.", type="info") @@ -1138,9 +1130,9 @@ def sac( # Retrieve current learning rates. if lr_decay_ref == "step": progress = max((t + 1) - update_after, 0) / update_every - lr_actor = lr_a_step_scheduler(progress) - lr_critic = lr_c_step_scheduler(progress) - lr_alpha = lr_a_step_scheduler(progress) + lr_actor = lr_a_scheduler(progress) + lr_critic = lr_c_scheduler(progress) + lr_alpha = lr_a_scheduler(progress) else: lr_actor = policy._pi_optimizer.lr.numpy() lr_critic = policy._c_optimizer.lr.numpy()