Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(tf2): add alpha/lambda learning rate customization #416

Merged
merged 1 commit into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def get_linear_decay_rate(lr_init, lr_final, steps):
lr_init (float): The initial learning rate.
lr_final (float): The final learning rate you want to achieve.
steps (int): The number of steps/epochs over which the learning rate should
decay. This is equal to epochs - 1.
decay. This is equal to epochs -1.

Returns:
decimal.Decimal: Linear learning rate decay factor (G).
Expand Down
239 changes: 206 additions & 33 deletions stable_learning_control/algos/tf2/lac/lac.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@
"AverageLossPi",
"AverageEntropy",
]
VALID_DECAY_TYPES = ["linear", "exponential", "constant"]
VALID_DECAY_REFERENCES = ["step", "epoch"]
DEFAULT_DECAY_TYPE = "linear"
DEFAULT_DECAY_REFERENCE = "epoch"

# tf.config.run_functions_eagerly(True) # NOTE: Uncomment for debugging.

Expand Down Expand Up @@ -106,6 +110,8 @@ def __init__(
adaptive_temperature=True,
lr_a=1e-4,
lr_c=3e-4,
lr_alpha=1e-4,
lr_labda=3e-4,
device="cpu",
name="LAC",
):
Expand Down Expand Up @@ -194,6 +200,10 @@ def __init__(
``1e-4``.
lr_c (float, optional): Learning rate used for the (lyapunov) critic.
Defaults to ``1e-4``.
lr_alpha (float, optional): Learning rate used for the entropy temperature.
Defaults to ``1e-4``.
lr_labda (float, optional): Learning rate used for the Lyapunov Lagrance
multiplier. Defaults to ``3e-4``.
device (str, optional): The device the networks are placed on (options:
``cpu``, ``gpu``, ``gpu:0``, ``gpu:1``, etc.). Defaults to ``cpu``.

Expand Down Expand Up @@ -258,8 +268,8 @@ def __init__(
self._alpha3 = alpha3
self._lr_a = tf.Variable(lr_a, name="Lr_a")
if self._adaptive_temperature:
self._lr_alpha = tf.Variable(lr_a, name="Lr_alpha")
self._lr_lag = tf.Variable(lr_a, name="Lr_lag")
self._lr_alpha = tf.Variable(lr_alpha, name="Lr_alpha")
self._lr_lag = tf.Variable(lr_labda, name="Lr_lag")
self._lr_c = tf.Variable(lr_c, name="Lr_c")
if not isinstance(target_entropy, (float, int)):
self._target_entropy = heuristic_target_entropy(env.action_space)
Expand Down Expand Up @@ -801,10 +811,18 @@ def lac(
adaptive_temperature=True,
lr_a=1e-4,
lr_c=3e-4,
lr_alpha=1e-4,
lr_labda=3e-4,
lr_a_final=1e-10,
lr_c_final=1e-10,
lr_decay_type="linear",
lr_decay_ref="epoch",
lr_alpha_final=1e-10,
lr_labda_final=1e-10,
lr_decay_type=DEFAULT_DECAY_TYPE,
lr_a_decay_type=None,
lr_c_decay_type=None,
lr_alpha_decay_type=None,
lr_labda_decay_type=None,
lr_decay_ref=DEFAULT_DECAY_REFERENCE,
batch_size=256,
replay_size=int(1e6),
seed=None,
Expand Down Expand Up @@ -919,10 +937,33 @@ def lac(
``1e-4``.
lr_c (float, optional): Learning rate used for the (lyapunov) critic. Defaults
to ``1e-4``.
lr_alpha (float, optional): Learning rate used for the entropy temperature.
Defaults to ``1e-4``.
lr_labda (float, optional): Learning rate used for the Lyapunov Lagrance
multiplier. Defaults to ``3e-4``.
lr_a_final(float, optional): The final actor learning rate that is achieved
at the end of the training. Defaults to ``1e-10``.
lr_c_final(float, optional): The final critic learning rate that is achieved
at the end of the training. Defaults to ``1e-10``.
lr_alpha_final(float, optional): The final alpha learning rate that is
achieved at the end of the training. Defaults to ``1e-10``.
lr_labda_final(float, optional): The final labda learning rate that is
achieved at the end of the training. Defaults to ``1e-10``.
lr_decay_type (str, optional): The learning rate decay type that is used (options
are: ``linear`` and ``exponential`` and ``constant``). Defaults to
``linear``.Can be overridden by the specific learning rate decay types.
lr_a_decay_type (str, optional): The learning rate decay type that is used for
the actor learning rate (options are: ``linear`` and ``exponential`` and
``constant``). If not specified, the general learning rate decay type is used.
lr_c_decay_type (str, optional): The learning rate decay type that is used for
the critic learning rate (options are: ``linear`` and ``exponential`` and
``constant``). If not specified, the general learning rate decay type is used.
lr_alpha_decay_type (str, optional): The learning rate decay type that is used
for the alpha learning rate (options are: ``linear`` and ``exponential``
and ``constant``). If not specified, the general learning rate decay type is used.
lr_labda_decay_type (str, optional): The learning rate decay type that is used
for the labda learning rate (options are: ``linear`` and ``exponential``
and ``constant``). If not specified, the general learning rate decay type is used.
lr_decay_type (str, optional): The learning rate decay type that is used (
options are: ``linear`` and ``exponential`` and ``constant``). Defaults to
``linear``.
Expand Down Expand Up @@ -1068,35 +1109,69 @@ def lac(
# os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" # Disable for reproducibility.

policy = LAC(
env,
actor_critic,
ac_kwargs,
opt_type,
alpha,
alpha3,
labda,
gamma,
polyak,
target_entropy,
adaptive_temperature,
lr_a,
lr_c,
device,
)

# Parse learning rate decay type.
valid_lr_decay_options = ["step", "epoch"]
env=env,
actor_critic=actor_critic,
ac_kwargs=ac_kwargs,
opt_type=opt_type,
alpha=alpha,
alpha3=alpha3,
labda=labda,
gamma=gamma,
polyak=polyak,
target_entropy=target_entropy,
adaptive_temperature=adaptive_temperature,
lr_a=lr_a,
lr_c=lr_c,
lr_alpha=lr_alpha,
lr_labda=lr_labda,
device=device,
)

# Parse learning rate decay reference.
lr_decay_ref = lr_decay_ref.lower()
if lr_decay_ref not in valid_lr_decay_options:
options = [f"'{option}'" for option in valid_lr_decay_options]
if lr_decay_ref not in VALID_DECAY_REFERENCES:
options = [f"'{option}'" for option in VALID_DECAY_REFERENCES]
logger.log(
f"The learning rate decay reference variable was set to '{lr_decay_ref}', "
"which is not a valid option. Valid options are "
f"{', '.join(options)}. The learning rate decay reference "
"variable has been set to 'epoch'.",
f"variable has been set to '{DEFAULT_DECAY_REFERENCE}'.",
type="warning",
)
lr_decay_ref = "epoch"
lr_decay_ref = DEFAULT_DECAY_REFERENCE

# Parse learning rate decay types.
lr_decay_type = lr_decay_type.lower()
if lr_decay_type not in VALID_DECAY_TYPES:
options = [f"'{option}'" for option in VALID_DECAY_TYPES]
logger.log(
f"The learning rate decay type was set to '{lr_decay_type}', which is not "
"a valid option. Valid options are "
f"{', '.join(options)}. The learning rate decay type has been set to "
f"'{DEFAULT_DECAY_TYPE}'.",
type="warning",
)
lr_decay_type = DEFAULT_DECAY_TYPE
decay_types = {
"actor": lr_a_decay_type.lower() if lr_a_decay_type else None,
"critic": lr_c_decay_type.lower() if lr_c_decay_type else None,
"alpha": lr_alpha_decay_type.lower() if lr_alpha_decay_type else None,
"labda": lr_labda_decay_type.lower() if lr_labda_decay_type else None,
}
for name, decay_type in decay_types.items():
if decay_type is None:
decay_types[name] = lr_decay_type
else:
if decay_type not in VALID_DECAY_TYPES:
logger.log(
f"Invalid {name} learning rate decay type: '{decay_type}'. Using "
f"global learning rate decay type: '{lr_decay_type}' instead.",
type="warning",
)
decay_types[name] = lr_decay_type
lr_a_decay_type, lr_c_decay_type, lr_alpha_decay_type, lr_labda_decay_type = (
decay_types.values()
)

# Calculate the number of learning rate scheduler steps.
if lr_decay_ref == "step":
Expand All @@ -1108,9 +1183,19 @@ def lac(
lr_decay_steps = epochs

# Create learning rate schedulers.
# NOTE: Alpha and labda currently use the same scheduler as the actor.
lr_a_scheduler = get_lr_scheduler(lr_decay_type, lr_a, lr_a_final, lr_decay_steps)
lr_c_scheduler = get_lr_scheduler(lr_decay_type, lr_c, lr_c_final, lr_decay_steps)
lr_a_init, lr_c_init, lr_alpha_init, lr_labda_init = lr_a, lr_c, lr_alpha, lr_labda
lr_a_scheduler = get_lr_scheduler(
lr_a_decay_type, lr_a_init, lr_a_final, lr_decay_steps
)
lr_c_scheduler = get_lr_scheduler(
lr_c_decay_type, lr_c_init, lr_c_final, lr_decay_steps
)
lr_alpha_scheduler = get_lr_scheduler(
lr_alpha_decay_type, lr_alpha_init, lr_alpha_final, lr_decay_steps
)
lr_labda_scheduler = get_lr_scheduler(
lr_labda_decay_type, lr_labda_init, lr_labda_final, lr_decay_steps
)

# Restore policy if supplied.
if start_policy is not None:
Expand Down Expand Up @@ -1265,8 +1350,17 @@ def lac(
lr_c_now = max(
lr_c_scheduler(n_update + 1), lr_c_final
) # Make sure lr is bounded above final lr.
lr_alpha_now = max(
lr_alpha_scheduler(n_update + 1), lr_alpha_final
) # Make sure lr is bounded above final lr.
lr_labda_now = max(
lr_labda_scheduler(n_update + 1), lr_labda_final
) # Make sure lr is bounded above final lr.
policy.set_learning_rates(
lr_a=lr_a_now, lr_c=lr_c_now, lr_alpha=lr_a_now, lr_labda=lr_a_now
lr_a=lr_a_now,
lr_c=lr_c_now,
lr_alpha=lr_alpha_now,
lr_labda=lr_labda_now,
)

# SGD batch tb logging.
Expand Down Expand Up @@ -1297,8 +1391,8 @@ def lac(
progress = max((t + 1) - update_after, 0) / update_every
lr_actor = lr_a_scheduler(progress)
lr_critic = lr_c_scheduler(progress)
lr_alpha = lr_a_scheduler(progress)
lr_labda = lr_a_scheduler(progress)
lr_alpha = lr_alpha_scheduler(progress)
lr_labda = lr_labda_scheduler(progress)
else:
lr_actor = policy._pi_optimizer.lr.numpy()
lr_critic = policy._c_optimizer.lr.numpy()
Expand Down Expand Up @@ -1398,8 +1492,17 @@ def lac(
lr_c_now = max(
lr_c_scheduler(epoch), lr_c_final
) # Make sure lr is bounded above final.
lr_alpha_now = max(
lr_alpha_scheduler(epoch), lr_alpha_final
) # Make sure lr is bounded above final.
lr_labda_now = max(
lr_labda_scheduler(epoch), lr_labda_final
) # Make sure lr is bounded above final.
policy.set_learning_rates(
lr_a=lr_a_now, lr_c=lr_c_now, lr_alpha=lr_a_now, lr_labda=lr_a_now
lr_a=lr_a_now,
lr_c=lr_c_now,
lr_alpha=lr_alpha_now,
lr_labda=lr_labda_now,
)

# Export model to 'SavedModel'
Expand Down Expand Up @@ -1574,6 +1677,18 @@ def lac(
parser.add_argument(
"--lr_c", type=float, default=3e-4, help="critic learning rate (default: 1e-4)"
)
parser.add_argument(
"--lr_alpha",
type=float,
default=1e-4,
help="entropy temperature learning rate (default: 1e-4)",
)
parser.add_argument(
"--lr_labda",
type=float,
default=3e-4,
help="lyapunov Lagrance multiplier learning rate (default: 3e-4)",
)
parser.add_argument(
"--lr_a_final",
type=float,
Expand All @@ -1586,12 +1701,62 @@ def lac(
default=1e-10,
help="the finalcritic learning rate (default: 1e-10)",
)
parser.add_argument(
"--lr_alpha_final",
type=float,
default=1e-10,
help="the final entropy temperature learning rate (default: 1e-10)",
)
parser.add_argument(
"--lr_labda_final",
type=float,
default=1e-10,
help="the final lyapunov Lagrance multiplier learning rate (default: 1e-10)",
)
parser.add_argument(
"--lr_decay_type",
type=str,
default="linear",
help="the learning rate decay type (default: linear)",
)
parser.add_argument(
"--lr_a_decay_type",
type=str,
default=None,
help=(
"the learning rate decay type that is used for the actor learning rate. "
"If not specified, the general learning rate decay type is used."
),
)
parser.add_argument(
"--lr_c_decay_type",
type=str,
default=None,
help=(
"the learning rate decay type that is used for the critic learning rate. "
"If not specified, the general learning rate decay type is used."
),
)
parser.add_argument(
"--lr_alpha_decay_type",
type=str,
default=None,
help=(
"the learning rate decay type that is used for the entropy temperature "
"learning rate. If not specified, the general learning rate decay type is "
"used."
),
)
parser.add_argument(
"--lr_labda_decay_type",
type=str,
default=None,
help=(
"the learning rate decay type that is used for the lyapunov Lagrance "
"multiplier learning rate. If not specified, the general learning rate "
"decay type is used."
),
)
parser.add_argument(
"--lr_decay_ref",
type=str,
Expand Down Expand Up @@ -1803,10 +1968,18 @@ def lac(
adaptive_temperature=args.adaptive_temperature,
lr_a=args.lr_a,
lr_c=args.lr_c,
lr_alpha=args.lr_alpha,
lr_labda=args.lr_labda,
lr_a_final=args.lr_a_final,
lr_c_final=args.lr_c_final,
lr_alpha_final=args.lr_alpha_final,
lr_labda_final=args.lr_labda_final,
lr_decay_type=args.lr_decay_type,
lr_decay_ref=args.lr_decay_ref,
lr_a_decay_type=args.lr_a_decay_type,
lr_c_decay_type=args.lr_c_decay_type,
lr_alpha_decay_type=args.lr_alpha_decay_type,
lr_labda_decay_type=args.lr_labda_decay_type,
batch_size=args.batch_size,
replay_size=args.replay_size,
horizon_length=args.horizon_length,
Expand Down
Loading