Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(tf2): correct off-by-one error in learning rate decay calculation #415

Merged
merged 1 commit into from
Feb 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 7 additions & 15 deletions stable_learning_control/algos/tf2/lac/lac.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,7 +1101,9 @@ def lac(
# Calculate the number of learning rate scheduler steps.
if lr_decay_ref == "step":
# NOTE: Decay applied at policy update to improve performance.
lr_decay_steps = (total_steps - update_after) / update_every
lr_decay_steps = (
total_steps - update_after
) / update_every + 1 # NOTE: +1 since we start at the initial learning rate.
else:
lr_decay_steps = epochs

Expand All @@ -1110,16 +1112,6 @@ def lac(
lr_a_scheduler = get_lr_scheduler(lr_decay_type, lr_a, lr_a_final, lr_decay_steps)
lr_c_scheduler = get_lr_scheduler(lr_decay_type, lr_c, lr_c_final, lr_decay_steps)

# Create step based learning rate schedulers.
# NOTE: Used to estimate the learning rate at each step.
if lr_decay_ref == "step":
lr_a_step_scheduler = get_lr_scheduler(
lr_decay_type, lr_a, lr_a_final, lr_decay_steps + 1
)
lr_c_step_scheduler = get_lr_scheduler(
lr_decay_type, lr_c, lr_c_final, lr_decay_steps + 1
)

# Restore policy if supplied.
if start_policy is not None:
logger.log(f"Restoring model from '{start_policy}'.", type="info")
Expand Down Expand Up @@ -1303,10 +1295,10 @@ def lac(
# Retrieve current learning rates.
if lr_decay_ref == "step":
progress = max((t + 1) - update_after, 0) / update_every
lr_actor = lr_a_step_scheduler(progress)
lr_critic = lr_c_step_scheduler(progress)
lr_alpha = lr_a_step_scheduler(progress)
lr_labda = lr_a_step_scheduler(progress)
lr_actor = lr_a_scheduler(progress)
lr_critic = lr_c_scheduler(progress)
lr_alpha = lr_a_scheduler(progress)
lr_labda = lr_a_scheduler(progress)
else:
lr_actor = policy._pi_optimizer.lr.numpy()
lr_critic = policy._c_optimizer.lr.numpy()
Expand Down
20 changes: 6 additions & 14 deletions stable_learning_control/algos/tf2/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,7 +968,9 @@ def sac(
# Calculate the number of learning rate scheduler steps.
if lr_decay_ref == "step":
# NOTE: Decay applied at policy update to improve performance.
lr_decay_steps = (total_steps - update_after) / update_every
lr_decay_steps = (
total_steps - update_after
) / update_every + 1 # NOTE: +1 since we start at the initial learning rate.
else:
lr_decay_steps = epochs

Expand All @@ -977,16 +979,6 @@ def sac(
lr_a_scheduler = get_lr_scheduler(lr_decay_type, lr_a, lr_a_final, lr_decay_steps)
lr_c_scheduler = get_lr_scheduler(lr_decay_type, lr_c, lr_c_final, lr_decay_steps)

# Create step based learning rate schedulers.
# NOTE: Used to estimate the learning rate at each step.
if lr_decay_ref == "step":
lr_a_step_scheduler = get_lr_scheduler(
lr_decay_type, lr_a, lr_a_final, lr_decay_steps + 1
)
lr_c_step_scheduler = get_lr_scheduler(
lr_decay_type, lr_c, lr_c_final, lr_decay_steps + 1
)

# Restore policy if supplied.
if start_policy is not None:
logger.log(f"Restoring model from '{start_policy}'.", type="info")
Expand Down Expand Up @@ -1138,9 +1130,9 @@ def sac(
# Retrieve current learning rates.
if lr_decay_ref == "step":
progress = max((t + 1) - update_after, 0) / update_every
lr_actor = lr_a_step_scheduler(progress)
lr_critic = lr_c_step_scheduler(progress)
lr_alpha = lr_a_step_scheduler(progress)
lr_actor = lr_a_scheduler(progress)
lr_critic = lr_c_scheduler(progress)
lr_alpha = lr_a_scheduler(progress)
else:
lr_actor = policy._pi_optimizer.lr.numpy()
lr_critic = policy._c_optimizer.lr.numpy()
Expand Down