Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Fixes Checkpointing (#5220)
Browse files Browse the repository at this point in the history
* Removes unused variable

* Formatting

* Make sure we always restore the model's weights properly

* Give TrainerCallbacks the ability to save and load state dicts

* Give MovingAverage the ability to save and load state dicts

* Do not set gradients to None

* Typo

* Remove unused variable

* Typo

* Entirely new checkpointing code

* Formatting

* Make mypy happy

lol

* Makes the no-op trainer work with the new checkpointer

* Mark epochs as completed when they're skipped

* Changelog

* Fixes how we get the best weights after a training run

* Mypy is annoying

* Callback fixes

* Fix the no op trainer

* Simplify

* Assorted checkpointer fixes

* Mypy is now happy

* Fixed all the tests except for one

* Removed unused variable

* Fix trainer restore logic

* Fix test for trainer restore logic

* Check the Checkpointing branch of the models repo

* Help mypy along

* Fixed finalizing logic

* More mypy stuff

* Update allennlp/training/checkpointer.py

Co-authored-by: Pete <[email protected]>

* Make weaker claims

Co-authored-by: Pete <[email protected]>
  • Loading branch information
dirkgr and epwalsh authored May 29, 2021
1 parent 3d5799d commit c5bff8b
Show file tree
Hide file tree
Showing 21 changed files with 1,455 additions and 1,393 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ jobs:
run: |
git clone https://github.com/allenai/allennlp-models.git
cd allennlp-models
git checkout Checkpointing
pip install --upgrade --upgrade-strategy eager -e . -r dev-requirements.txt
- name: Run models tests
Expand Down
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
an actual `torch.nn.Module`. Other parameters to this method have changed as well.
- Print the first batch to the console by default.
- Renamed `sanity_checks` to `confidence_checks` (`sanity_checks` is deprecated and will be removed in AllenNLP 3.0).
- Trainer callbacks can now store and restore state in case a training run gets interrupted.
- VilBERT backbone now rolls and unrolls extra dimensions to handle input with > 3 dimensions.

### Added
Expand All @@ -41,6 +42,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- When `PretrainedTransformerIndexer` folds long sequences, it no longer loses the information from token type ids.
- Fixed documentation for `GradientDescentTrainer.cuda_device`.
- Re-starting a training run from a checkpoint in the middle of an epoch now works correctly.
- When using the "moving average" weights smoothing feature of the trainer, training checkpoints would also get smoothed, with strange results for resuming a training job. This has been fixed.
- When re-starting an interrupted training job, the trainer will now read out the data loader even for epochs and batches that can be skipped. We do this to try to get any random number generators used by the reader or data loader into the same state as they were the first time the training job ran.
- Fixed the potential for a race condition with `cached_path()` when extracting archives. Although the race condition
is still possible if used with `force_extract=True`.
- Fixed `wandb` callback to work in distributed training.
Expand Down
21 changes: 16 additions & 5 deletions allennlp/commands/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,11 +471,22 @@ def _train_worker(
except KeyboardInterrupt:
# if we have completed an epoch, try to create a model archive.
if primary and os.path.exists(os.path.join(serialization_dir, _DEFAULT_WEIGHTS)):
logging.info(
"Training interrupted by the user. Attempting to create "
"a model archive using the current best epoch weights."
)
archive_model(serialization_dir, include_in_archive=include_in_archive)
best_weights_path = train_loop.trainer.get_best_weights_path()
if best_weights_path is None:
logging.info(
"Training interrupted by the user, and no best model has been saved. "
"No model archive created."
)
else:
logging.info(
"Training interrupted by the user. Attempting to create "
"a model archive using the current best epoch weights."
)
archive_model(
serialization_dir,
weights=best_weights_path,
include_in_archive=include_in_archive,
)
raise

if primary:
Expand Down
7 changes: 6 additions & 1 deletion allennlp/models/archival.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Helper functions for archiving models and restoring archived models.
"""
from os import PathLike
from pathlib import Path
from typing import Tuple, NamedTuple, Union, Dict, Any, List, Optional
import logging
import os
Expand Down Expand Up @@ -130,7 +131,11 @@ def archive_model(
include_in_archive : `List[str]`, optional, (default = `None`)
Paths relative to `serialization_dir` that should be archived in addition to the default ones.
"""
weights_file = os.path.join(serialization_dir, weights)
extra_copy_of_weights_just_for_mypy = Path(weights)
if extra_copy_of_weights_just_for_mypy.is_absolute():
weights_file = extra_copy_of_weights_just_for_mypy
else:
weights_file = Path(serialization_dir) / extra_copy_of_weights_just_for_mypy
if not os.path.exists(weights_file):
logger.error("weights file %s does not exist, unable to archive model", weights_file)
return
Expand Down
6 changes: 2 additions & 4 deletions allennlp/training/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from allennlp.training.checkpointer import Checkpointer
from allennlp.training.no_op_trainer import NoOpTrainer
from allennlp.training.callbacks import TrainerCallback
from allennlp.training.trainer import (
Trainer,
GradientDescentTrainer,
)
from allennlp.training.trainer import Trainer
from allennlp.training.gradient_descent_trainer import GradientDescentTrainer
8 changes: 7 additions & 1 deletion allennlp/training/callbacks/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


if TYPE_CHECKING:
from allennlp.training.trainer import GradientDescentTrainer
from allennlp.training.gradient_descent_trainer import GradientDescentTrainer


class TrainerCallback(Registrable):
Expand Down Expand Up @@ -77,5 +77,11 @@ def on_end(
"""
pass

def state_dict(self) -> Dict[str, Any]:
return {}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
pass


TrainerCallback.register("null")(TrainerCallback)
2 changes: 1 addition & 1 deletion allennlp/training/callbacks/confidence_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


if TYPE_CHECKING:
from allennlp.training.trainer import GradientDescentTrainer
from allennlp.training.gradient_descent_trainer import GradientDescentTrainer


# `sanity_checks` is deprecated and will be removed.
Expand Down
2 changes: 1 addition & 1 deletion allennlp/training/callbacks/console_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from allennlp.data import TensorDict

if TYPE_CHECKING:
from allennlp.training.trainer import GradientDescentTrainer
from allennlp.training.gradient_descent_trainer import GradientDescentTrainer


logger = logging.getLogger(__name__)
Expand Down
10 changes: 6 additions & 4 deletions allennlp/training/callbacks/log_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from allennlp.training.util import get_train_and_validation_metrics, get_batch_size

if TYPE_CHECKING:
from allennlp.training.trainer import GradientDescentTrainer
from allennlp.training.gradient_descent_trainer import GradientDescentTrainer


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -289,15 +289,17 @@ def log_epoch(
)

def _should_log_distributions_next_batch(self) -> bool:
assert self.trainer is not None
return (
self._distribution_interval is not None
and (self.trainer._batch_num_total + 1) % self._distribution_interval == 0 # type: ignore[union-attr]
and (self.trainer._total_batches_completed + 1) % self._distribution_interval == 0
)

def _should_log_distributions_this_batch(self) -> bool:
assert self.trainer is not None
return (
self._distribution_interval is not None
and self.trainer._batch_num_total % self._distribution_interval == 0 # type: ignore[union-attr]
and self.trainer._total_batches_completed % self._distribution_interval == 0
)

def _enable_activation_logging(self) -> None:
Expand All @@ -318,7 +320,7 @@ def hook(module_, inputs, outputs):
self._module_hook_handles.append(module.register_forward_hook(hook))

def _should_log_this_batch(self) -> bool:
return self.trainer._batch_num_total % self._summary_interval == 0 # type: ignore[union-attr]
return self.trainer._total_batches_completed % self._summary_interval == 0 # type: ignore[union-attr]

def _log_activation_distribution(self, outputs: Any, module_name: str) -> None:
activations_to_log: Dict[str, torch.Tensor] = {}
Expand Down
6 changes: 4 additions & 2 deletions allennlp/training/callbacks/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def log_scalars(
log_prefix: str = "",
epoch: Optional[int] = None,
) -> None:
timestep = epoch if epoch is not None else self.trainer._batch_num_total # type: ignore[union-attr]
assert self.trainer is not None
timestep = epoch if epoch is not None else self.trainer._total_batches_completed
log = self._train_log if not log_prefix.startswith("validation") else self._validation_log
for key, value in scalars.items():
name = f"{log_prefix}/{key}" if log_prefix else key
Expand All @@ -59,7 +60,8 @@ def log_scalars(
def log_tensors(
self, tensors: Dict[str, torch.Tensor], log_prefix: str = "", epoch: Optional[int] = None
) -> None:
timestep = epoch if epoch is not None else self.trainer._batch_num_total # type: ignore[union-attr]
assert self.trainer is not None
timestep = epoch if epoch is not None else self.trainer._total_batches_completed
log = self._train_log if not log_prefix.startswith("validation") else self._validation_log
for key, values in tensors.items():
name = f"{log_prefix}/{key}" if log_prefix else key
Expand Down
2 changes: 1 addition & 1 deletion allennlp/training/callbacks/track_epoch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from allennlp.training.callbacks.callback import TrainerCallback

if TYPE_CHECKING:
from allennlp.training.trainer import GradientDescentTrainer
from allennlp.training.gradient_descent_trainer import GradientDescentTrainer


@TrainerCallback.register("track_epoch_callback")
Expand Down
4 changes: 2 additions & 2 deletions allennlp/training/callbacks/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


if TYPE_CHECKING:
from allennlp.training.trainer import GradientDescentTrainer
from allennlp.training.gradient_descent_trainer import GradientDescentTrainer


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -127,7 +127,7 @@ def _log(
dict_to_log = {f"{log_prefix}/{k}": v for k, v in dict_to_log.items()}
if epoch is not None:
dict_to_log["epoch"] = epoch
self.wandb.log(dict_to_log, step=self.trainer._batch_num_total) # type: ignore
self.wandb.log(dict_to_log, step=self.trainer._total_batches_completed) # type: ignore

@overrides
def on_start(
Expand Down
Loading

0 comments on commit c5bff8b

Please sign in to comment.