From e669a23689be11a82a292c012807d48cfd46362f Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Sun, 11 Feb 2024 10:44:37 +0100 Subject: [PATCH] fix(pytorch): correct step-based learning rate decay This commit addresses an issue with the step-based learning rate decay mechanism when `lr_decay_ref` is set to 'step'. Previously, the learning rate was decaying too rapidly due to a bug in the decay logic. This fix ensures that the learning rate decays at the correct pace as per the step-based decay configuration. --- .../algos/pytorch/common/get_lr_scheduler.py | 68 ++++++-- .../algos/pytorch/lac/lac.py | 147 +++++++++++++----- .../algos/pytorch/sac/sac.py | 121 ++++++++++---- 3 files changed, 256 insertions(+), 80 deletions(-) diff --git a/stable_learning_control/algos/pytorch/common/get_lr_scheduler.py b/stable_learning_control/algos/pytorch/common/get_lr_scheduler.py index 69019baa..dc49fb1e 100644 --- a/stable_learning_control/algos/pytorch/common/get_lr_scheduler.py +++ b/stable_learning_control/algos/pytorch/common/get_lr_scheduler.py @@ -22,11 +22,12 @@ def get_exponential_decay_rate(lr_start, lr_final, steps): return gamma -def calc_linear_decay_rate(lr_init, lr_final, steps): - r"""Returns the linear decay factor (G) needed to achieve a given final learning - rate at a certain step. This decay factor can for example be used with a - :class:`torch.optim.lr_scheduler.LambdaLR` scheduler. Keep in mind that this - function assumes the following formula for the learning rate decay. +def get_linear_decay_rate(lr_init, lr_final, steps): + r"""Returns a linear decay factor (G) that enables a learning rate to transition + from an initial value (`lr_init`) at step 0 to a final value (`lr_final`) at a + specified step (N). This decay factor is compatible with the + :class:`torch.optim.lr_scheduler.LambdaLR` scheduler. The decay factor is calculated + using the following formula: .. math:: lr_{terminal} = lr_{init} * (1.0 - G \cdot step) @@ -34,10 +35,11 @@ def calc_linear_decay_rate(lr_init, lr_final, steps): Args: lr_init (float): The initial learning rate. lr_final (float): The final learning rate you want to achieve. - steps (int): The step/epoch at which you want to achieve this learning rate. + steps (int): The number of steps/epochs over which the learning rate should + decay. This is equal to epochs - 1. Returns: - decimal.Decimal: Linear learning rate decay factor (G) + decimal.Decimal: Linear learning rate decay factor (G). """ # noqa: W605 return -( ((Decimal(lr_final) / Decimal(lr_init)) - Decimal(1.0)) / Decimal(max(steps, 1)) @@ -53,7 +55,7 @@ def get_lr_scheduler(optimizer, decaying_lr_type, lr_start, lr_final, steps): (options are: ``linear`` and ``exponential`` and ``constant``). lr_start (float): Initial learning rate. lr_final (float): Final learning rate. - steps (int, optional): Number of steps/epochs used in the training. This + steps (int, optional): Number of steps/epochs used in the training. This includes the starting step. Returns: @@ -83,7 +85,7 @@ def lr_multiplier_function(step): return np.longdouble( Decimal(1.0) - ( - calc_linear_decay_rate(lr_start, lr_final, (steps - 1.0)) + get_linear_decay_rate(lr_start, lr_final, (steps - 1.0)) * Decimal(step) ) ) @@ -96,3 +98,51 @@ def lr_multiplier_function(step): return torch.optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=lambda step: np.longdouble(1.0) ) # Return a constant function. + + +def estimate_step_learning_rate( + lr_scheduler, lr_start, lr_final, update_after, total_steps, step +): + """Estimates the learning rate at a given step. + + This function estimates the learning rate for a specific training step. It differs + from the `get_last_lr` method of the learning rate scheduler, which returns the + learning rate at the last scheduler step, not necessarily the current training step. + + Args: + lr_scheduler (torch.optim.lr_scheduler): The learning rate scheduler. + lr_start (float): The initial learning rate. + update_after (int): The step number after which the learning rate should start + decreasing. + lr_final (float): The final learning rate. + total_steps (int): The total number of steps/epochs in the training process. + Excludes the initial step. + step (int): The current step number. Excludes the initial step. + + Returns: + float: The learning rate at the given step. + """ + if step < update_after: + return lr_start + else: + adjusted_step = step - update_after + adjusted_total_steps = total_steps - update_after + if isinstance(lr_scheduler, torch.optim.lr_scheduler.LambdaLR): + decay_rate = get_linear_decay_rate(lr_start, lr_final, adjusted_total_steps) + lr = float( + Decimal(lr_start) * (Decimal(1.0) - decay_rate * Decimal(adjusted_step)) + ) + elif isinstance(lr_scheduler, torch.optim.lr_scheduler.ExponentialLR): + decay_rate = get_exponential_decay_rate( + lr_start, lr_final, adjusted_total_steps + ) + lr = float( + Decimal(lr_start) * (Decimal(decay_rate) ** Decimal(adjusted_step)) + ) + else: + supported_schedulers = ["LambdaLR", "ExponentialLR"] + raise ValueError( + f"The learning rate scheduler is not supported for this function. " + f"Supported schedulers are: {', '.join(supported_schedulers)}" + ) + return max(lr, lr_final) diff --git a/stable_learning_control/algos/pytorch/lac/lac.py b/stable_learning_control/algos/pytorch/lac/lac.py index a03e2ab6..a6a5c16a 100644 --- a/stable_learning_control/algos/pytorch/lac/lac.py +++ b/stable_learning_control/algos/pytorch/lac/lac.py @@ -36,6 +36,7 @@ ) from stable_learning_control.algos.pytorch.common.get_lr_scheduler import ( get_lr_scheduler, + estimate_step_learning_rate, ) from stable_learning_control.algos.pytorch.common.helpers import ( count_vars, @@ -1111,7 +1112,7 @@ def lac( actor_critic = LyapunovActorCritic if actor_critic is None else actor_critic # Ensure the environment is correctly seeded. - # NOTE: Done here since we donote:n't want to seed on every env.reset() call. + # NOTE: Done here since we don't want to seed on every env.reset() call. if seed is not None: env.np_random, _ = seeding.np_random(seed) env.action_space.seed(seed) @@ -1197,29 +1198,51 @@ def lac( logger.log("Network structure:\n", type="info") logger.log(policy.ac, end="\n\n") - # Create learning rate schedulers. - opt_schedulers = [] - lr_decay_ref_var = total_steps if lr_decay_ref.lower() == "steps" else epochs - pi_opt_scheduler = get_lr_scheduler( - policy._pi_optimizer, lr_decay_type, lr_a, lr_a_final, lr_decay_ref_var - ) - opt_schedulers.append(pi_opt_scheduler) - alpha_opt_scheduler = get_lr_scheduler( - policy._log_alpha_optimizer, lr_decay_type, lr_a, lr_a_final, lr_decay_ref_var - ) - opt_schedulers.append(alpha_opt_scheduler) - c_opt_scheduler = get_lr_scheduler( - policy._c_optimizer, lr_decay_type, lr_c, lr_c_final, lr_decay_ref_var - ) - opt_schedulers.append(c_opt_scheduler) - labda_opt_scheduler = get_lr_scheduler( - policy._log_labda_optimizer, - lr_decay_type, - lr_a, - lr_a_final, - lr_decay_ref_var, - ) - opt_schedulers.append(labda_opt_scheduler) + # Parse learning rate decay type. + valid_lr_decay_options = ["step", "epoch"] + lr_decay_ref = lr_decay_ref.lower() + if lr_decay_ref not in valid_lr_decay_options: + options = [f"'{option}'" for option in valid_lr_decay_options] + logger.log( + f"The learning rate decay reference variable was set to '{lr_decay_ref}', " + "which is not a valid option. Valid options are " + f"{', '.join(options)}. The learning rate decay reference " + "variable has been set to 'epoch'.", + type="warning", + ) + lr_decay_ref = "epoch" + + # Calculate the number of learning rate scheduler steps. + if lr_decay_ref == "step": + # NOTE: Decay applied at policy update to improve performance. + lr_decay_steps = (total_steps - update_after) / update_every + else: + lr_decay_steps = epochs + + # Setup learning rate schedulers. + # NOTE: +1 since we start at the initial learning rate. + opt_schedulers = { + "pi": get_lr_scheduler( + policy._pi_optimizer, lr_decay_type, lr_a, lr_a_final, lr_decay_steps + 1 + ), + "c": get_lr_scheduler( + policy._c_optimizer, lr_decay_type, lr_c, lr_c_final, lr_decay_steps + 1 + ), + "alpha": get_lr_scheduler( + policy._log_alpha_optimizer, + lr_decay_type, + lr_a, + lr_a_final, + lr_decay_steps + 1, + ), + "lambda": get_lr_scheduler( + policy._log_labda_optimizer, + lr_decay_type, + lr_a, + lr_a_final, + lr_decay_steps + 1, + ), + } logger.setup_pytorch_saver(policy) @@ -1253,6 +1276,7 @@ def lac( "Entropy", ] if use_tensorboard: + # NOTE: TensorBoard counts from 0. logger.log_to_tb( "Lr_a", policy._pi_optimizer.param_groups[0]["lr"], @@ -1321,8 +1345,8 @@ def lac( logger.store(**update_diagnostics) # Log diagnostics. # Step based learning rate decay. - if lr_decay_ref.lower() == "step": - for scheduler in opt_schedulers: + if lr_decay_ref == "step": + for scheduler in opt_schedulers.values(): scheduler.step() policy.bound_lr( lr_a_final, lr_c_final, lr_a_final, lr_a_final @@ -1330,7 +1354,9 @@ def lac( # SGD batch tb logging. if use_tensorboard and not tb_low_log_freq: - logger.log_to_tb(keys=diag_tb_log_list, global_step=t) + logger.log_to_tb( + keys=diag_tb_log_list, global_step=t + ) # NOTE: TensorBoard counts from 0. # End of epoch handling (Save model, test performance and log data) if (t + 1) % steps_per_epoch == 0: @@ -1349,17 +1375,50 @@ def lac( extend=True, ) - # Epoch based learning rate decay. - if lr_decay_ref.lower() != "step": - for scheduler in opt_schedulers: - scheduler.step() - policy.bound_lr( - lr_a_final, lr_c_final, lr_a_final, lr_a_final - ) # Make sure lr is bounded above the final lr. + # Retrieve current learning rates. + if lr_decay_ref == "step": + # NOTE: Estimate since 'step' decay is applied at policy update. + lr_actor = estimate_step_learning_rate( + opt_schedulers["pi"], + lr_a, + lr_a_final, + update_after, + total_steps, + t + 1, + ) + lr_critic = estimate_step_learning_rate( + opt_schedulers["c"], + lr_c, + lr_c_final, + update_after, + total_steps, + t + 1, + ) + lr_alpha = estimate_step_learning_rate( + opt_schedulers["alpha"], + lr_a, + lr_a_final, + update_after, + total_steps, + t + 1, + ) + lr_labda = estimate_step_learning_rate( + opt_schedulers["lambda"], + lr_a, + lr_a_final, + update_after, + total_steps, + t + 1, + ) + else: + lr_actor = policy._pi_optimizer.param_groups[0]["lr"] + lr_critic = policy._c_optimizer.param_groups[0]["lr"] + lr_alpha = policy._log_alpha_optimizer.param_groups[0]["lr"] + lr_labda = policy._log_labda_optimizer.param_groups[0]["lr"] # Log info about epoch. logger.log_tabular("Epoch", epoch) - logger.log_tabular("TotalEnvInteracts", t) + logger.log_tabular("TotalEnvInteracts", t + 1) logger.log_tabular( "EpRet", with_min_and_max=True, @@ -1379,25 +1438,25 @@ def lac( ) logger.log_tabular( "Lr_a", - policy._pi_optimizer.param_groups[0]["lr"], + lr_actor, tb_write=use_tensorboard, tb_prefix="LearningRates", ) logger.log_tabular( "Lr_c", - policy._c_optimizer.param_groups[0]["lr"], + lr_critic, tb_write=use_tensorboard, tb_prefix="LearningRates", ) logger.log_tabular( "Lr_alpha", - policy._log_alpha_optimizer.param_groups[0]["lr"], + lr_alpha, tb_write=use_tensorboard, tb_prefix="LearningRates", ) logger.log_tabular( "Lr_labda", - policy._log_labda_optimizer.param_groups[0]["lr"], + lr_labda, tb_write=use_tensorboard, tb_prefix="LearningRates", ) @@ -1440,7 +1499,15 @@ def lac( tb_write=(use_tensorboard and tb_low_log_freq), ) logger.log_tabular("Time", time.time() - start_time) - logger.dump_tabular(global_step=t) + logger.dump_tabular(global_step=t) # NOTE: TensorBoard counts from 0. + + # Epoch based learning rate decay. + if lr_decay_ref != "step": + for scheduler in opt_schedulers.values(): + scheduler.step() + policy.bound_lr( + lr_a_final, lr_c_final, lr_a_final, lr_a_final + ) # Make sure lr is bounded above the final lr. # Export model to 'TorchScript' if export: diff --git a/stable_learning_control/algos/pytorch/sac/sac.py b/stable_learning_control/algos/pytorch/sac/sac.py index 7c492c43..b3e23cd6 100644 --- a/stable_learning_control/algos/pytorch/sac/sac.py +++ b/stable_learning_control/algos/pytorch/sac/sac.py @@ -33,6 +33,7 @@ from stable_learning_control.algos.pytorch.common.buffers import ReplayBuffer from stable_learning_control.algos.pytorch.common.get_lr_scheduler import ( get_lr_scheduler, + estimate_step_learning_rate, ) from stable_learning_control.algos.pytorch.common.helpers import ( count_vars, @@ -979,7 +980,7 @@ def sac( actor_critic = SoftActorCritic if actor_critic is None else actor_critic # Ensure the environment is correctly seeded. - # NOTE: Done here since we donote:n't want to seed on every env.reset() call. + # NOTE: Done here since we don't want to seed on every env.reset() call. if seed is not None: env.np_random, _ = seeding.np_random(seed) env.action_space.seed(seed) @@ -1048,21 +1049,44 @@ def sac( logger.log("Network structure:\n", type="info") logger.log(policy.ac, end="\n\n") - # Create learning rate schedulers. - opt_schedulers = [] - lr_decay_ref_var = total_steps if lr_decay_ref.lower() == "steps" else epochs - pi_opt_scheduler = get_lr_scheduler( - policy._pi_optimizer, lr_decay_type, lr_a, lr_a_final, lr_decay_ref_var - ) - opt_schedulers.append(pi_opt_scheduler) - alpha_opt_scheduler = get_lr_scheduler( - policy._log_alpha_optimizer, lr_decay_type, lr_a, lr_a_final, lr_decay_ref_var - ) - opt_schedulers.append(alpha_opt_scheduler) - c_opt_scheduler = get_lr_scheduler( - policy._c_optimizer, lr_decay_type, lr_c, lr_c_final, lr_decay_ref_var - ) - opt_schedulers.append(c_opt_scheduler) + # Parse learning rate decay type. + valid_lr_decay_options = ["step", "epoch"] + lr_decay_ref = lr_decay_ref.lower() + if lr_decay_ref not in valid_lr_decay_options: + options = [f"'{option}'" for option in valid_lr_decay_options] + logger.log( + f"The learning rate decay reference variable was set to '{lr_decay_ref}', " + "which is not a valid option. Valid options are " + f"{', '.join(options)}. The learning rate decay reference " + "variable has been set to 'epoch'.", + type="warning", + ) + lr_decay_ref = "epoch" + + # Calculate the number of learning rate scheduler steps. + if lr_decay_ref == "step": + # NOTE: Decay applied at policy update to improve performance. + lr_decay_steps = (total_steps - update_after) / update_every + else: + lr_decay_steps = epochs + + # Setup learning rate schedulers. + # NOTE: +1 since we start at the initial learning rate. + opt_schedulers = { + "pi": get_lr_scheduler( + policy._pi_optimizer, lr_decay_type, lr_a, lr_a_final, lr_decay_steps + 1 + ), + "c": get_lr_scheduler( + policy._c_optimizer, lr_decay_type, lr_c, lr_c_final, lr_decay_steps + 1 + ), + "alpha": get_lr_scheduler( + policy._log_alpha_optimizer, + lr_decay_type, + lr_a, + lr_a_final, + lr_decay_steps + 1, + ), + } logger.setup_pytorch_saver(policy) @@ -1088,6 +1112,7 @@ def sac( # Setup diagnostics tb_write dict and store initial learning rates. diag_tb_log_list = ["LossQ", "LossPi", "Alpha", "LossAlpha", "Entropy"] if use_tensorboard: + # NOTE: TensorBoard counts from 0. logger.log_to_tb( "Lr_a", policy._pi_optimizer.param_groups[0]["lr"], @@ -1146,8 +1171,8 @@ def sac( logger.store(**update_diagnostics) # Log diagnostics. # Step based learning rate decay. - if lr_decay_ref.lower() == "step": - for scheduler in opt_schedulers: + if lr_decay_ref == "step": + for scheduler in opt_schedulers.values(): scheduler.step() policy.bound_lr( lr_a_final, lr_c_final, lr_a_final @@ -1155,7 +1180,9 @@ def sac( # SGD batch tb logging. if use_tensorboard and not tb_low_log_freq: - logger.log_to_tb(keys=diag_tb_log_list, global_step=t) + logger.log_to_tb( + keys=diag_tb_log_list, global_step=t + ) # NOTE: TensorBoard counts from 0. # End of epoch handling (Save model, test performance and log data) if (t + 1) % steps_per_epoch == 0: @@ -1174,17 +1201,41 @@ def sac( extend=True, ) - # Epoch based learning rate decay. - if lr_decay_ref.lower() != "step": - for scheduler in opt_schedulers: - scheduler.step() - policy.bound_lr( - lr_a_final, lr_c_final, lr_a_final - ) # Make sure lr is bounded above the final lr. + # Retrieve current learning rates. + if lr_decay_ref == "step": + # NOTE: Estimate since 'step' decay is applied at policy update. + lr_actor = estimate_step_learning_rate( + opt_schedulers["pi"], + lr_a, + lr_a_final, + update_after, + total_steps, + t + 1, + ) + lr_critic = estimate_step_learning_rate( + opt_schedulers["c"], + lr_c, + lr_c_final, + update_after, + total_steps, + t + 1, + ) + lr_alpha = estimate_step_learning_rate( + opt_schedulers["alpha"], + lr_a, + lr_a_final, + update_after, + total_steps, + t + 1, + ) + else: + lr_actor = policy._pi_optimizer.param_groups[0]["lr"] + lr_critic = policy._c_optimizer.param_groups[0]["lr"] + lr_alpha = policy._log_alpha_optimizer.param_groups[0]["lr"] # Log info about epoch. logger.log_tabular("Epoch", epoch) - logger.log_tabular("TotalEnvInteracts", t) + logger.log_tabular("TotalEnvInteracts", t + 1) logger.log_tabular( "EpRet", with_min_and_max=True, @@ -1204,19 +1255,19 @@ def sac( ) logger.log_tabular( "Lr_a", - policy._pi_optimizer.param_groups[0]["lr"], + lr_actor, tb_write=use_tensorboard, tb_prefix="LearningRates", ) logger.log_tabular( "Lr_c", - policy._c_optimizer.param_groups[0]["lr"], + lr_critic, tb_write=use_tensorboard, tb_prefix="LearningRates", ) logger.log_tabular( "Lr_alpha", - policy._log_alpha_optimizer.param_groups[0]["lr"], + lr_alpha, tb_write=use_tensorboard, tb_prefix="LearningRates", ) @@ -1250,7 +1301,15 @@ def sac( tb_write=(use_tensorboard and tb_low_log_freq), ) logger.log_tabular("Time", time.time() - start_time) - logger.dump_tabular(global_step=t) + logger.dump_tabular(global_step=t) # NOTE: TensorBoard counts from 0. + + # Epoch based learning rate decay. + if lr_decay_ref != "step": + for scheduler in opt_schedulers.values(): + scheduler.step() + policy.bound_lr( + lr_a_final, lr_c_final, lr_a_final + ) # Make sure lr is bounded above the final lr. # Export model to 'TorchScript' if export: