From 28450b5ebee600159055d046b14c4d7d882c83f3 Mon Sep 17 00:00:00 2001 From: Paul-Edouard Sarlin Date: Mon, 23 Oct 2023 13:16:50 +0200 Subject: [PATCH 1/2] Remove unused losses key in checkpoint --- gluefactory/train.py | 4 ---- gluefactory/utils/experiments.py | 2 -- 2 files changed, 6 deletions(-) diff --git a/gluefactory/train.py b/gluefactory/train.py index 12ad2075..98adab1c 100644 --- a/gluefactory/train.py +++ b/gluefactory/train.py @@ -341,7 +341,6 @@ def sigint_handler(signal, frame): logger.info( "Starting training with configuration:\n%s", OmegaConf.to_yaml(conf) ) - losses_ = None def trace_handler(p): # torch.profiler.tensorboard_trace_handler(str(output_dir)) @@ -551,7 +550,6 @@ def trace_handler(p): optimizer, lr_scheduler, conf, - losses_, results, best_eval, epoch, @@ -586,7 +584,6 @@ def trace_handler(p): optimizer, lr_scheduler, conf, - losses_, results, best_eval, epoch, @@ -605,7 +602,6 @@ def trace_handler(p): optimizer, lr_scheduler, conf, - losses_, results, best_eval, epoch, diff --git a/gluefactory/utils/experiments.py b/gluefactory/utils/experiments.py index 7723fcea..ae029fad 100644 --- a/gluefactory/utils/experiments.py +++ b/gluefactory/utils/experiments.py @@ -97,7 +97,6 @@ def save_experiment( optimizer, lr_scheduler, conf, - losses, results, best_eval, epoch, @@ -116,7 +115,6 @@ def save_experiment( "lr_scheduler": lr_scheduler.state_dict(), "conf": OmegaConf.to_container(conf, resolve=True), "epoch": epoch, - "losses": losses, "eval": results, } if cp_name is None: From 4c52c0373d4b9b3e0d1a9f2e44cffecc0a678233 Mon Sep 17 00:00:00 2001 From: Paul-Edouard Sarlin Date: Mon, 23 Oct 2023 13:21:54 +0200 Subject: [PATCH 2/2] Avoid unnecessary generation and storage of figures --- gluefactory/train.py | 61 +++++++++++++++++++++++++++----------------- 1 file changed, 37 insertions(+), 24 deletions(-) diff --git a/gluefactory/train.py b/gluefactory/train.py index 98adab1c..cb8499b3 100644 --- a/gluefactory/train.py +++ b/gluefactory/train.py @@ -78,12 +78,12 @@ @torch.no_grad() -def do_evaluation(model, loader, device, loss_fn, conf, pbar=True): +def do_evaluation(model, loader, device, loss_fn, conf, rank, pbar=True): model.eval() results = {} pr_metrics = defaultdict(PRMetric) figures = [] - if conf.plot is not None: + if conf.plot is not None and rank == 0: n, plot_fn = conf.plot plot_ids = np.random.choice(len(loader), min(len(loader), n), replace=False) for i, data in enumerate( @@ -120,7 +120,8 @@ def do_evaluation(model, loader, device, loss_fn, conf, pbar=True): results[k + f"_recall{int(q)}"].update(v) del numbers results = {k: results[k].compute() for k in results} - return results, {k: v.compute() for k, v in pr_metrics.items()}, figures + pr_metrics = {k: v.compute() for k, v in pr_metrics.items()} + return results, pr_metrics, figures def filter_parameters(params, regexp): @@ -184,6 +185,27 @@ def pack_lr_parameters(params, base_lr, lr_scaling): return lr_params +def write_dict_summaries(writer, name, items, step): + for k, v in items.items(): + key = f"{name}/{k}" + if isinstance(v, dict): + writer.add_scalars(key, v, step) + elif isinstance(v, tuple): + writer.add_pr_curve(key, *v, step) + else: + writer.add_scalar(key, v, step) + + +def write_image_summaries(writer, name, figures, step): + if isinstance(figures, list): + for i, figs in enumerate(figures): + for k, fig in figs.items(): + writer.add_figure(f"{name}/{i}_{k}", fig, step) + else: + for k, fig in figs.items(): + writer.add_figure(f"{name}/{k}", fig, step) + + def training(rank, conf, output_dir, args): if args.restore: logger.info(f"Restoring from previous training of {args.experiment}") @@ -370,17 +392,16 @@ def trace_handler(p): ): for bname, eval_conf in conf.get("benchmarks", {}).items(): logger.info(f"Running eval on {bname}") - s, f, r = run_benchmark( + results, figures, _ = run_benchmark( bname, eval_conf, EVAL_PATH / bname / args.experiment / str(epoch), model.eval(), ) - logger.info(str(s)) - for metric_name, value in s.items(): - writer.add_scalar(f"test/{bname}/{metric_name}", value, epoch) - for fig_name, fig in f.items(): - writer.add_figure(f"figures/{bname}/{fig_name}", fig, epoch) + logger.info(str(results)) + write_dict_summaries(writer, f"test/{bname}", results, epoch) + write_image_summaries(writer, f"figures/{bname}", figures, epoch) + del results, figures # set the seed set_seed(conf.train.seed + epoch) @@ -487,8 +508,7 @@ def trace_handler(p): epoch, it, ", ".join(str_losses) ) ) - for k, v in losses.items(): - writer.add_scalar("training/" + k, v, tot_n_samples) + write_dict_summaries(writer, "training/", losses, tot_n_samples) writer.add_scalar( "training/lr", optimizer.param_groups[0]["lr"], tot_n_samples ) @@ -525,6 +545,7 @@ def trace_handler(p): device, loss_fn, conf.train, + rank, pbar=(rank == -1), ) @@ -535,13 +556,9 @@ def trace_handler(p): if isinstance(v, float) ] logger.info(f'[Validation] {{{", ".join(str_results)}}}') - for k, v in results.items(): - if isinstance(v, dict): - writer.add_scalars(f"figure/val/{k}", v, tot_n_samples) - else: - writer.add_scalar("val/" + k, v, tot_n_samples) - for k, v in pr_metrics.items(): - writer.add_pr_curve("val/" + k, *v, tot_n_samples) + write_dict_summaries(writer, "val", results, tot_n_samples) + write_dict_summaries(writer, "val", pr_metrics, tot_n_samples) + write_image_summaries(writer, "figures", figures, tot_n_samples) # @TODO: optional always save checkpoint if results[conf.train.best_key] < best_eval: best_eval = results[conf.train.best_key] @@ -560,13 +577,8 @@ def trace_handler(p): cp_name="checkpoint_best.tar", ) logger.info(f"New best val: {conf.train.best_key}={best_eval}") - if len(figures) > 0: - for i, figs in enumerate(figures): - for name, fig in figs.items(): - writer.add_figure( - f"figures/{i}_{name}", fig, tot_n_samples - ) torch.cuda.empty_cache() # should be cleared at the first iter + del results, pr_metrics, figures if (tot_it % conf.train.save_every_iter == 0 and tot_it > 0) and rank == 0: if results is None: @@ -576,6 +588,7 @@ def trace_handler(p): device, loss_fn, conf.train, + rank, pbar=(rank == -1), ) best_eval = results[conf.train.best_key]