diff --git a/stable_learning_control/algos/pytorch/common/get_lr_scheduler.py b/stable_learning_control/algos/pytorch/common/get_lr_scheduler.py index aea79f42..a2935d10 100644 --- a/stable_learning_control/algos/pytorch/common/get_lr_scheduler.py +++ b/stable_learning_control/algos/pytorch/common/get_lr_scheduler.py @@ -50,7 +50,7 @@ def get_linear_decay_rate(lr_init, lr_final, steps): lr_init (float): The initial learning rate. lr_final (float): The final learning rate you want to achieve. steps (int): The number of steps/epochs over which the learning rate should - decay. This is equal to epochs - 1. + decay. This is equal to epochs -1. Returns: decimal.Decimal: Linear learning rate decay factor (G). diff --git a/stable_learning_control/algos/tf2/lac/lac.py b/stable_learning_control/algos/tf2/lac/lac.py index ca1f9857..06e3d6d4 100644 --- a/stable_learning_control/algos/tf2/lac/lac.py +++ b/stable_learning_control/algos/tf2/lac/lac.py @@ -73,6 +73,10 @@ "AverageLossPi", "AverageEntropy", ] +VALID_DECAY_TYPES = ["linear", "exponential", "constant"] +VALID_DECAY_REFERENCES = ["step", "epoch"] +DEFAULT_DECAY_TYPE = "linear" +DEFAULT_DECAY_REFERENCE = "epoch" # tf.config.run_functions_eagerly(True) # NOTE: Uncomment for debugging. @@ -106,6 +110,8 @@ def __init__( adaptive_temperature=True, lr_a=1e-4, lr_c=3e-4, + lr_alpha=1e-4, + lr_labda=3e-4, device="cpu", name="LAC", ): @@ -194,6 +200,10 @@ def __init__( ``1e-4``. lr_c (float, optional): Learning rate used for the (lyapunov) critic. Defaults to ``1e-4``. + lr_alpha (float, optional): Learning rate used for the entropy temperature. + Defaults to ``1e-4``. + lr_labda (float, optional): Learning rate used for the Lyapunov Lagrance + multiplier. Defaults to ``3e-4``. device (str, optional): The device the networks are placed on (options: ``cpu``, ``gpu``, ``gpu:0``, ``gpu:1``, etc.). Defaults to ``cpu``. @@ -258,8 +268,8 @@ def __init__( self._alpha3 = alpha3 self._lr_a = tf.Variable(lr_a, name="Lr_a") if self._adaptive_temperature: - self._lr_alpha = tf.Variable(lr_a, name="Lr_alpha") - self._lr_lag = tf.Variable(lr_a, name="Lr_lag") + self._lr_alpha = tf.Variable(lr_alpha, name="Lr_alpha") + self._lr_lag = tf.Variable(lr_labda, name="Lr_lag") self._lr_c = tf.Variable(lr_c, name="Lr_c") if not isinstance(target_entropy, (float, int)): self._target_entropy = heuristic_target_entropy(env.action_space) @@ -801,10 +811,18 @@ def lac( adaptive_temperature=True, lr_a=1e-4, lr_c=3e-4, + lr_alpha=1e-4, + lr_labda=3e-4, lr_a_final=1e-10, lr_c_final=1e-10, - lr_decay_type="linear", - lr_decay_ref="epoch", + lr_alpha_final=1e-10, + lr_labda_final=1e-10, + lr_decay_type=DEFAULT_DECAY_TYPE, + lr_a_decay_type=None, + lr_c_decay_type=None, + lr_alpha_decay_type=None, + lr_labda_decay_type=None, + lr_decay_ref=DEFAULT_DECAY_REFERENCE, batch_size=256, replay_size=int(1e6), seed=None, @@ -919,10 +937,33 @@ def lac( ``1e-4``. lr_c (float, optional): Learning rate used for the (lyapunov) critic. Defaults to ``1e-4``. + lr_alpha (float, optional): Learning rate used for the entropy temperature. + Defaults to ``1e-4``. + lr_labda (float, optional): Learning rate used for the Lyapunov Lagrance + multiplier. Defaults to ``3e-4``. lr_a_final(float, optional): The final actor learning rate that is achieved at the end of the training. Defaults to ``1e-10``. lr_c_final(float, optional): The final critic learning rate that is achieved at the end of the training. Defaults to ``1e-10``. + lr_alpha_final(float, optional): The final alpha learning rate that is + achieved at the end of the training. Defaults to ``1e-10``. + lr_labda_final(float, optional): The final labda learning rate that is + achieved at the end of the training. Defaults to ``1e-10``. + lr_decay_type (str, optional): The learning rate decay type that is used (options + are: ``linear`` and ``exponential`` and ``constant``). Defaults to + ``linear``.Can be overridden by the specific learning rate decay types. + lr_a_decay_type (str, optional): The learning rate decay type that is used for + the actor learning rate (options are: ``linear`` and ``exponential`` and + ``constant``). If not specified, the general learning rate decay type is used. + lr_c_decay_type (str, optional): The learning rate decay type that is used for + the critic learning rate (options are: ``linear`` and ``exponential`` and + ``constant``). If not specified, the general learning rate decay type is used. + lr_alpha_decay_type (str, optional): The learning rate decay type that is used + for the alpha learning rate (options are: ``linear`` and ``exponential`` + and ``constant``). If not specified, the general learning rate decay type is used. + lr_labda_decay_type (str, optional): The learning rate decay type that is used + for the labda learning rate (options are: ``linear`` and ``exponential`` + and ``constant``). If not specified, the general learning rate decay type is used. lr_decay_type (str, optional): The learning rate decay type that is used ( options are: ``linear`` and ``exponential`` and ``constant``). Defaults to ``linear``. @@ -1068,35 +1109,69 @@ def lac( # os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" # Disable for reproducibility. policy = LAC( - env, - actor_critic, - ac_kwargs, - opt_type, - alpha, - alpha3, - labda, - gamma, - polyak, - target_entropy, - adaptive_temperature, - lr_a, - lr_c, - device, - ) - - # Parse learning rate decay type. - valid_lr_decay_options = ["step", "epoch"] + env=env, + actor_critic=actor_critic, + ac_kwargs=ac_kwargs, + opt_type=opt_type, + alpha=alpha, + alpha3=alpha3, + labda=labda, + gamma=gamma, + polyak=polyak, + target_entropy=target_entropy, + adaptive_temperature=adaptive_temperature, + lr_a=lr_a, + lr_c=lr_c, + lr_alpha=lr_alpha, + lr_labda=lr_labda, + device=device, + ) + + # Parse learning rate decay reference. lr_decay_ref = lr_decay_ref.lower() - if lr_decay_ref not in valid_lr_decay_options: - options = [f"'{option}'" for option in valid_lr_decay_options] + if lr_decay_ref not in VALID_DECAY_REFERENCES: + options = [f"'{option}'" for option in VALID_DECAY_REFERENCES] logger.log( f"The learning rate decay reference variable was set to '{lr_decay_ref}', " "which is not a valid option. Valid options are " f"{', '.join(options)}. The learning rate decay reference " - "variable has been set to 'epoch'.", + f"variable has been set to '{DEFAULT_DECAY_REFERENCE}'.", type="warning", ) - lr_decay_ref = "epoch" + lr_decay_ref = DEFAULT_DECAY_REFERENCE + + # Parse learning rate decay types. + lr_decay_type = lr_decay_type.lower() + if lr_decay_type not in VALID_DECAY_TYPES: + options = [f"'{option}'" for option in VALID_DECAY_TYPES] + logger.log( + f"The learning rate decay type was set to '{lr_decay_type}', which is not " + "a valid option. Valid options are " + f"{', '.join(options)}. The learning rate decay type has been set to " + f"'{DEFAULT_DECAY_TYPE}'.", + type="warning", + ) + lr_decay_type = DEFAULT_DECAY_TYPE + decay_types = { + "actor": lr_a_decay_type.lower() if lr_a_decay_type else None, + "critic": lr_c_decay_type.lower() if lr_c_decay_type else None, + "alpha": lr_alpha_decay_type.lower() if lr_alpha_decay_type else None, + "labda": lr_labda_decay_type.lower() if lr_labda_decay_type else None, + } + for name, decay_type in decay_types.items(): + if decay_type is None: + decay_types[name] = lr_decay_type + else: + if decay_type not in VALID_DECAY_TYPES: + logger.log( + f"Invalid {name} learning rate decay type: '{decay_type}'. Using " + f"global learning rate decay type: '{lr_decay_type}' instead.", + type="warning", + ) + decay_types[name] = lr_decay_type + lr_a_decay_type, lr_c_decay_type, lr_alpha_decay_type, lr_labda_decay_type = ( + decay_types.values() + ) # Calculate the number of learning rate scheduler steps. if lr_decay_ref == "step": @@ -1108,9 +1183,19 @@ def lac( lr_decay_steps = epochs # Create learning rate schedulers. - # NOTE: Alpha and labda currently use the same scheduler as the actor. - lr_a_scheduler = get_lr_scheduler(lr_decay_type, lr_a, lr_a_final, lr_decay_steps) - lr_c_scheduler = get_lr_scheduler(lr_decay_type, lr_c, lr_c_final, lr_decay_steps) + lr_a_init, lr_c_init, lr_alpha_init, lr_labda_init = lr_a, lr_c, lr_alpha, lr_labda + lr_a_scheduler = get_lr_scheduler( + lr_a_decay_type, lr_a_init, lr_a_final, lr_decay_steps + ) + lr_c_scheduler = get_lr_scheduler( + lr_c_decay_type, lr_c_init, lr_c_final, lr_decay_steps + ) + lr_alpha_scheduler = get_lr_scheduler( + lr_alpha_decay_type, lr_alpha_init, lr_alpha_final, lr_decay_steps + ) + lr_labda_scheduler = get_lr_scheduler( + lr_labda_decay_type, lr_labda_init, lr_labda_final, lr_decay_steps + ) # Restore policy if supplied. if start_policy is not None: @@ -1265,8 +1350,17 @@ def lac( lr_c_now = max( lr_c_scheduler(n_update + 1), lr_c_final ) # Make sure lr is bounded above final lr. + lr_alpha_now = max( + lr_alpha_scheduler(n_update + 1), lr_alpha_final + ) # Make sure lr is bounded above final lr. + lr_labda_now = max( + lr_labda_scheduler(n_update + 1), lr_labda_final + ) # Make sure lr is bounded above final lr. policy.set_learning_rates( - lr_a=lr_a_now, lr_c=lr_c_now, lr_alpha=lr_a_now, lr_labda=lr_a_now + lr_a=lr_a_now, + lr_c=lr_c_now, + lr_alpha=lr_alpha_now, + lr_labda=lr_labda_now, ) # SGD batch tb logging. @@ -1297,8 +1391,8 @@ def lac( progress = max((t + 1) - update_after, 0) / update_every lr_actor = lr_a_scheduler(progress) lr_critic = lr_c_scheduler(progress) - lr_alpha = lr_a_scheduler(progress) - lr_labda = lr_a_scheduler(progress) + lr_alpha = lr_alpha_scheduler(progress) + lr_labda = lr_labda_scheduler(progress) else: lr_actor = policy._pi_optimizer.lr.numpy() lr_critic = policy._c_optimizer.lr.numpy() @@ -1398,8 +1492,17 @@ def lac( lr_c_now = max( lr_c_scheduler(epoch), lr_c_final ) # Make sure lr is bounded above final. + lr_alpha_now = max( + lr_alpha_scheduler(epoch), lr_alpha_final + ) # Make sure lr is bounded above final. + lr_labda_now = max( + lr_labda_scheduler(epoch), lr_labda_final + ) # Make sure lr is bounded above final. policy.set_learning_rates( - lr_a=lr_a_now, lr_c=lr_c_now, lr_alpha=lr_a_now, lr_labda=lr_a_now + lr_a=lr_a_now, + lr_c=lr_c_now, + lr_alpha=lr_alpha_now, + lr_labda=lr_labda_now, ) # Export model to 'SavedModel' @@ -1574,6 +1677,18 @@ def lac( parser.add_argument( "--lr_c", type=float, default=3e-4, help="critic learning rate (default: 1e-4)" ) + parser.add_argument( + "--lr_alpha", + type=float, + default=1e-4, + help="entropy temperature learning rate (default: 1e-4)", + ) + parser.add_argument( + "--lr_labda", + type=float, + default=3e-4, + help="lyapunov Lagrance multiplier learning rate (default: 3e-4)", + ) parser.add_argument( "--lr_a_final", type=float, @@ -1586,12 +1701,62 @@ def lac( default=1e-10, help="the finalcritic learning rate (default: 1e-10)", ) + parser.add_argument( + "--lr_alpha_final", + type=float, + default=1e-10, + help="the final entropy temperature learning rate (default: 1e-10)", + ) + parser.add_argument( + "--lr_labda_final", + type=float, + default=1e-10, + help="the final lyapunov Lagrance multiplier learning rate (default: 1e-10)", + ) parser.add_argument( "--lr_decay_type", type=str, default="linear", help="the learning rate decay type (default: linear)", ) + parser.add_argument( + "--lr_a_decay_type", + type=str, + default=None, + help=( + "the learning rate decay type that is used for the actor learning rate. " + "If not specified, the general learning rate decay type is used." + ), + ) + parser.add_argument( + "--lr_c_decay_type", + type=str, + default=None, + help=( + "the learning rate decay type that is used for the critic learning rate. " + "If not specified, the general learning rate decay type is used." + ), + ) + parser.add_argument( + "--lr_alpha_decay_type", + type=str, + default=None, + help=( + "the learning rate decay type that is used for the entropy temperature " + "learning rate. If not specified, the general learning rate decay type is " + "used." + ), + ) + parser.add_argument( + "--lr_labda_decay_type", + type=str, + default=None, + help=( + "the learning rate decay type that is used for the lyapunov Lagrance " + "multiplier learning rate. If not specified, the general learning rate " + "decay type is used." + ), + ) parser.add_argument( "--lr_decay_ref", type=str, @@ -1803,10 +1968,18 @@ def lac( adaptive_temperature=args.adaptive_temperature, lr_a=args.lr_a, lr_c=args.lr_c, + lr_alpha=args.lr_alpha, + lr_labda=args.lr_labda, lr_a_final=args.lr_a_final, lr_c_final=args.lr_c_final, + lr_alpha_final=args.lr_alpha_final, + lr_labda_final=args.lr_labda_final, lr_decay_type=args.lr_decay_type, lr_decay_ref=args.lr_decay_ref, + lr_a_decay_type=args.lr_a_decay_type, + lr_c_decay_type=args.lr_c_decay_type, + lr_alpha_decay_type=args.lr_alpha_decay_type, + lr_labda_decay_type=args.lr_labda_decay_type, batch_size=args.batch_size, replay_size=args.replay_size, horizon_length=args.horizon_length, diff --git a/stable_learning_control/algos/tf2/sac/sac.py b/stable_learning_control/algos/tf2/sac/sac.py index c9a0d71b..78f5c62c 100644 --- a/stable_learning_control/algos/tf2/sac/sac.py +++ b/stable_learning_control/algos/tf2/sac/sac.py @@ -63,6 +63,10 @@ "AverageLossPi", "AverageEntropy", ] +VALID_DECAY_TYPES = ["linear", "exponential", "constant"] +VALID_DECAY_REFERENCES = ["step", "epoch"] +DEFAULT_DECAY_TYPE = "linear" +DEFAULT_DECAY_REFERENCE = "epoch" # tf.config.run_functions_eagerly(True) # NOTE: Uncomment for debugging. @@ -93,6 +97,7 @@ def __init__( adaptive_temperature=True, lr_a=1e-4, lr_c=3e-4, + lr_alpha=1e-4, device="cpu", name="SAC", ): @@ -121,8 +126,7 @@ def __init__( | make sure to flatten this!) =========== ================ ====================================== - Calling ``pi`` should return: - + epoch =========== ================ ====================================== Symbol Shape Description =========== ================ ====================================== @@ -176,6 +180,8 @@ def __init__( ``1e-4``. lr_c (float, optional): Learning rate used for the (soft) critic. Defaults to ``1e-4``. + lr_alpha (float, optional): Learning rate used for the entropy temperature. + Defaults to ``1e-4``. device (str, optional): The device the networks are placed on (options: ``cpu``, ``gpu``, ``gpu:0``, ``gpu:1``, etc.). Defaults to ``cpu``. """ # noqa: E501, D301 @@ -234,7 +240,7 @@ def __init__( self._gamma = gamma self._lr_a = tf.Variable(lr_a, name="Lr_a") if self._adaptive_temperature: - self._lr_alpha = tf.Variable(lr_a, name="Lr_alpha") + self._lr_alpha = tf.Variable(lr_alpha, name="Lr_alpha") self._lr_c = tf.Variable(lr_c, name="Lr_c") if not isinstance(target_entropy, (float, int)): self._target_entropy = heuristic_target_entropy(env.action_space) @@ -678,10 +684,15 @@ def sac( adaptive_temperature=True, lr_a=1e-4, lr_c=3e-4, + lr_alpha=1e-4, lr_a_final=1e-10, lr_c_final=1e-10, - lr_decay_type="linear", - lr_decay_ref="epoch", + lr_alpha_final=1e-10, + lr_decay_type=DEFAULT_DECAY_TYPE, + lr_a_decay_type=None, + lr_c_decay_type=None, + lr_alpha_decay_type=None, + lr_decay_ref=DEFAULT_DECAY_REFERENCE, batch_size=256, replay_size=int(1e6), seed=None, @@ -791,6 +802,8 @@ def sac( ``1e-4``. lr_c (float, optional): Learning rate used for the (soft) critic. Defaults to ``1e-4``. + lr_alpha (float, optional): Learning rate used for the entropy temperature. + Defaults to ``1e-4``. lr_a_final(float, optional): The final actor learning rate that is achieved at the end of the training. Defaults to ``1e-10``. lr_c_final(float, optional): The final critic learning rate that is achieved @@ -798,6 +811,20 @@ def sac( lr_decay_type (str, optional): The learning rate decay type that is used ( options are: ``linear`` and ``exponential`` and ``constant``). Defaults to ``linear``. + lr_alpha_final(float, optional): The final alpha learning rate that is + achieved at the end of the training. Defaults to ``1e-10``. + lr_decay_type (str, optional): The learning rate decay type that is used (options + are: ``linear`` and ``exponential`` and ``constant``). Defaults to + ``linear``.Can be overridden by the specific learning rate decay types. + lr_a_decay_type (str, optional): The learning rate decay type that is used for + the actor learning rate (options are: ``linear`` and ``exponential`` and + ``constant``). If not specified, the general learning rate decay type is used. + lr_c_decay_type (str, optional): The learning rate decay type that is used for + the critic learning rate (options are: ``linear`` and ``exponential`` and + ``constant``). If not specified, the general learning rate decay type is used. + lr_alpha_decay_type (str, optional): The learning rate decay type that is used + for the alpha learning rate (options are: ``linear`` and ``exponential`` + and ``constant``). If not specified, the general learning rate decay type is used. lr_decay_ref (str, optional): The reference variable that is used for decaying the learning rate (options: ``epoch`` and ``step``). Defaults to ``epoch``. batch_size (int, optional): Minibatch size for SGD. Defaults to ``256``. @@ -937,33 +964,65 @@ def sac( # os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" # Disable for reproducibility. policy = SAC( - env, - actor_critic, - ac_kwargs, - opt_type, - alpha, - gamma, - polyak, - target_entropy, - adaptive_temperature, - lr_a, - lr_c, - device, - ) - - # Parse learning rate decay type. - valid_lr_decay_options = ["step", "epoch"] + env=env, + actor_critic=actor_critic, + ac_kwargs=ac_kwargs, + opt_type=opt_type, + alpha=alpha, + gamma=gamma, + polyak=polyak, + target_entropy=target_entropy, + adaptive_temperature=adaptive_temperature, + lr_a=lr_a, + lr_c=lr_c, + lr_alpha=lr_alpha, + device=device, + ) + + # Parse learning rate decay reference. lr_decay_ref = lr_decay_ref.lower() - if lr_decay_ref not in valid_lr_decay_options: - options = [f"'{option}'" for option in valid_lr_decay_options] + if lr_decay_ref not in VALID_DECAY_REFERENCES: + options = [f"'{option}'" for option in VALID_DECAY_REFERENCES] logger.log( f"The learning rate decay reference variable was set to '{lr_decay_ref}', " "which is not a valid option. Valid options are " f"{', '.join(options)}. The learning rate decay reference " - "variable has been set to 'epoch'.", + f"variable has been set to '{DEFAULT_DECAY_REFERENCE}'.", type="warning", ) - lr_decay_ref = "epoch" + lr_decay_ref = DEFAULT_DECAY_REFERENCE + + # Parse learning rate decay types. + lr_decay_type = lr_decay_type.lower() + if lr_decay_type not in VALID_DECAY_TYPES: + options = [f"'{option}'" for option in VALID_DECAY_TYPES] + logger.log( + f"The learning rate decay type was set to '{lr_decay_type}', which is not " + "a valid option. Valid options are " + f"{', '.join(options)}. The learning rate decay type has been set to " + f"'{DEFAULT_DECAY_TYPE}'.", + type="warning", + ) + lr_decay_type = DEFAULT_DECAY_TYPE + decay_types = { + "actor": lr_a_decay_type.lower() if lr_a_decay_type else None, + "critic": lr_c_decay_type.lower() if lr_c_decay_type else None, + "alpha": lr_alpha_decay_type.lower() if lr_alpha_decay_type else None, + } + for name, decay_type in decay_types.items(): + if decay_type is None: + decay_types[name] = lr_decay_type + else: + if decay_type not in VALID_DECAY_TYPES: + logger.log( + f"Invalid {name} learning rate decay type: '{decay_type}'. Using " + f"global learning rate decay type: '{lr_decay_type}' instead.", + type="warning", + ) + decay_types[name] = lr_decay_type + lr_a_decay_type, lr_c_decay_type, lr_alpha_decay_type = ( + decay_types.values() + ) # Calculate the number of learning rate scheduler steps. if lr_decay_ref == "step": @@ -976,8 +1035,16 @@ def sac( # Create learning rate schedulers. # NOTE: Alpha currently uses the same scheduler as the actor. - lr_a_scheduler = get_lr_scheduler(lr_decay_type, lr_a, lr_a_final, lr_decay_steps) - lr_c_scheduler = get_lr_scheduler(lr_decay_type, lr_c, lr_c_final, lr_decay_steps) + lr_a_init, lr_c_init, lr_alpha_init = lr_a, lr_c, lr_alpha + lr_a_scheduler = get_lr_scheduler( + lr_a_decay_type, lr_a_init, lr_a_final, lr_decay_steps + ) + lr_c_scheduler = get_lr_scheduler( + lr_c_decay_type, lr_c_init, lr_c_final, lr_decay_steps + ) + lr_alpha_scheduler = get_lr_scheduler( + lr_alpha_decay_type, lr_alpha_init, lr_alpha_final, lr_decay_steps + ) # Restore policy if supplied. if start_policy is not None: @@ -1100,8 +1167,11 @@ def sac( lr_c_now = max( lr_c_scheduler(n_update + 1), lr_c_final ) # Make sure lr is bounded above final lr. + lr_alpha_now = max( + lr_alpha_scheduler(n_update + 1), lr_alpha_final + ) # Make sure lr is bounded above final lr. policy.set_learning_rates( - lr_a=lr_a_now, lr_c=lr_c_now, lr_alpha=lr_a_now + lr_a=lr_a_now, lr_c=lr_c_now, lr_alpha=lr_alpha_now ) # SGD batch tb logging. @@ -1132,7 +1202,7 @@ def sac( progress = max((t + 1) - update_after, 0) / update_every lr_actor = lr_a_scheduler(progress) lr_critic = lr_c_scheduler(progress) - lr_alpha = lr_a_scheduler(progress) + lr_alpha = lr_alpha_scheduler(progress) else: lr_actor = policy._pi_optimizer.lr.numpy() lr_critic = policy._c_optimizer.lr.numpy() @@ -1216,9 +1286,12 @@ def sac( lr_c_now = max( lr_c_scheduler(epoch), lr_c_final ) # Make sure lr is bounded above final. + lr_alpha_now = max( + lr_alpha_scheduler(epoch), lr_alpha_final + ) # Make sure lr is bounded above final. policy.set_learning_rates( - lr_a=lr_a_now, lr_c=lr_c_now, lr_alpha=lr_a_now - ) + lr_a=lr_a_now, lr_c=lr_c_now, lr_alpha=lr_alpha_now + ) # Make sure lr is bounded above final. # Export model to 'SavedModel' if export: @@ -1386,6 +1459,12 @@ def sac( parser.add_argument( "--lr_c", type=float, default=3e-4, help="critic learning rate (default: 1e-4)" ) + parser.add_argument( + "--lr_alpha", + type=float, + default=1e-4, + help="entropy temperature learning rate (default: 1e-4)", + ) parser.add_argument( "--lr_a_final", type=float, @@ -1398,12 +1477,46 @@ def sac( default=1e-10, help="the finalcritic learning rate (default: 1e-10)", ) + parser.add_argument( + "--lr_alpha_final", + type=float, + default=1e-10, + help="the final entropy temperature learning rate (default: 1e-10)", + ) parser.add_argument( "--lr_decay_type", type=str, default="linear", help="the learning rate decay type (default: linear)", ) + parser.add_argument( + "--lr_a_decay_type", + type=str, + default=None, + help=( + "the learning rate decay type that is used for the actor learning rate. " + "If not specified, the general learning rate decay type is used." + ), + ) + parser.add_argument( + "--lr_c_decay_type", + type=str, + default=None, + help=( + "the learning rate decay type that is used for the critic learning rate. " + "If not specified, the general learning rate decay type is used." + ), + ) + parser.add_argument( + "--lr_alpha_decay_type", + type=str, + default=None, + help=( + "the learning rate decay type that is used for the entropy temperature " + "learning rate. If not specified, the general learning rate decay type is " + "used." + ), + ) parser.add_argument( "--lr_decay_ref", type=str, @@ -1605,9 +1718,14 @@ def sac( adaptive_temperature=args.adaptive_temperature, lr_a=args.lr_a, lr_c=args.lr_c, + lr_alpha=args.lr_alpha, lr_a_final=args.lr_a_final, lr_c_final=args.lr_c_final, + lr_alpha_final=args.lr_alpha, lr_decay_type=args.lr_decay_type, + lr_a_decay_type=args.lr_a_decay_type, + lr_c_decay_type=args.lr_c_decay_type, + lr_alpha_decay_type=args.lr_alpha_decay_type, lr_decay_ref=args.lr_decay_ref, batch_size=args.batch_size, replay_size=args.replay_size,