From 9b8b5b5c94da426a5592f18bf4c429d045dcc227 Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Mon, 12 Feb 2024 17:21:18 +0100 Subject: [PATCH] fix(torch): handle 'update_after' set to zero This commit addresses a bug in the learning rate decay logic when 'update_after' is set to zero. Previously, the algorithm would malfunction under these conditions. With this fix, the algorithm can now correctly handle 'update_after' being set to zero, ensuring proper learning rate decay. --- stable_learning_control/algos/pytorch/lac/lac.py | 1 + stable_learning_control/algos/pytorch/sac/sac.py | 1 + 2 files changed, 2 insertions(+) diff --git a/stable_learning_control/algos/pytorch/lac/lac.py b/stable_learning_control/algos/pytorch/lac/lac.py index f808dd57..ad28d0bb 100644 --- a/stable_learning_control/algos/pytorch/lac/lac.py +++ b/stable_learning_control/algos/pytorch/lac/lac.py @@ -1021,6 +1021,7 @@ def lac( - replay_buffer (union[:class:`~stable_learning_control.algos.pytorch.common.buffers.ReplayBuffer`, :class:`~stable_learning_control.algos.pytorch.common.buffers.FiniteHorizonReplayBuffer`]): The replay buffer used during training. """ # noqa: E501, D301 + update_after = max(1, update_after) # You can not update before the first step. validate_args(**locals()) # Retrieve hyperparameters while filtering out the logger_kwargs. diff --git a/stable_learning_control/algos/pytorch/sac/sac.py b/stable_learning_control/algos/pytorch/sac/sac.py index a3e6c8e5..f5685f39 100644 --- a/stable_learning_control/algos/pytorch/sac/sac.py +++ b/stable_learning_control/algos/pytorch/sac/sac.py @@ -889,6 +889,7 @@ def sac( - replay_buffer (union[:class:`~stable_learning_control.algos.common.buffers.ReplayBuffer`, :class:`~stable_learning_control.algos.common.buffers.FiniteHorizonReplayBuffer`]): The replay buffer used during training. """ # noqa: E501, D301 + update_after = max(1, update_after) # You can not update before the first step. validate_args(**locals()) # Retrieve hyperparameters while filtering out the logger_kwargs.