From 8d936101a280c068dfcf61c1d7919453ce1a0b3c Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Tue, 13 Feb 2024 16:48:26 +0100 Subject: [PATCH] fix(run): resolve issue with data_dir input argument (#409) This commit addresses a bug where the `--data_dir` input argument was not correctly utilized. The issue prevented the application from accessing the specified data directory. This fix ensures that the `--data_dir` argument is appropriately recognised and applied. --- stable_learning_control/algos/pytorch/lac/lac.py | 2 +- stable_learning_control/algos/pytorch/latc/latc.py | 2 +- stable_learning_control/algos/pytorch/sac/sac.py | 4 ++-- stable_learning_control/algos/tf2/lac/lac.py | 2 +- stable_learning_control/algos/tf2/latc/latc.py | 2 +- stable_learning_control/algos/tf2/sac/sac.py | 2 +- stable_learning_control/utils/run_utils.py | 3 +-- 7 files changed, 8 insertions(+), 9 deletions(-) diff --git a/stable_learning_control/algos/pytorch/lac/lac.py b/stable_learning_control/algos/pytorch/lac/lac.py index ad28d0bb..f430c88e 100644 --- a/stable_learning_control/algos/pytorch/lac/lac.py +++ b/stable_learning_control/algos/pytorch/lac/lac.py @@ -1870,7 +1870,7 @@ def lac( # Setup output dir for logger and return output kwargs. logger_kwargs = setup_logger_kwargs( args.exp_name, - args.seed, + seed=args.seed, save_checkpoints=args.save_checkpoints, use_tensorboard=args.use_tensorboard, tb_log_freq=args.tb_log_freq, diff --git a/stable_learning_control/algos/pytorch/latc/latc.py b/stable_learning_control/algos/pytorch/latc/latc.py index f747bc04..47e085c7 100644 --- a/stable_learning_control/algos/pytorch/latc/latc.py +++ b/stable_learning_control/algos/pytorch/latc/latc.py @@ -452,7 +452,7 @@ def latc(env_fn, actor_critic=None, *args, **kwargs): # Setup output dir for logger and return output kwargs. logger_kwargs = setup_logger_kwargs( args.exp_name, - args.seed, + seed=args.seed, save_checkpoints=args.save_checkpoints, use_tensorboard=args.use_tensorboard, tb_log_freq=args.tb_log_freq, diff --git a/stable_learning_control/algos/pytorch/sac/sac.py b/stable_learning_control/algos/pytorch/sac/sac.py index f5685f39..51ef26f3 100644 --- a/stable_learning_control/algos/pytorch/sac/sac.py +++ b/stable_learning_control/algos/pytorch/sac/sac.py @@ -1657,15 +1657,15 @@ def sac( # Setup output dir for logger and return output kwargs. logger_kwargs = setup_logger_kwargs( args.exp_name, - args.seed, + seed=args.seed, save_checkpoints=args.save_checkpoints, use_tensorboard=args.use_tensorboard, + tb_log_freq=args.tb_log_freq, use_wandb=args.use_wandb, wandb_job_type=args.wandb_job_type, wandb_project=args.wandb_project, wandb_group=args.wandb_group, wandb_run_name=args.wandb_run_name, - tb_log_freq=args.tb_log_freq, quiet=args.quiet, verbose_fmt=args.verbose_fmt, verbose_vars=args.verbose_vars, diff --git a/stable_learning_control/algos/tf2/lac/lac.py b/stable_learning_control/algos/tf2/lac/lac.py index 3be40cd8..02fadbf9 100644 --- a/stable_learning_control/algos/tf2/lac/lac.py +++ b/stable_learning_control/algos/tf2/lac/lac.py @@ -1769,7 +1769,7 @@ def lac( # Setup output dir for logger and return output kwargs. logger_kwargs = setup_logger_kwargs( args.exp_name, - args.seed, + seed=args.seed, save_checkpoints=args.save_checkpoints, use_tensorboard=args.use_tensorboard, tb_log_freq=args.tb_log_freq, diff --git a/stable_learning_control/algos/tf2/latc/latc.py b/stable_learning_control/algos/tf2/latc/latc.py index b1677ec7..04731ad5 100644 --- a/stable_learning_control/algos/tf2/latc/latc.py +++ b/stable_learning_control/algos/tf2/latc/latc.py @@ -454,7 +454,7 @@ def latc(env_fn, actor_critic=None, *args, **kwargs): # Setup output dir for logger and return output kwargs. logger_kwargs = setup_logger_kwargs( args.exp_name, - args.seed, + seed=args.seed, save_checkpoints=args.save_checkpoints, use_tensorboard=args.use_tensorboard, tb_log_freq=args.tb_log_freq, diff --git a/stable_learning_control/algos/tf2/sac/sac.py b/stable_learning_control/algos/tf2/sac/sac.py index 00e9cdc8..c49a7023 100644 --- a/stable_learning_control/algos/tf2/sac/sac.py +++ b/stable_learning_control/algos/tf2/sac/sac.py @@ -1573,7 +1573,7 @@ def sac( # Setup output dir for logger and return output kwargs. logger_kwargs = setup_logger_kwargs( args.exp_name, - args.seed, + seed=args.seed, save_checkpoints=args.save_checkpoints, use_tensorboard=args.use_tensorboard, tb_log_freq=args.tb_log_freq, diff --git a/stable_learning_control/utils/run_utils.py b/stable_learning_control/utils/run_utils.py index e94b78be..b36ee29d 100644 --- a/stable_learning_control/utils/run_utils.py +++ b/stable_learning_control/utils/run_utils.py @@ -87,9 +87,8 @@ def call_experiment( # Set up logger output directory. if "logger_kwargs" not in kwargs: kwargs["logger_kwargs"] = setup_logger_kwargs( - exp_name, seed, data_dir, datestamp + exp_name, seed=seed, data_dir=data_dir, datestamp=datestamp ) - else: print("Note: Call experiment is not handling logger_kwargs.\n") kwargs["logger_kwargs"] = setup_logger_kwargs(