Skip to content

Commit

Permalink
Avoid unnecessary generation and storage of figures
Browse files Browse the repository at this point in the history
  • Loading branch information
sarlinpe committed Oct 23, 2023
1 parent 28450b5 commit 4c52c03
Showing 1 changed file with 37 additions and 24 deletions.
61 changes: 37 additions & 24 deletions gluefactory/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,12 @@


@torch.no_grad()
def do_evaluation(model, loader, device, loss_fn, conf, pbar=True):
def do_evaluation(model, loader, device, loss_fn, conf, rank, pbar=True):
model.eval()
results = {}
pr_metrics = defaultdict(PRMetric)
figures = []
if conf.plot is not None:
if conf.plot is not None and rank == 0:
n, plot_fn = conf.plot
plot_ids = np.random.choice(len(loader), min(len(loader), n), replace=False)
for i, data in enumerate(
Expand Down Expand Up @@ -120,7 +120,8 @@ def do_evaluation(model, loader, device, loss_fn, conf, pbar=True):
results[k + f"_recall{int(q)}"].update(v)
del numbers
results = {k: results[k].compute() for k in results}
return results, {k: v.compute() for k, v in pr_metrics.items()}, figures
pr_metrics = {k: v.compute() for k, v in pr_metrics.items()}
return results, pr_metrics, figures


def filter_parameters(params, regexp):
Expand Down Expand Up @@ -184,6 +185,27 @@ def pack_lr_parameters(params, base_lr, lr_scaling):
return lr_params


def write_dict_summaries(writer, name, items, step):
for k, v in items.items():
key = f"{name}/{k}"
if isinstance(v, dict):
writer.add_scalars(key, v, step)
elif isinstance(v, tuple):
writer.add_pr_curve(key, *v, step)
else:
writer.add_scalar(key, v, step)


def write_image_summaries(writer, name, figures, step):
if isinstance(figures, list):
for i, figs in enumerate(figures):
for k, fig in figs.items():
writer.add_figure(f"{name}/{i}_{k}", fig, step)
else:
for k, fig in figs.items():
writer.add_figure(f"{name}/{k}", fig, step)


def training(rank, conf, output_dir, args):
if args.restore:
logger.info(f"Restoring from previous training of {args.experiment}")
Expand Down Expand Up @@ -370,17 +392,16 @@ def trace_handler(p):
):
for bname, eval_conf in conf.get("benchmarks", {}).items():
logger.info(f"Running eval on {bname}")
s, f, r = run_benchmark(
results, figures, _ = run_benchmark(
bname,
eval_conf,
EVAL_PATH / bname / args.experiment / str(epoch),
model.eval(),
)
logger.info(str(s))
for metric_name, value in s.items():
writer.add_scalar(f"test/{bname}/{metric_name}", value, epoch)
for fig_name, fig in f.items():
writer.add_figure(f"figures/{bname}/{fig_name}", fig, epoch)
logger.info(str(results))
write_dict_summaries(writer, f"test/{bname}", results, epoch)
write_image_summaries(writer, f"figures/{bname}", figures, epoch)
del results, figures

# set the seed
set_seed(conf.train.seed + epoch)
Expand Down Expand Up @@ -487,8 +508,7 @@ def trace_handler(p):
epoch, it, ", ".join(str_losses)
)
)
for k, v in losses.items():
writer.add_scalar("training/" + k, v, tot_n_samples)
write_dict_summaries(writer, "training/", losses, tot_n_samples)
writer.add_scalar(
"training/lr", optimizer.param_groups[0]["lr"], tot_n_samples
)
Expand Down Expand Up @@ -525,6 +545,7 @@ def trace_handler(p):
device,
loss_fn,
conf.train,
rank,
pbar=(rank == -1),
)

Expand All @@ -535,13 +556,9 @@ def trace_handler(p):
if isinstance(v, float)
]
logger.info(f'[Validation] {{{", ".join(str_results)}}}')
for k, v in results.items():
if isinstance(v, dict):
writer.add_scalars(f"figure/val/{k}", v, tot_n_samples)
else:
writer.add_scalar("val/" + k, v, tot_n_samples)
for k, v in pr_metrics.items():
writer.add_pr_curve("val/" + k, *v, tot_n_samples)
write_dict_summaries(writer, "val", results, tot_n_samples)
write_dict_summaries(writer, "val", pr_metrics, tot_n_samples)
write_image_summaries(writer, "figures", figures, tot_n_samples)
# @TODO: optional always save checkpoint
if results[conf.train.best_key] < best_eval:
best_eval = results[conf.train.best_key]
Expand All @@ -560,13 +577,8 @@ def trace_handler(p):
cp_name="checkpoint_best.tar",
)
logger.info(f"New best val: {conf.train.best_key}={best_eval}")
if len(figures) > 0:
for i, figs in enumerate(figures):
for name, fig in figs.items():
writer.add_figure(
f"figures/{i}_{name}", fig, tot_n_samples
)
torch.cuda.empty_cache() # should be cleared at the first iter
del results, pr_metrics, figures

if (tot_it % conf.train.save_every_iter == 0 and tot_it > 0) and rank == 0:
if results is None:
Expand All @@ -576,6 +588,7 @@ def trace_handler(p):
device,
loss_fn,
conf.train,
rank,
pbar=(rank == -1),
)
best_eval = results[conf.train.best_key]
Expand Down

0 comments on commit 4c52c03

Please sign in to comment.