diff --git a/stable_learning_control/algos/pytorch/lac/lac.py b/stable_learning_control/algos/pytorch/lac/lac.py index f808dd57..ad28d0bb 100644 --- a/stable_learning_control/algos/pytorch/lac/lac.py +++ b/stable_learning_control/algos/pytorch/lac/lac.py @@ -1021,6 +1021,7 @@ def lac( - replay_buffer (union[:class:`~stable_learning_control.algos.pytorch.common.buffers.ReplayBuffer`, :class:`~stable_learning_control.algos.pytorch.common.buffers.FiniteHorizonReplayBuffer`]): The replay buffer used during training. """ # noqa: E501, D301 + update_after = max(1, update_after) # You can not update before the first step. validate_args(**locals()) # Retrieve hyperparameters while filtering out the logger_kwargs. diff --git a/stable_learning_control/algos/pytorch/sac/sac.py b/stable_learning_control/algos/pytorch/sac/sac.py index a3e6c8e5..f5685f39 100644 --- a/stable_learning_control/algos/pytorch/sac/sac.py +++ b/stable_learning_control/algos/pytorch/sac/sac.py @@ -889,6 +889,7 @@ def sac( - replay_buffer (union[:class:`~stable_learning_control.algos.common.buffers.ReplayBuffer`, :class:`~stable_learning_control.algos.common.buffers.FiniteHorizonReplayBuffer`]): The replay buffer used during training. """ # noqa: E501, D301 + update_after = max(1, update_after) # You can not update before the first step. validate_args(**locals()) # Retrieve hyperparameters while filtering out the logger_kwargs.