From ef5400d59bf9d8deb31dd346d41c210c79c00b78 Mon Sep 17 00:00:00 2001 From: Pete Date: Mon, 19 Jul 2021 15:37:33 -0700 Subject: [PATCH] make W&B callback resumable (#5312) Co-authored-by: Dirk Groeneveld --- CHANGELOG.md | 1 + allennlp/training/callbacks/wandb.py | 16 +++++++++++++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bb8d28ba95f..98a979de9dd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed the docs for `PytorchTransformerWrapper` - Fixed recovering training jobs with models that expect `get_metrics()` to not be called until they have seen at least one batch. - Made the Transformer Toolkit compatible with transformers that don't start their positional embeddings at 0. +- Weights & Biases training callback ("wandb") now works when resuming training jobs. ### Changed diff --git a/allennlp/training/callbacks/wandb.py b/allennlp/training/callbacks/wandb.py index b09301af62b..a1983d7bdcc 100644 --- a/allennlp/training/callbacks/wandb.py +++ b/allennlp/training/callbacks/wandb.py @@ -88,6 +88,7 @@ def __init__( self._watch_model = watch_model self._files_to_save = files_to_save + self._run_id: Optional[str] = None self._wandb_kwargs: Dict[str, Any] = dict( dir=os.path.abspath(serialization_dir), project=project, @@ -141,7 +142,11 @@ def on_start( import wandb self.wandb = wandb - self.wandb.init(**self._wandb_kwargs) + + if self._run_id is None: + self._run_id = self.wandb.util.generate_id() + + self.wandb.init(id=self._run_id, **self._wandb_kwargs) for fpath in self._files_to_save: self.wandb.save( # type: ignore @@ -155,3 +160,12 @@ def on_start( def close(self) -> None: super().close() self.wandb.finish() # type: ignore + + def state_dict(self) -> Dict[str, Any]: + return { + "run_id": self._run_id, + } + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + self._wandb_kwargs["resume"] = "auto" + self._run_id = state_dict["run_id"]