diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a53b1b6e12b..e0e3494a954 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -156,6 +156,7 @@ jobs: run: | git clone https://github.com/allenai/allennlp-models.git cd allennlp-models + git checkout Checkpointing pip install --upgrade --upgrade-strategy eager -e . -r dev-requirements.txt - name: Run models tests diff --git a/CHANGELOG.md b/CHANGELOG.md index deab8ed058a..281391a0e14 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 an actual `torch.nn.Module`. Other parameters to this method have changed as well. - Print the first batch to the console by default. - Renamed `sanity_checks` to `confidence_checks` (`sanity_checks` is deprecated and will be removed in AllenNLP 3.0). +- Trainer callbacks can now store and restore state in case a training run gets interrupted. - VilBERT backbone now rolls and unrolls extra dimensions to handle input with > 3 dimensions. ### Added @@ -41,6 +42,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - When `PretrainedTransformerIndexer` folds long sequences, it no longer loses the information from token type ids. - Fixed documentation for `GradientDescentTrainer.cuda_device`. +- Re-starting a training run from a checkpoint in the middle of an epoch now works correctly. +- When using the "moving average" weights smoothing feature of the trainer, training checkpoints would also get smoothed, with strange results for resuming a training job. This has been fixed. +- When re-starting an interrupted training job, the trainer will now read out the data loader even for epochs and batches that can be skipped. We do this to try to get any random number generators used by the reader or data loader into the same state as they were the first time the training job ran. - Fixed the potential for a race condition with `cached_path()` when extracting archives. Although the race condition is still possible if used with `force_extract=True`. - Fixed `wandb` callback to work in distributed training. diff --git a/allennlp/commands/train.py b/allennlp/commands/train.py index 5304e6e0735..c5be97c990f 100644 --- a/allennlp/commands/train.py +++ b/allennlp/commands/train.py @@ -471,11 +471,22 @@ def _train_worker( except KeyboardInterrupt: # if we have completed an epoch, try to create a model archive. if primary and os.path.exists(os.path.join(serialization_dir, _DEFAULT_WEIGHTS)): - logging.info( - "Training interrupted by the user. Attempting to create " - "a model archive using the current best epoch weights." - ) - archive_model(serialization_dir, include_in_archive=include_in_archive) + best_weights_path = train_loop.trainer.get_best_weights_path() + if best_weights_path is None: + logging.info( + "Training interrupted by the user, and no best model has been saved. " + "No model archive created." + ) + else: + logging.info( + "Training interrupted by the user. Attempting to create " + "a model archive using the current best epoch weights." + ) + archive_model( + serialization_dir, + weights=best_weights_path, + include_in_archive=include_in_archive, + ) raise if primary: diff --git a/allennlp/models/archival.py b/allennlp/models/archival.py index e1d48fcb76f..e49bd9dec6a 100644 --- a/allennlp/models/archival.py +++ b/allennlp/models/archival.py @@ -2,6 +2,7 @@ Helper functions for archiving models and restoring archived models. """ from os import PathLike +from pathlib import Path from typing import Tuple, NamedTuple, Union, Dict, Any, List, Optional import logging import os @@ -130,7 +131,11 @@ def archive_model( include_in_archive : `List[str]`, optional, (default = `None`) Paths relative to `serialization_dir` that should be archived in addition to the default ones. """ - weights_file = os.path.join(serialization_dir, weights) + extra_copy_of_weights_just_for_mypy = Path(weights) + if extra_copy_of_weights_just_for_mypy.is_absolute(): + weights_file = extra_copy_of_weights_just_for_mypy + else: + weights_file = Path(serialization_dir) / extra_copy_of_weights_just_for_mypy if not os.path.exists(weights_file): logger.error("weights file %s does not exist, unable to archive model", weights_file) return diff --git a/allennlp/training/__init__.py b/allennlp/training/__init__.py index c309005246c..cf95606636b 100644 --- a/allennlp/training/__init__.py +++ b/allennlp/training/__init__.py @@ -1,7 +1,5 @@ from allennlp.training.checkpointer import Checkpointer from allennlp.training.no_op_trainer import NoOpTrainer from allennlp.training.callbacks import TrainerCallback -from allennlp.training.trainer import ( - Trainer, - GradientDescentTrainer, -) +from allennlp.training.trainer import Trainer +from allennlp.training.gradient_descent_trainer import GradientDescentTrainer diff --git a/allennlp/training/callbacks/callback.py b/allennlp/training/callbacks/callback.py index 19c14cc0dc6..301e9cb4387 100644 --- a/allennlp/training/callbacks/callback.py +++ b/allennlp/training/callbacks/callback.py @@ -5,7 +5,7 @@ if TYPE_CHECKING: - from allennlp.training.trainer import GradientDescentTrainer + from allennlp.training.gradient_descent_trainer import GradientDescentTrainer class TrainerCallback(Registrable): @@ -77,5 +77,11 @@ def on_end( """ pass + def state_dict(self) -> Dict[str, Any]: + return {} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + pass + TrainerCallback.register("null")(TrainerCallback) diff --git a/allennlp/training/callbacks/confidence_checks.py b/allennlp/training/callbacks/confidence_checks.py index e57a0a0a626..584dcd137b4 100644 --- a/allennlp/training/callbacks/confidence_checks.py +++ b/allennlp/training/callbacks/confidence_checks.py @@ -6,7 +6,7 @@ if TYPE_CHECKING: - from allennlp.training.trainer import GradientDescentTrainer + from allennlp.training.gradient_descent_trainer import GradientDescentTrainer # `sanity_checks` is deprecated and will be removed. diff --git a/allennlp/training/callbacks/console_logger.py b/allennlp/training/callbacks/console_logger.py index 68565ed247a..768ba0d5936 100644 --- a/allennlp/training/callbacks/console_logger.py +++ b/allennlp/training/callbacks/console_logger.py @@ -8,7 +8,7 @@ from allennlp.data import TensorDict if TYPE_CHECKING: - from allennlp.training.trainer import GradientDescentTrainer + from allennlp.training.gradient_descent_trainer import GradientDescentTrainer logger = logging.getLogger(__name__) diff --git a/allennlp/training/callbacks/log_writer.py b/allennlp/training/callbacks/log_writer.py index 253b35de3df..8b3183f28c3 100644 --- a/allennlp/training/callbacks/log_writer.py +++ b/allennlp/training/callbacks/log_writer.py @@ -10,7 +10,7 @@ from allennlp.training.util import get_train_and_validation_metrics, get_batch_size if TYPE_CHECKING: - from allennlp.training.trainer import GradientDescentTrainer + from allennlp.training.gradient_descent_trainer import GradientDescentTrainer logger = logging.getLogger(__name__) @@ -289,15 +289,17 @@ def log_epoch( ) def _should_log_distributions_next_batch(self) -> bool: + assert self.trainer is not None return ( self._distribution_interval is not None - and (self.trainer._batch_num_total + 1) % self._distribution_interval == 0 # type: ignore[union-attr] + and (self.trainer._total_batches_completed + 1) % self._distribution_interval == 0 ) def _should_log_distributions_this_batch(self) -> bool: + assert self.trainer is not None return ( self._distribution_interval is not None - and self.trainer._batch_num_total % self._distribution_interval == 0 # type: ignore[union-attr] + and self.trainer._total_batches_completed % self._distribution_interval == 0 ) def _enable_activation_logging(self) -> None: @@ -318,7 +320,7 @@ def hook(module_, inputs, outputs): self._module_hook_handles.append(module.register_forward_hook(hook)) def _should_log_this_batch(self) -> bool: - return self.trainer._batch_num_total % self._summary_interval == 0 # type: ignore[union-attr] + return self.trainer._total_batches_completed % self._summary_interval == 0 # type: ignore[union-attr] def _log_activation_distribution(self, outputs: Any, module_name: str) -> None: activations_to_log: Dict[str, torch.Tensor] = {} diff --git a/allennlp/training/callbacks/tensorboard.py b/allennlp/training/callbacks/tensorboard.py index 0f6302dfcb4..73bc04a686a 100644 --- a/allennlp/training/callbacks/tensorboard.py +++ b/allennlp/training/callbacks/tensorboard.py @@ -49,7 +49,8 @@ def log_scalars( log_prefix: str = "", epoch: Optional[int] = None, ) -> None: - timestep = epoch if epoch is not None else self.trainer._batch_num_total # type: ignore[union-attr] + assert self.trainer is not None + timestep = epoch if epoch is not None else self.trainer._total_batches_completed log = self._train_log if not log_prefix.startswith("validation") else self._validation_log for key, value in scalars.items(): name = f"{log_prefix}/{key}" if log_prefix else key @@ -59,7 +60,8 @@ def log_scalars( def log_tensors( self, tensors: Dict[str, torch.Tensor], log_prefix: str = "", epoch: Optional[int] = None ) -> None: - timestep = epoch if epoch is not None else self.trainer._batch_num_total # type: ignore[union-attr] + assert self.trainer is not None + timestep = epoch if epoch is not None else self.trainer._total_batches_completed log = self._train_log if not log_prefix.startswith("validation") else self._validation_log for key, values in tensors.items(): name = f"{log_prefix}/{key}" if log_prefix else key diff --git a/allennlp/training/callbacks/track_epoch.py b/allennlp/training/callbacks/track_epoch.py index ea08459b390..b15da434248 100644 --- a/allennlp/training/callbacks/track_epoch.py +++ b/allennlp/training/callbacks/track_epoch.py @@ -3,7 +3,7 @@ from allennlp.training.callbacks.callback import TrainerCallback if TYPE_CHECKING: - from allennlp.training.trainer import GradientDescentTrainer + from allennlp.training.gradient_descent_trainer import GradientDescentTrainer @TrainerCallback.register("track_epoch_callback") diff --git a/allennlp/training/callbacks/wandb.py b/allennlp/training/callbacks/wandb.py index 5adc9f1520d..b09301af62b 100644 --- a/allennlp/training/callbacks/wandb.py +++ b/allennlp/training/callbacks/wandb.py @@ -11,7 +11,7 @@ if TYPE_CHECKING: - from allennlp.training.trainer import GradientDescentTrainer + from allennlp.training.gradient_descent_trainer import GradientDescentTrainer logger = logging.getLogger(__name__) @@ -127,7 +127,7 @@ def _log( dict_to_log = {f"{log_prefix}/{k}": v for k, v in dict_to_log.items()} if epoch is not None: dict_to_log["epoch"] = epoch - self.wandb.log(dict_to_log, step=self.trainer._batch_num_total) # type: ignore + self.wandb.log(dict_to_log, step=self.trainer._total_batches_completed) # type: ignore @overrides def on_start( diff --git a/allennlp/training/checkpointer.py b/allennlp/training/checkpointer.py index 95105a0820f..38d2692273d 100644 --- a/allennlp/training/checkpointer.py +++ b/allennlp/training/checkpointer.py @@ -1,5 +1,5 @@ import glob -from typing import Union, Dict, Any, List, Tuple, Optional +from typing import Dict, Any, Tuple, Optional, Set, Union import logging import os @@ -8,10 +8,9 @@ import torch -import allennlp from allennlp.common import Registrable from allennlp.nn import util as nn_util -from allennlp.training import util as training_util +from allennlp.training.trainer import Trainer logger = logging.getLogger(__name__) @@ -20,28 +19,26 @@ class Checkpointer(Registrable): """ This class implements the functionality for checkpointing your model and trainer state during training. It is agnostic as to what those states look like (they are typed as - Dict[str, Any]), but they will be fed to `torch.save` so they should be serializable - in that sense. They will also be restored as Dict[str, Any], which means the calling + `Dict[str, Any]`), but they will be fed to `torch.save` so they should be serializable + in that sense. They will also be restored as `Dict[str, Any]`, which means the calling code is responsible for knowing what to do with them. # Parameters - num_serialized_models_to_keep : `int`, optional (default=`2`) - Number of previous model checkpoints to retain. Default is to keep 2 checkpoints. - A value of None or -1 means all checkpoints will be kept. - - In a typical AllenNLP configuration file, this argument does not get an entry under the - "checkpointer", it gets passed in separately. - keep_serialized_model_every_num_seconds : `int`, optional (default=`None`) - If num_serialized_models_to_keep is not None, then occasionally it's useful to - save models at a given interval in addition to the last num_serialized_models_to_keep. - To do so, specify keep_serialized_model_every_num_seconds as the number of seconds - between permanently saved checkpoints. Note that this option is only used if - num_serialized_models_to_keep is not None, otherwise all checkpoints are kept. - model_save_interval : `float`, optional (default=`None`) - If provided, then serialize models every `model_save_interval` - seconds within single epochs. In all cases, models are also saved - at the end of every epoch if `serialization_dir` is provided. + save_completed_epochs : `bool`, (default=`True`) + Saves model and trainer state at the end of each completed epoch. + save_every_num_seconds : `int`, optional (default=`None`) + If set, makes sure we never go longer than this number of seconds between saving a model. + save_every_num_batches : `int`, optional (default=`None`) + If set, makes sure we never go longer than this number of batches between saving a model. + keep_most_recent_by_count : `int`, optional (default=`2`) + Sets the number of model checkpoints to keep on disk. If both `keep_most_recent_by_count` and + `keep_most_recent_by_age` are set, we'll keep checkpoints that satisfy either criterion. + If both are `None`, we keep all checkpoints. + keep_most_recent_by_age : `int`, optional (default=`None`) + Sets the number of seconds we'll keep a checkpoint before deleting it. If both + `keep_most_recent_by_count` and `keep_most_recent_by_age` are set, we'll keep checkpoints + that satisfy either criterion. If both are `None`, we keep all checkpoints. """ default_implementation = "default" @@ -49,183 +46,179 @@ class Checkpointer(Registrable): def __init__( self, serialization_dir: str, - keep_serialized_model_every_num_seconds: int = None, - num_serialized_models_to_keep: int = 2, - model_save_interval: float = None, + save_completed_epochs: bool = True, + save_every_num_seconds: Optional[float] = None, + save_every_num_batches: Optional[int] = None, + keep_most_recent_by_count: Optional[int] = 2, + keep_most_recent_by_age: Optional[int] = None, ) -> None: self._serialization_dir = serialization_dir - self._keep_serialized_model_every_num_seconds = keep_serialized_model_every_num_seconds - self._num_serialized_models_to_keep = num_serialized_models_to_keep - self._model_save_interval = model_save_interval - - self._last_permanent_saved_checkpoint_time = time.time() - self._serialized_paths: List[Tuple[float, str, str]] = [] + self._save_completed_epochs = save_completed_epochs + self._save_every_num_seconds = save_every_num_seconds + self._save_every_num_batches = save_every_num_batches + self._keep_most_recent_by_count = keep_most_recent_by_count + self._keep_most_recent_by_age = keep_most_recent_by_age self._last_save_time = time.time() + self._last_save_num_epochs_completed = 0 + self._last_save_num_batches_in_epoch_completed = 0 + + def _model_state_path(self, epochs_completed: int, batches_in_epoch_completed: int) -> str: + return os.path.join( + self._serialization_dir, + f"model_state_e{epochs_completed}_b{batches_in_epoch_completed}.th", + ) + + def _training_state_path(self, epochs_completed: int, batches_in_epoch_completed: int) -> str: + return os.path.join( + self._serialization_dir, + f"training_state_e{epochs_completed}_b{batches_in_epoch_completed}.th", + ) + + _model_state_file_re = re.compile(r"(.*/)?model_state_e(\d+)_b(\d+)\.th$") + _training_state_file_re = re.compile(r"(.*/)?training_state_e(\d+)_b(\d+)\.th$") + + @classmethod + def _parse_model_state_path(cls, path: Union[str, os.PathLike]) -> Optional[Tuple[int, int]]: + match = cls._model_state_file_re.match(str(path)) + if match is None: + return None + else: + try: + return int(match.group(2)), int(match.group(3)) + except ValueError: + return None + + @classmethod + def _parse_training_state_path(cls, path: Union[str, os.PathLike]) -> Optional[Tuple[int, int]]: + match = cls._training_state_file_re.match(str(path)) + if match is None: + return None + else: + try: + return int(match.group(2)), int(match.group(3)) + except ValueError: + return None + + def _find_all_checkpoints(self) -> Set[Tuple[int, int]]: + """Returns a set of integers, each of which is a number of batches that were completed at the + time a checkpoint wsa saved.""" + checkpoints = set() + for model_state_file in glob.iglob( + os.path.join(self._serialization_dir, "model_state_e*_b*.th") + ): + point_in_time = self._parse_model_state_path(model_state_file) + if point_in_time is None: + continue + else: + checkpoints.add(point_in_time) + return checkpoints def maybe_save_checkpoint( - self, trainer: "allennlp.training.trainer.Trainer", epoch: int, batches_this_epoch: int + self, + trainer: Trainer, + num_epochs_completed: int, + num_batches_in_epoch_completed: int, ) -> None: """ - Given amount of time lapsed between the last save and now (tracked internally), the - current epoch, and the number of batches seen so far this epoch, this method decides whether - to save a checkpoint or not. If we decide to save a checkpoint, we grab whatever state we - need out of the `Trainer` and save it. - - This function is intended to be called at the end of each batch in an epoch (perhaps because - your data is large enough that you don't really have "epochs"). The default implementation - only looks at time, not batch or epoch number, though those parameters are available to you - if you want to customize the behavior of this function. + Figures out whether we need to save a checkpoint, and does so if necessary. """ - if self._model_save_interval is None: - return - if time.time() - self._last_save_time < self._model_save_interval: - return - - self._last_save_time = time.time() - epoch_str = f"{epoch}.{training_util.time_to_str(int(self._last_save_time))}" - self.save_checkpoint(epoch_str, trainer) - - def shelve_model(self, epoch: Union[int, str], trainer: "allennlp.training.trainer.Trainer"): - if self._serialization_dir is None: - return - - # back up the model - with trainer.get_checkpoint_state() as state: - model_state, _ = state - model_backup_path = os.path.join( - self._serialization_dir, "model_state_backup_epoch_{}.th".format(epoch) + end_of_epoch = num_batches_in_epoch_completed == 0 + if num_epochs_completed == self._last_save_num_epochs_completed: + last_save_num_batches_in_epoch_completed = ( + self._last_save_num_batches_in_epoch_completed ) - torch.save(model_state, model_backup_path) + else: + last_save_num_batches_in_epoch_completed = 0 - def remove_shelved_models(self): - if self._serialization_dir is None: - return + should_save = ( + (end_of_epoch and self._save_completed_epochs) + or ( + self._save_every_num_seconds is not None + and (time.time() - self._last_save_time >= self._save_every_num_seconds) + ) + or ( + self._save_every_num_batches is not None + and ( + num_batches_in_epoch_completed - last_save_num_batches_in_epoch_completed + >= self._save_every_num_batches + ) + ) + ) - for old_model_backup_path in glob.glob( - os.path.join(self._serialization_dir, "model_state_backup_epoch_*.th") - ): - os.remove(old_model_backup_path) + if should_save: + self.save_checkpoint(trainer) def save_checkpoint( self, - epoch: Union[int, str], - trainer: "allennlp.training.trainer.Trainer", - is_best_so_far: bool = False, + trainer: Trainer, ) -> None: if self._serialization_dir is None: return - with trainer.get_checkpoint_state() as state: - model_state, training_states = state - model_path = os.path.join( - self._serialization_dir, "model_state_epoch_{}.th".format(epoch) - ) - if not os.path.isfile(model_path): - model_backup_path = os.path.join( - self._serialization_dir, "model_state_backup_epoch_{}.th".format(epoch) - ) - if os.path.isfile(model_backup_path): - os.rename(model_backup_path, model_path) - else: - torch.save(model_state, model_path) + tcps = trainer.get_checkpoint_state() + epochs_completed = tcps.trainer_state["epochs_completed"] + batches_in_epoch_completed = tcps.trainer_state["batches_in_epoch_completed"] - training_path = os.path.join( - self._serialization_dir, "training_state_epoch_{}.th".format(epoch) - ) - if not os.path.isfile(training_path): - torch.save({**training_states, "epoch": epoch}, training_path) + model_state_path = self._model_state_path(epochs_completed, batches_in_epoch_completed) + if not os.path.isfile(model_state_path): + torch.save(tcps.model_state, model_state_path) - # The main checkpointing logic is now done, this is just shuffling files around, to keep - # track of best weights, and to remove old checkpoints, if desired. - self.remove_shelved_models() + trainer_state_path = self._training_state_path(epochs_completed, batches_in_epoch_completed) + if not os.path.isfile(trainer_state_path): + torch.save(tcps.trainer_state, trainer_state_path) - if is_best_so_far: - logger.info( - "Best validation performance so far. Copying weights to '%s/best.th'.", - self._serialization_dir, - ) - dest_path = os.path.join(self._serialization_dir, "best.th") - if os.path.exists(dest_path): - os.remove(dest_path) - os.link(model_path, dest_path) - - if ( - self._num_serialized_models_to_keep is not None - and self._num_serialized_models_to_keep >= 0 - ): - self._serialized_paths.append((time.time(), model_path, training_path)) - if len(self._serialized_paths) > self._num_serialized_models_to_keep: - paths_to_remove = self._serialized_paths.pop(0) - # Check to see if we should keep this checkpoint, if it has been longer - # then self._keep_serialized_model_every_num_seconds since the last - # kept checkpoint. - remove_path = True - if self._keep_serialized_model_every_num_seconds is not None: - save_time = paths_to_remove[0] - time_since_checkpoint_kept = ( - save_time - self._last_permanent_saved_checkpoint_time + self._last_save_time = time.time() + self._last_save_num_epochs_completed = epochs_completed + self._last_save_num_batches_in_epoch_completed = batches_in_epoch_completed + + if self._keep_most_recent_by_age is not None or self._keep_most_recent_by_count is not None: + checkpoints = list(self._find_all_checkpoints()) + checkpoints.sort(reverse=True) + + # Keep the most recent n checkpoints + if self._keep_most_recent_by_count is not None: + checkpoints_to_keep = set(checkpoints[: self._keep_most_recent_by_count]) + else: + checkpoints_to_keep = set() + + # Keep the youngest checkpoints by age + now = time.time() + if self._keep_most_recent_by_age is not None: + for checkpoint in checkpoints: + checkpoint_mtime = max( + os.path.getmtime(n) + for n in [ + self._model_state_path(*checkpoint), + self._training_state_path(*checkpoint), + ] ) - if time_since_checkpoint_kept > self._keep_serialized_model_every_num_seconds: - # We want to keep this checkpoint. - remove_path = False - self._last_permanent_saved_checkpoint_time = save_time - if remove_path: - for fname in paths_to_remove[1:]: - if os.path.isfile(fname): - os.remove(fname) - - def find_latest_checkpoint(self) -> Optional[Tuple[str, str]]: + if now - checkpoint_mtime <= self._keep_most_recent_by_age: + checkpoints_to_keep.add(checkpoint) + + # Remove everything we're not keeping + for checkpoint in checkpoints: + if checkpoint not in checkpoints_to_keep: + os.remove(self._model_state_path(*checkpoint)) + os.remove(self._training_state_path(*checkpoint)) + + def _find_latest_checkpoint(self) -> Optional[Tuple[str, str]]: """ Return the location of the latest model and training state files. If there isn't a valid checkpoint then return None. """ - have_checkpoint = self._serialization_dir is not None and any( - "model_state_epoch_" in x for x in os.listdir(self._serialization_dir) - ) - - if not have_checkpoint: + checkpoints = self._find_all_checkpoints() + if len(checkpoints) <= 0: return None + last_checkpoint = max(checkpoints) + return self._model_state_path(*last_checkpoint), self._training_state_path(*last_checkpoint) - serialization_files = os.listdir(self._serialization_dir) - model_checkpoints = [x for x in serialization_files if "model_state_epoch" in x] - # Get the last checkpoint file. Epochs are specified as either an - # int (for end of epoch files) or with epoch and timestamp for - # within epoch checkpoints, e.g. 5.2018-02-02-15-33-42 - found_epochs = [ - re.search(r"model_state_epoch_([0-9\.\-]+)\.th", x).group(1) for x in model_checkpoints # type: ignore - ] - int_epochs: Any = [] - for epoch in found_epochs: - pieces = epoch.split(".") - if len(pieces) == 1: - # Just a single epoch without timestamp - int_epochs.append([int(pieces[0]), "0"]) - else: - # has a timestamp - int_epochs.append([int(pieces[0]), pieces[1]]) - last_epoch = sorted(int_epochs, reverse=True)[0] - if last_epoch[1] == "0": - epoch_to_load = str(last_epoch[0]) - else: - epoch_to_load = "{0}.{1}".format(last_epoch[0], last_epoch[1]) - - model_path = os.path.join( - self._serialization_dir, "model_state_epoch_{}.th".format(epoch_to_load) - ) - training_state_path = os.path.join( - self._serialization_dir, "training_state_epoch_{}.th".format(epoch_to_load) - ) - - return (model_path, training_state_path) - - def restore_checkpoint(self) -> Tuple[Dict[str, Any], Dict[str, Any]]: + def load_checkpoint(self) -> Tuple[Dict[str, Any], Dict[str, Any]]: """ - Restores a model from a serialization_dir to the last saved checkpoint. - This includes a training state (typically consisting of an epoch count and optimizer state), - which is serialized separately from model parameters. This function should only be used to - continue training - if you wish to load a model for inference/load parts of a model into a new - computation graph, you should use the native Pytorch functions: - ` model.load_state_dict(torch.load("/path/to/model/weights.th"))` + Loads model state from a `serialization_dir` corresponding to the last saved checkpoint. + This includes a training state, which is serialized separately from model parameters. This function + should only be used to continue training - if you wish to load a model for inference/load parts + of a model into a new computation graph, you should use the native Pytorch functions: + `model.load_state_dict(torch.load("/path/to/model/weights.th"))` If `self._serialization_dir` does not exist or does not contain any checkpointed weights, this function will do nothing and return empty dicts. @@ -235,12 +228,9 @@ def restore_checkpoint(self) -> Tuple[Dict[str, Any], Dict[str, Any]]: states : `Tuple[Dict[str, Any], Dict[str, Any]]` The model state and the training state. """ - latest_checkpoint = self.find_latest_checkpoint() - + latest_checkpoint = self._find_latest_checkpoint() if latest_checkpoint is None: - # No checkpoint to restore, start at 0 return {}, {} - model_path, training_state_path = latest_checkpoint # Load the parameters onto CPU, then transfer to GPU. @@ -251,17 +241,5 @@ def restore_checkpoint(self) -> Tuple[Dict[str, Any], Dict[str, Any]]: training_state = torch.load(training_state_path, map_location=nn_util.device_mapping(-1)) return model_state, training_state - def best_model_state(self) -> Dict[str, Any]: - if self._serialization_dir: - logger.info("loading best weights") - best_model_state_path = os.path.join(self._serialization_dir, "best.th") - return torch.load(best_model_state_path, map_location=nn_util.device_mapping(-1)) - else: - logger.info( - "cannot load best weights without `serialization_dir`, " - "so you're just getting the last weights" - ) - return {} - Checkpointer.register("default")(Checkpointer) diff --git a/allennlp/training/gradient_descent_trainer.py b/allennlp/training/gradient_descent_trainer.py new file mode 100644 index 00000000000..0e3f3cb0816 --- /dev/null +++ b/allennlp/training/gradient_descent_trainer.py @@ -0,0 +1,1072 @@ +import datetime +import logging +import math +import os +import re +import time +import warnings +from typing import Optional, Union, List, Dict, Tuple, Any, Type + +import torch +from torch.cuda import amp +from torch.nn.parallel import DistributedDataParallel +from torch.nn.utils import clip_grad_norm_ +import torch.distributed as dist + +from allennlp.common.checks import ConfigurationError, check_for_gpu +from allennlp.common import util as common_util, Tqdm, Lazy +from allennlp.data.data_loaders.data_loader import DataLoader, TensorDict +from allennlp.models.model import Model +from allennlp.training.callbacks import ConsoleLoggerCallback +from allennlp.training.callbacks.confidence_checks import ConfidenceChecksCallback +from allennlp.training.checkpointer import Checkpointer +from allennlp.training.learning_rate_schedulers.learning_rate_scheduler import LearningRateScheduler +from allennlp.training.metric_tracker import MetricTracker +from allennlp.training.momentum_schedulers.momentum_scheduler import MomentumScheduler +from allennlp.training.moving_average import MovingAverage +from allennlp.training.optimizers import Optimizer +from allennlp.training.trainer import Trainer, TrainerCheckpoint +from allennlp.training.callbacks import TrainerCallback +from allennlp.training import util as training_util + +logger = logging.getLogger(__name__) + + +@Trainer.register("gradient_descent", constructor="from_partial_objects") +class GradientDescentTrainer(Trainer): + """ + A trainer for doing supervised learning with gradient descent. It just takes a labeled dataset + and a `DataLoader`, and uses the supplied `Optimizer` to learn the weights for your model over + some fixed number of epochs. You can also pass in a validation data_loader and enable early + stopping. There are many other bells and whistles as well. + + Registered as a `Trainer` with the name "gradient_descent" (and is also the default `Trainer`). + The constructor that is registered is [`from_partial_objects`](#from_partial_objects) - + see the arguments to that function for the exact keys that should be used, if you are using + a configuration file. They largely match the arguments to `__init__`, and we don't repeat their + docstrings in `from_partial_objects`. + + [0]: https://tinyurl.com/y5mv44fw + + # Parameters + + model : `Model`, required. + An AllenNLP model to be optimized. Pytorch Modules can also be optimized if + their `forward` method returns a dictionary with a "loss" key, containing a + scalar tensor representing the loss function to be optimized. + + If you are training your model using GPUs, your model should already be + on the correct device. (If you are using our `train` command this will be + handled for you.) + + In a typical AllenNLP configuration file, this parameter does not get an entry under the + "trainer", it gets constructed separately. + + optimizer : `torch.nn.Optimizer`, required. + An instance of a Pytorch Optimizer, instantiated with the parameters of the + model to be optimized. + + data_loader : `DataLoader`, required. + A `DataLoader` containing your `Dataset`, yielding padded indexed batches. + + In a typical AllenNLP configuration file, this parameter does not get an entry under the + "trainer", it gets constructed separately. + + patience : `Optional[int] > 0`, optional (default=`None`) + Number of epochs to be patient before early stopping: the training is stopped + after `patience` epochs with no improvement. If given, it must be `> 0`. + If None, early stopping is disabled. + + validation_metric : `Union[str, List[str]]`, optional (default=`"-loss"`) + Validation metric to measure for whether to stop training using patience + and whether to serialize an `is_best` model each epoch. The metric name + must be prepended with either "+" or "-", which specifies whether the metric + is an increasing or decreasing function. If you specify more than one metric, + the metrics will be summed to make the `is_best` decision. + + validation_data_loader : `DataLoader`, optional (default=`None`) + A `DataLoader` to use for the validation set. If `None`, then + use the training `DataLoader` with the validation data. + + In a typical AllenNLP configuration file, this parameter does not get an entry under the + "trainer", it gets constructed separately. + + num_epochs : `int`, optional (default = `20`) + Number of training epochs. + + serialization_dir : `str`, optional (default=`None`) + Path to directory for saving and loading model files. Models will not be saved if + this parameter is not passed. + + In a typical AllenNLP configuration file, this parameter does not get an entry under the + "trainer", it gets constructed separately. + + checkpointer : `Checkpointer`, optional (default=`None`) + A `Checkpointer` is responsible for periodically saving model weights. If none is given + here, we will construct one with default parameters. + + cuda_device : `Optional[Union[int, torch.device]]`, optional (default = `None`) + An integer or `torch.device` specifying the CUDA device to use for this process. + If -1, the CPU is used. If `None` and you have a GPU available, that GPU will be used. + + !!! Note + If you *don't* intend to use a GPU, but you have one available, you'll need + to explicitly set `cuda_device=-1`. + + !!! Note + If you intend to use a GPU, your model already needs to be on the correct device, + which you can do with `model = model.cuda()`. + + !!! Note + Data parallelism is controlled at the allennlp train level, so each trainer will have a single GPU. + + grad_norm : `float`, optional, (default = `None`). + If provided, gradient norms will be rescaled to have a maximum of this value. + + grad_clipping : `float`, optional (default = `None`). + If provided, gradients will be clipped `during the backward pass` to have an (absolute) + maximum of this value. If you are getting `NaNs` in your gradients during training + that are not solved by using `grad_norm`, you may need this. + + learning_rate_scheduler : `LearningRateScheduler`, optional (default = `None`) + If specified, the learning rate will be decayed with respect to + this schedule at the end of each epoch (or batch, if the scheduler implements + the `step_batch` method). If you use `torch.optim.lr_scheduler.ReduceLROnPlateau`, + this will use the `validation_metric` provided to determine if learning has plateaued. + To support updating the learning rate on every batch, this can optionally implement + `step_batch(batch_num_total)` which updates the learning rate given the batch number. + + momentum_scheduler : `MomentumScheduler`, optional (default = `None`) + If specified, the momentum will be updated at the end of each batch or epoch + according to the schedule. + + moving_average : `MovingAverage`, optional, (default = `None`) + If provided, we will maintain moving averages for all parameters. During training, we + employ a shadow variable for each parameter, which maintains the moving average. During + evaluation, we backup the original parameters and assign the moving averages to corresponding + parameters. Be careful that when saving the checkpoint, we will save the moving averages of + parameters. This is necessary because we want the saved model to perform as well as the validated + model if we load it later. But this may cause problems if you restart the training from checkpoint. + + callbacks : `List[Lazy[TrainerCallback]]`, optional (default = `None`) + A list of callbacks that can be called at certain events: e.g. each batch, epoch, and at the start + and end of training, etc. + + distributed : `bool`, optional, (default = `False`) + If set, PyTorch's `DistributedDataParallel` is used to train the model in multiple GPUs. This also + requires `world_size` to be greater than 1. + + In a typical AllenNLP configuration file, this parameter does not get an entry under the + "trainer", it gets constructed separately (you need a top-level "distributed" key, next to + the "trainer" entry, that specifies a list of "cuda_devices"). + + local_rank : `int`, optional, (default = `0`) + This is the unique identifier of the `Trainer` in a distributed process group. The GPU device id is + used as the rank. + + In a typical AllenNLP configuration file, this parameter does not get an entry under the + "trainer", it gets constructed separately. + + world_size : `int`, (default = `1`) + The number of `Trainer` workers participating in the distributed training. + + In a typical AllenNLP configuration file, this parameter does not get an entry under the + "trainer", it gets constructed separately. + + num_gradient_accumulation_steps : `int`, optional, (default = `1`) + Gradients are accumulated for the given number of steps before doing an optimizer step. This can + be useful to accommodate batches that are larger than the RAM size. Refer [Thomas Wolf's + post][0] for details on Gradient Accumulation. + + use_amp : `bool`, optional, (default = `False`) + If `True`, we'll train using [Automatic Mixed Precision](https://pytorch.org/docs/stable/amp.html). + + enable_default_callbacks : `bool`, optional (default = `True`) + When `True`, the [`DEFAULT_CALLBACKS`](#default_callbacks) will be used in + addition to any other callbacks listed in the `callbacks` parameter. + When set to `False`, `DEFAULT_CALLBACKS` are not used. + + run_confidence_checks : `bool`, optional (default = `True`) + Determines whether model confidence checks, such as + [`NormalizationBiasVerification`](../../confidence_checks/normalization_bias_verification/), + are run. + + run_sanity_checks : `bool`, optional (default = `True`) + This parameter is deprecated. Please use `run_confidence_checks` instead. + + """ + + def __init__( + self, + model: Model, + optimizer: torch.optim.Optimizer, + data_loader: DataLoader, + patience: Optional[int] = None, + validation_metric: Union[str, List[str]] = "-loss", + validation_data_loader: DataLoader = None, + num_epochs: int = 20, + serialization_dir: Optional[str] = None, + checkpointer: Checkpointer = None, + cuda_device: Optional[Union[int, torch.device]] = None, + grad_norm: Optional[float] = None, + grad_clipping: Optional[float] = None, + learning_rate_scheduler: Optional[LearningRateScheduler] = None, + momentum_scheduler: Optional[MomentumScheduler] = None, + moving_average: Optional[MovingAverage] = None, + callbacks: List[TrainerCallback] = None, + distributed: bool = False, + local_rank: int = 0, + world_size: int = 1, + num_gradient_accumulation_steps: int = 1, + use_amp: bool = False, + enable_default_callbacks: bool = True, + run_confidence_checks: bool = True, + **kwargs, + ) -> None: + super().__init__( + serialization_dir=serialization_dir, + cuda_device=cuda_device, + distributed=distributed, + local_rank=local_rank, + world_size=world_size, + ) + + if "run_sanity_checks" in kwargs: + warnings.warn( + "'run_sanity_checks' is deprecated, please use 'run_confidence_checks' instead.", + DeprecationWarning, + ) + run_confidence_checks = kwargs["run_sanity_checks"] + + # I am not calling move_to_gpu here, because if the model is + # not already on the GPU then the optimizer is going to be wrong. + self.model = model + + self.data_loader = data_loader + self.data_loader.set_target_device(self.cuda_device) + self._validation_data_loader = validation_data_loader + if self._validation_data_loader is not None: + self._validation_data_loader.set_target_device(self.cuda_device) + self.optimizer = optimizer + + if patience is None: # no early stopping + if validation_data_loader is not None: + logger.warning( + "You provided a validation dataset but patience was set to None, " + "meaning that early stopping is disabled" + ) + elif (not isinstance(patience, int)) or patience <= 0: + raise ConfigurationError( + '{} is an invalid value for "patience": it must be a positive integer ' + "or None (if you want to disable early stopping)".format(patience) + ) + + # For tracking is_best_so_far and should_stop_early + self._metric_tracker = MetricTracker(validation_metric, patience) + + self._num_epochs = num_epochs + + self._checkpointer: Optional[Checkpointer] = checkpointer + if checkpointer is None and serialization_dir is not None: + self._checkpointer = Checkpointer(serialization_dir) + + self._grad_norm = grad_norm + self._grad_clipping = grad_clipping + + self._learning_rate_scheduler = learning_rate_scheduler + self._momentum_scheduler = momentum_scheduler + self._moving_average = moving_average + + self._callbacks = callbacks or [] + default_callbacks = list(DEFAULT_CALLBACKS) if enable_default_callbacks else [] + + if run_confidence_checks: + default_callbacks.append(ConfidenceChecksCallback) + for callback_cls in default_callbacks: + for callback in self._callbacks: + if callback.__class__ == callback_cls: + break + else: + self._callbacks.append(callback_cls(self._serialization_dir)) + + self._num_gradient_accumulation_steps = num_gradient_accumulation_steps + + # Enable automatic mixed precision training. + self._scaler: Optional[amp.GradScaler] = None + self._use_amp = use_amp + if self._use_amp: + if self.cuda_device == torch.device("cpu"): + raise ValueError("Using AMP requires a cuda device") + self._scaler = amp.GradScaler() + + # Using `DistributedDataParallel`(ddp) brings in a quirk wrt AllenNLP's `Model` interface and its + # usage. A `Model` object is wrapped by `ddp`, but assigning the wrapped model to `self.model` + # will break the usages such as `Model.get_regularization_penalty`, `Model.get_metrics`, etc. + # + # Hence a reference to Pytorch's object is maintained in the case of distributed training and in the + # normal case, reference to `Model` is retained. This reference is only used in + # these places: `model.__call__`, `model.train` and `model.eval`. + if self._distributed: + self._pytorch_model = DistributedDataParallel( + self.model, + device_ids=None if self.cuda_device == torch.device("cpu") else [self.cuda_device], + find_unused_parameters=True, + ) + else: + self._pytorch_model = self.model + + # training state management + self._epochs_completed: int = 0 + self._start_after_epochs_completed: int = 0 + self._batches_in_epoch_completed: int = 0 + self._start_after_batches_in_epoch_completed: int = 0 + self._best_model_filename: Optional[str] = None + + # This is a kind of training state, but it is not serialized with the trainer state, because we can + # re-create it with `epochs_completed` and `batches_in_epoch_completed`. + self._total_batches_completed: int = 0 + + def rescale_gradients(self) -> float: + """ + Performs gradient rescaling. Is a no-op if gradient rescaling is not enabled. + + Returns the norm of the gradients. + """ + parameters_to_clip = [p for p in self.model.parameters() if p.grad is not None] + if self._grad_norm: + if self._scaler is not None: + # Need to first unscale gradients in order to clip as usual. + self._scaler.unscale_(self.optimizer) + return clip_grad_norm_(parameters_to_clip, self._grad_norm) + else: + return torch.norm( + torch.stack([torch.norm(p.grad.detach()) for p in parameters_to_clip]) + ) + + def batch_outputs(self, batch: TensorDict, for_training: bool) -> Dict[str, torch.Tensor]: + """ + Does a forward pass on the given batch and returns the output dictionary that the model + returns, after adding any specified regularization penalty to the loss (if training). + """ + output_dict = self._pytorch_model(**batch) + + if for_training: + try: + assert "loss" in output_dict + regularization_penalty = self.model.get_regularization_penalty() + + if regularization_penalty is not None: + output_dict["reg_loss"] = regularization_penalty + output_dict["loss"] += regularization_penalty + + except AssertionError: + if for_training: + raise RuntimeError( + "The model you are trying to optimize does not contain a" + " 'loss' key in the output of model.forward(inputs)." + ) + + return output_dict + + def _train_epoch(self, epoch: int) -> Dict[str, float]: + """ + Trains one epoch and returns metrics. + """ + logger.info("Epoch %d/%d", epoch, self._num_epochs - 1) + cpu_memory_usage = [] + for worker, memory in common_util.peak_cpu_memory().items(): + cpu_memory_usage.append((worker, memory)) + logger.info(f"Worker {worker} memory usage: {common_util.format_size(memory)}") + gpu_memory_usage = [] + for gpu, memory in common_util.peak_gpu_memory().items(): + gpu_memory_usage.append((gpu, memory)) + logger.info(f"GPU {gpu} memory usage: {common_util.format_size(memory)}") + + regularization_penalty = self.model.get_regularization_penalty() + + train_loss = 0.0 + train_reg_loss = None if regularization_penalty is None else 0.0 + batch_reg_loss = None if regularization_penalty is None else 0.0 + + # Set the model to "train" mode. + self._pytorch_model.train() + + # Get tqdm for the training batches + batch_generator = iter(self.data_loader) + batch_group_generator = common_util.lazy_groups_of( + batch_generator, self._num_gradient_accumulation_steps + ) + + logger.info("Training") + + num_training_batches: Union[int, float] + try: + len_data_loader = len(self.data_loader) + num_training_batches = math.ceil( + len_data_loader / self._num_gradient_accumulation_steps + ) + except TypeError: + num_training_batches = float("inf") + + # Having multiple tqdm bars in case of distributed training will be a mess. Hence only the primary's + # progress is shown + if self._primary: + batch_group_generator_tqdm = Tqdm.tqdm( + batch_group_generator, total=num_training_batches + ) + else: + batch_group_generator_tqdm = batch_group_generator + + done_early = False + for batch_group in batch_group_generator_tqdm: + if done_early: + break + + if self._epochs_completed < self._start_after_epochs_completed or ( + self._epochs_completed == self._start_after_epochs_completed + and self._batches_in_epoch_completed < self._start_after_batches_in_epoch_completed + ): + self._batches_in_epoch_completed += 1 + self._total_batches_completed += 1 + continue + + self.optimizer.zero_grad() + + batch_loss = 0.0 + batch_group_outputs = [] + for batch in batch_group: + if self._distributed: + # Check whether the other workers have stopped already (due to differing amounts of + # data in each). If so, we can't proceed because we would hang when we hit the + # barrier implicit in Model.forward. We use a IntTensor instead a BoolTensor + # here because NCCL process groups apparently don't support BoolTensor. + done = torch.tensor(0, device=self.cuda_device) + torch.distributed.all_reduce(done, torch.distributed.ReduceOp.SUM) + if done.item() > 0: + done_early = True + logger.warning( + f"Worker {torch.distributed.get_rank()} finishing training early! " + "This implies that there is an imbalance in your training " + "data across the workers and that some amount of it will be " + "ignored. A small amount of this is fine, but a major imbalance " + "should be avoided. Note: This warning will appear unless your " + "data is perfectly balanced." + ) + break + + with amp.autocast(self._use_amp): + batch_outputs = self.batch_outputs(batch, for_training=True) + batch_group_outputs.append(batch_outputs) + loss = batch_outputs["loss"] + reg_loss = batch_outputs.get("reg_loss") + if torch.isnan(loss): + raise ValueError("nan loss encountered") + loss = loss / len(batch_group) + + batch_loss += loss.item() + if reg_loss is not None: + reg_loss = reg_loss / len(batch_group) + batch_reg_loss = reg_loss.item() + train_reg_loss += batch_reg_loss # type: ignore + + if self._scaler is not None: + self._scaler.scale(loss).backward() + else: + loss.backward() + if len(batch_group_outputs) <= 0: + continue + + train_loss += batch_loss + + batch_grad_norm = self.rescale_gradients() + + if self._learning_rate_scheduler: + self._learning_rate_scheduler.step_batch(self._total_batches_completed + 1) + if self._momentum_scheduler: + self._momentum_scheduler.step_batch(self._total_batches_completed + 1) + + if self._scaler is not None: + self._scaler.step(self.optimizer) + self._scaler.update() + else: + self.optimizer.step() + + # Update moving averages + if self._moving_average is not None: + self._moving_average.apply(self._total_batches_completed + 1) + + self._batches_in_epoch_completed += 1 + self._total_batches_completed += 1 + + # Update the description with the latest metrics + metrics = training_util.get_metrics( + self.model, + train_loss, + train_reg_loss, + batch_loss, + batch_reg_loss, + self._batches_in_epoch_completed, + world_size=self._world_size, + cuda_device=self.cuda_device, + ) + + for callback in self._callbacks: + callback.on_batch( + self, + batch_group, + batch_group_outputs, + metrics, + epoch, + self._batches_in_epoch_completed, + is_training=True, + is_primary=self._primary, + batch_grad_norm=batch_grad_norm, + ) + + if self._primary: + # Updating tqdm only for the primary as the trainers wouldn't have one + description = training_util.description_from_metrics(metrics) + batch_group_generator_tqdm.set_description(description, refresh=False) + + if self._checkpointer is not None: + self._checkpointer.maybe_save_checkpoint( + self, self._epochs_completed, self._batches_in_epoch_completed + ) + + if self._distributed and not done_early: + logger.warning( + f"Worker {torch.distributed.get_rank()} completed its entire epoch (training)." + ) + # Indicate that we're done so that any workers that have remaining data stop the epoch early. + done = torch.tensor(1, device=self.cuda_device) + torch.distributed.all_reduce(done, torch.distributed.ReduceOp.SUM) + assert done.item() + + # Let all workers finish their epoch before computing + # the final statistics for the epoch. + if self._distributed: + dist.barrier() + + metrics = training_util.get_metrics( + self.model, + train_loss, + train_reg_loss, + batch_loss=None, + batch_reg_loss=None, + num_batches=self._batches_in_epoch_completed, + reset=True, + world_size=self._world_size, + cuda_device=self.cuda_device, + ) + + for (worker, memory) in cpu_memory_usage: + metrics["worker_" + str(worker) + "_memory_MB"] = memory / (1024 * 1024) + for (gpu_num, memory) in gpu_memory_usage: + metrics["gpu_" + str(gpu_num) + "_memory_MB"] = memory / (1024 * 1024) + return metrics + + def _validation_loss(self, epoch: int) -> Tuple[float, Optional[float], int]: + """ + Computes the validation loss. Returns it and the number of batches. + """ + logger.info("Validating") + + self._pytorch_model.eval() + + # Replace parameter values with the shadow values from the moving averages. + if self._moving_average is not None: + self._moving_average.assign_average_value() + try: + if self._validation_data_loader is not None: + validation_data_loader = self._validation_data_loader + else: + raise ConfigurationError( + "Validation results cannot be calculated without a validation_data_loader" + ) + + regularization_penalty = self.model.get_regularization_penalty() + + # Having multiple tqdm bars in case of distributed training will be a mess. Hence only the primary's + # progress is shown + if self._primary: + val_generator_tqdm = Tqdm.tqdm(validation_data_loader) + else: + val_generator_tqdm = validation_data_loader + + batches_this_epoch = 0 + val_loss = 0.0 + val_batch_loss = 0.0 + val_reg_loss = None if regularization_penalty is None else 0.0 + val_batch_reg_loss = None if regularization_penalty is None else 0.0 + done_early = False + for batch in val_generator_tqdm: + if self._distributed: + # Check whether the other workers have stopped already (due to differing amounts of + # data in each). If so, we can't proceed because we would hang when we hit the + # barrier implicit in Model.forward. We use a IntTensor instead a BoolTensor + # here because NCCL process groups apparently don't support BoolTensor. + done = torch.tensor(0, device=self.cuda_device) + torch.distributed.all_reduce(done, torch.distributed.ReduceOp.SUM) + if done.item() > 0: + done_early = True + logger.warning( + f"Worker {torch.distributed.get_rank()} finishing validation early! " + "This implies that there is an imbalance in your validation " + "data across the workers and that some amount of it will be " + "ignored. A small amount of this is fine, but a major imbalance " + "should be avoided. Note: This warning will appear unless your " + "data is perfectly balanced." + ) + break + + with amp.autocast(self._use_amp): + batch_outputs = self.batch_outputs(batch, for_training=False) + loss = batch_outputs.get("loss") + reg_loss = batch_outputs.get("reg_loss") + if loss is not None: + # You shouldn't necessarily have to compute a loss for validation, so we allow for + # `loss` to be None. We need to be careful, though - `batches_this_epoch` is + # currently only used as the divisor for the loss function, so we can safely only + # count those batches for which we actually have a loss. If this variable ever + # gets used for something else, we might need to change things around a bit. + batches_this_epoch += 1 + val_batch_loss = loss.item() + val_loss += val_batch_loss + if reg_loss is not None: + val_batch_reg_loss = reg_loss.item() + val_reg_loss += val_batch_reg_loss # type: ignore + + # Update the description with the latest metrics + val_metrics = training_util.get_metrics( + self.model, + val_loss, + val_reg_loss, + val_batch_loss, + val_batch_reg_loss, + batches_this_epoch, + world_size=self._world_size, + cuda_device=self.cuda_device, + ) + + description = training_util.description_from_metrics(val_metrics) + if self._primary: + val_generator_tqdm.set_description(description, refresh=False) + + for callback in self._callbacks: + callback.on_batch( + self, + [batch], + [batch_outputs], + val_metrics, + epoch, + batches_this_epoch, + is_training=False, + is_primary=self._primary, + ) + + if self._distributed and not done_early: + logger.warning( + f"Worker {torch.distributed.get_rank()} completed its entire epoch (validation)." + ) + # Indicate that we're done so that any workers that have remaining data stop validation early. + done = torch.tensor(1, device=self.cuda_device) + torch.distributed.all_reduce(done, torch.distributed.ReduceOp.SUM) + assert done.item() + + return val_loss, val_reg_loss, batches_this_epoch + finally: + # Now restore the original parameter values. + if self._moving_average is not None: + self._moving_average.restore() + + def train(self) -> Dict[str, Any]: + """ + Trains the supplied model with the supplied parameters. + """ + try: + self._restore_checkpoint() + except RuntimeError as e: + configuration_error = ConfigurationError( + "Could not recover training from the checkpoint. Did you mean to output to " + "a different serialization directory or delete the existing serialization " + "directory?" + ) + configuration_error.__cause__ = e + raise configuration_error + + # Callbacks get their `on_start` call even when we're starting from a checkpoint. + for callback in self._callbacks: + callback.on_start(self, is_primary=self._primary) + + # Set default values in case of failure + epoch = None + metrics = None + + try: + metrics, epoch = self._try_train() + return metrics + finally: + for callback in self._callbacks: + callback.on_end(self, metrics=metrics, epoch=epoch, is_primary=self._primary) + + def _try_train(self) -> Tuple[Dict[str, Any], int]: + training_util.enable_gradient_clipping(self.model, self._grad_clipping) + + logger.info("Beginning training.") + + val_metrics: Dict[str, float] = {} + metrics: Dict[str, Any] = {} + training_start_time = None + + metrics["best_epoch"] = self._metric_tracker.best_epoch + for key, value in self._metric_tracker.best_epoch_metrics.items(): + metrics["best_validation_" + key] = value + + for epoch in range(self._num_epochs): + epoch_start_time = time.time() + train_metrics = self._train_epoch(epoch) + + if self._epochs_completed < self._start_after_epochs_completed: + # We're still catching up with the checkpoint, so we do nothing. + # Note that we have to call _train_epoch() even when we know the epoch is skipped. We have to + # read from the data loader, because the data loader and dataset readers might use randomness, + # and we have to make sure we consume exactly the same instances in exactly the same way every + # time we train, even when starting from a checkpoint, so that we update the randomness + # generators in the same way each time. + self._epochs_completed += 1 + self._batches_in_epoch_completed = 0 + continue + if training_start_time is None: + training_start_time = epoch_start_time + + # get peak of memory usage + for key, value in train_metrics.items(): + if key.startswith("gpu_") and key.endswith("_memory_MB"): + metrics["peak_" + key] = max(metrics.get("peak_" + key, 0), value) + elif key.startswith("worker_") and key.endswith("_memory_MB"): + metrics["peak_" + key] = max(metrics.get("peak_" + key, 0), value) + + this_epoch_val_metric: float = 0.0 + if self._validation_data_loader is not None: + with torch.no_grad(): + # We have a validation set, so compute all the metrics on it. + val_loss, val_reg_loss, num_batches = self._validation_loss(epoch) + + # It is safe again to wait till the validation is done. This is + # important to get the metrics right. + if self._distributed: + dist.barrier() + + val_metrics = training_util.get_metrics( + self.model, + val_loss, + val_reg_loss, + batch_loss=None, + batch_reg_loss=None, + num_batches=num_batches, + reset=True, + world_size=self._world_size, + cuda_device=self.cuda_device, + ) + + # Check validation metric for early stopping + this_epoch_val_metric = self._metric_tracker.combined_score(val_metrics) + self._metric_tracker.add_metrics(val_metrics) + + # Create overall metrics dict + training_elapsed_time = time.time() - training_start_time + metrics["training_duration"] = str(datetime.timedelta(seconds=training_elapsed_time)) + metrics["epoch"] = epoch + + for key, value in train_metrics.items(): + metrics["training_" + key] = value + for key, value in val_metrics.items(): + metrics["validation_" + key] = value + + if self._metric_tracker.is_best_so_far(): + # Update all the best_ metrics. + # (Otherwise they just stay the same as they were.) + metrics["best_epoch"] = epoch + for key, value in val_metrics.items(): + metrics["best_validation_" + key] = value + + self._metric_tracker.best_epoch_metrics = val_metrics + + if self._serialization_dir and self._primary: + common_util.dump_metrics( + os.path.join(self._serialization_dir, f"metrics_epoch_{epoch}.json"), + metrics, + ) + + # The Scheduler API is agnostic to whether your schedule requires a validation metric - + # if it doesn't, the validation metric passed here is ignored. + if self._learning_rate_scheduler: + self._learning_rate_scheduler.step(this_epoch_val_metric) + if self._momentum_scheduler: + self._momentum_scheduler.step(this_epoch_val_metric) + for callback in self._callbacks: + callback.on_epoch(self, metrics=metrics, epoch=epoch, is_primary=self._primary) + + self._epochs_completed += 1 + self._batches_in_epoch_completed = 0 + + # The checkpointer saves state from the learning rate scheduler, momentum scheduler, moving + # average, and callbacks, so we have to make sure those are updated before we save the + # checkpoint here. + if self._primary and self._checkpointer is not None: + self._checkpointer.maybe_save_checkpoint( + self, self._epochs_completed, self._batches_in_epoch_completed + ) + # Wait for the primary process to finish saving the checkpoint + if self._distributed: + dist.barrier() + + if self._primary and self._serialization_dir and self._metric_tracker.is_best_so_far(): + self._best_model_filename = os.path.join(self._serialization_dir, "best.th") + if self._moving_average is None: + torch.save(self.model.state_dict(), self._best_model_filename) + else: + self._moving_average.assign_average_value() + try: + torch.save(self.model.state_dict(), self._best_model_filename) + finally: + self._moving_average.restore() + # Wait for the primary process to finish saving the best + if self._distributed: + dist.barrier() + + epoch_elapsed_time = time.time() - epoch_start_time + logger.info("Epoch duration: %s", datetime.timedelta(seconds=epoch_elapsed_time)) + + if self._metric_tracker.should_stop_early(): + logger.info("Ran out of patience. Stopping training.") + break + + if epoch < self._num_epochs - 1: + time_per_epoch = training_elapsed_time / ( + (epoch + 1) - self._start_after_epochs_completed + ) + # Note: If the first non-skipped epoch is half skipped (because it was checkpointed half-way + # through), then this estimate is going to be optimistic. + estimated_time_remaining = ( + time_per_epoch * self._num_epochs + ) - training_elapsed_time + formatted_time = str(datetime.timedelta(seconds=int(estimated_time_remaining))) + logger.info("Estimated training time remaining: %s", formatted_time) + else: + epoch = self._num_epochs - 1 + + # Load the best model state before returning + if self._best_model_filename is None or self._metric_tracker.is_best_so_far(): + self._finalize_model() + else: + # The model we're loading here has already been finalized. + self.model.load_state_dict(torch.load(self._best_model_filename)) + + return metrics, epoch + + def _finalize_model(self) -> None: + """If we have a moving average, we have to finalize the model at the end of training.""" + if self._moving_average is not None: + self._moving_average.assign_average_value() + + def get_checkpoint_state(self) -> TrainerCheckpoint: + model_state = self.model.state_dict() + + # These are the training states we need to persist. + training_states = { + "version": 1, + "metric_tracker": self._metric_tracker.state_dict(), + "optimizer": self.optimizer.state_dict(), + "callbacks": [cb.state_dict() for cb in self._callbacks], + "epochs_completed": self._epochs_completed, + "batches_in_epoch_completed": self._batches_in_epoch_completed, + "best_model_filename": self._best_model_filename, + } + + # If we have any of these optional objects, we should persist them too. + if self._learning_rate_scheduler is not None: + training_states["learning_rate_scheduler"] = self._learning_rate_scheduler.state_dict() + if self._momentum_scheduler is not None: + training_states["momentum_scheduler"] = self._momentum_scheduler.state_dict() + if self._moving_average is not None: + training_states["moving_average"] = self._moving_average.state_dict() + + return TrainerCheckpoint(model_state, training_states) + + def _restore_checkpoint(self) -> None: + """ + Restores the model and training state from the last saved checkpoint. + This includes an epoch count and optimizer state, which is serialized separately + from model parameters. This function should only be used to continue training - + if you wish to load a model for inference/load parts of a model into a new + computation graph, you should use the native Pytorch functions: + `model.load_state_dict(torch.load("/path/to/model/weights.th"))` + + If `self._serialization_dir` does not exist or does not contain any checkpointed weights, + this function will do nothing. + """ + if self._checkpointer is None: + return + + model_state, training_state = self._checkpointer.load_checkpoint() + if len(model_state) <= 0 and len(training_state) <= 0: + self._start_after_epochs_completed = 0 + self._start_after_batches_in_epoch_completed = 0 + self._best_model_filename = None + return + if training_state["version"] != 1: + raise ValueError( + f"This version of {self.__class__.__name__} only supports checkpoints of version 1. " + f"Found version {training_state['version']}" + ) + + self.model.load_state_dict(model_state) + self._metric_tracker.load_state_dict(training_state["metric_tracker"]) + self.optimizer.load_state_dict(training_state["optimizer"]) + + for cb, state_dict in zip(self._callbacks, training_state["callbacks"]): + cb.load_state_dict(state_dict) + + if self._learning_rate_scheduler is not None: + self._learning_rate_scheduler.load_state_dict(training_state["learning_rate_scheduler"]) + if self._momentum_scheduler is not None: + self._momentum_scheduler.load_state_dict(training_state["momentum_scheduler"]) + if self._moving_average is not None: + self._moving_average.load_state_dict(training_state["moving_average"]) + + self._start_after_epochs_completed = training_state["epochs_completed"] + self._start_after_batches_in_epoch_completed = training_state["batches_in_epoch_completed"] + self._best_model_filename = training_state["best_model_filename"] + + @classmethod + def from_partial_objects( + cls, + model: Model, + serialization_dir: str, + data_loader: DataLoader, + validation_data_loader: DataLoader = None, + local_rank: int = 0, + patience: int = None, + validation_metric: Union[str, List[str]] = "-loss", + num_epochs: int = 20, + cuda_device: Optional[Union[int, torch.device]] = None, + grad_norm: float = None, + grad_clipping: float = None, + distributed: bool = False, + world_size: int = 1, + num_gradient_accumulation_steps: int = 1, + use_amp: bool = False, + no_grad: List[str] = None, + optimizer: Lazy[Optimizer] = Lazy(Optimizer.default), + learning_rate_scheduler: Lazy[LearningRateScheduler] = None, + momentum_scheduler: Lazy[MomentumScheduler] = None, + moving_average: Lazy[MovingAverage] = None, + checkpointer: Lazy[Checkpointer] = Lazy(Checkpointer), + callbacks: List[Lazy[TrainerCallback]] = None, + enable_default_callbacks: bool = True, + run_confidence_checks: bool = True, + **kwargs, + ) -> Trainer: + """ + This method exists so that we can have a documented method to construct this class using + `FromParams`. If you are not using `FromParams` or config files, you can safely ignore this + method. + + The reason we can't just use `__init__` with `FromParams` here is because there are + sequential dependencies to this class's arguments. Anything that has a `Lazy[]` type + annotation needs something from one of the non-`Lazy` arguments. The `Optimizer` needs to + have the parameters from the `Model` before it's constructed, and the `Schedulers` need to + have the `Optimizer`. Because of this, the typical way we construct things `FromParams` + doesn't work, so we use `Lazy` to allow for constructing the objects sequentially. + + If you're not using `FromParams`, you can just construct these arguments in the right order + yourself in your code and call the constructor directly. + """ + if cuda_device is None: + from torch import cuda + + if cuda.device_count() > 0: + cuda_device = 0 + else: + cuda_device = -1 + + check_for_gpu(cuda_device) + if cuda_device >= 0: + # Moving model to GPU here so that the optimizer state gets constructed on + # the right device. + model = model.cuda(cuda_device) + + if no_grad: + for name, parameter in model.named_parameters(): + if any(re.search(regex, name) for regex in no_grad): + parameter.requires_grad_(False) + + parameters = [[n, p] for n, p in model.named_parameters() if p.requires_grad] + optimizer_ = optimizer.construct(model_parameters=parameters) + + common_util.log_frozen_and_tunable_parameter_names(model) + + batches_per_epoch: Optional[int] + try: + batches_per_epoch = len(data_loader) + batches_per_epoch = math.ceil(batches_per_epoch / num_gradient_accumulation_steps) + except TypeError: + batches_per_epoch = None + + moving_average_ = ( + None if moving_average is None else moving_average.construct(parameters=parameters) + ) + learning_rate_scheduler_ = ( + None + if learning_rate_scheduler is None + else learning_rate_scheduler.construct( + optimizer=optimizer_, num_epochs=num_epochs, num_steps_per_epoch=batches_per_epoch + ) + ) + momentum_scheduler_ = ( + None + if momentum_scheduler is None + else momentum_scheduler.construct(optimizer=optimizer_) + ) + checkpointer_ = checkpointer.construct(serialization_dir=serialization_dir) + + callbacks_: List[TrainerCallback] = [] + for callback_ in callbacks or []: + callbacks_.append(callback_.construct(serialization_dir=serialization_dir)) + + return cls( + model, + optimizer_, + data_loader, + patience=patience, + validation_metric=validation_metric, + validation_data_loader=validation_data_loader, + num_epochs=num_epochs, + serialization_dir=serialization_dir, + cuda_device=cuda_device, + grad_norm=grad_norm, + grad_clipping=grad_clipping, + learning_rate_scheduler=learning_rate_scheduler_, + momentum_scheduler=momentum_scheduler_, + checkpointer=checkpointer_, + moving_average=moving_average_, + callbacks=callbacks_, + distributed=distributed, + local_rank=local_rank, + world_size=world_size, + num_gradient_accumulation_steps=num_gradient_accumulation_steps, + use_amp=use_amp, + enable_default_callbacks=enable_default_callbacks, + run_confidence_checks=run_confidence_checks, + **kwargs, + ) + + def get_best_weights_path(self) -> Optional[str]: + return self._best_model_filename + + +DEFAULT_CALLBACKS: Tuple[Type[TrainerCallback]] = (ConsoleLoggerCallback,) +""" +The default callbacks used by `GradientDescentTrainer`. +""" diff --git a/allennlp/training/moving_average.py b/allennlp/training/moving_average.py index 205eec973fa..4657e2d45dd 100644 --- a/allennlp/training/moving_average.py +++ b/allennlp/training/moving_average.py @@ -1,4 +1,4 @@ -from typing import Iterable, Tuple, Optional +from typing import Iterable, Tuple, Optional, Any, Dict import torch @@ -41,6 +41,14 @@ def restore(self) -> None: for name, parameter in self._parameters: parameter.data.copy_(self._backups[name]) + def state_dict(self) -> Dict[str, Any]: + return {"parameters": self._parameters, "shadows": self._shadows, "backups": self._backups} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + self._parameters = state_dict["parameters"] + self._shadows = state_dict["shadows"] + self._backups = state_dict["backups"] + @MovingAverage.register("exponential") class ExponentialMovingAverage(MovingAverage): diff --git a/allennlp/training/no_op_trainer.py b/allennlp/training/no_op_trainer.py index 93ee3542aec..47cec4323d0 100644 --- a/allennlp/training/no_op_trainer.py +++ b/allennlp/training/no_op_trainer.py @@ -1,10 +1,11 @@ import os -from contextlib import contextmanager -from typing import Any, Dict, Iterator, Tuple +from typing import Any, Dict, Optional + +import torch from allennlp.models import Model from allennlp.training.checkpointer import Checkpointer -from allennlp.training.trainer import Trainer +from allennlp.training.trainer import Trainer, TrainerCheckpoint @Trainer.register("no_op") @@ -24,14 +25,24 @@ def __init__(self, serialization_dir: str, model: Model) -> None: super().__init__(serialization_dir, cuda_device=-1) self.model = model + self._best_model_filename: Optional[str] = None def train(self) -> Dict[str, Any]: assert self._serialization_dir is not None self.model.vocab.save_to_files(os.path.join(self._serialization_dir, "vocabulary")) checkpointer = Checkpointer(self._serialization_dir) - checkpointer.save_checkpoint(epoch=0, trainer=self, is_best_so_far=True) + checkpointer.save_checkpoint(self) + + best_model_filename = os.path.join(self._serialization_dir, "best.th") + torch.save(self.model.state_dict(), best_model_filename) + self._best_model_filename = best_model_filename + return {} - @contextmanager - def get_checkpoint_state(self) -> Iterator[Tuple[Dict[str, Any], Dict[str, Any]]]: - yield self.model.state_dict(), {} + def get_checkpoint_state(self) -> TrainerCheckpoint: + return TrainerCheckpoint( + self.model.state_dict(), {"epochs_completed": 0, "batches_in_epoch_completed": 0} + ) + + def get_best_weights_path(self) -> Optional[str]: + return self._best_model_filename diff --git a/allennlp/training/scheduler.py b/allennlp/training/scheduler.py index 26b115b68ed..e9cad0bc9ca 100644 --- a/allennlp/training/scheduler.py +++ b/allennlp/training/scheduler.py @@ -79,5 +79,4 @@ def step_batch(self, batch_num_total: int = None) -> None: By default, a scheduler is assumed to only update every epoch, not every batch. So this does nothing unless it's overriden. """ - return diff --git a/allennlp/training/trainer.py b/allennlp/training/trainer.py index 54d9b59ffb1..797a8f382d2 100644 --- a/allennlp/training/trainer.py +++ b/allennlp/training/trainer.py @@ -1,44 +1,23 @@ -import datetime import logging -import math import os -import re -import time -import traceback -import warnings -from contextlib import contextmanager -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Type +from dataclasses import dataclass +from typing import Any, Dict, Optional, Union -from allennlp.common.util import int_to_device - -import torch -import torch.distributed as dist -from torch.cuda import amp import torch.optim.lr_scheduler -from torch.nn.parallel import DistributedDataParallel -from torch.nn.utils import clip_grad_norm_ -from allennlp.common import Lazy, Registrable, Tqdm -from allennlp.common import util as common_util +from allennlp.common import Registrable from allennlp.common.checks import ConfigurationError, check_for_gpu -from allennlp.data import DataLoader, TensorDict -from allennlp.models.model import Model -from allennlp.training import util as training_util -from allennlp.training.callbacks import ( - TrainerCallback, - ConfidenceChecksCallback, - ConsoleLoggerCallback, -) -from allennlp.training.checkpointer import Checkpointer -from allennlp.training.learning_rate_schedulers import LearningRateScheduler -from allennlp.training.metric_tracker import MetricTracker -from allennlp.training.momentum_schedulers import MomentumScheduler -from allennlp.training.moving_average import MovingAverage -from allennlp.training.optimizers import Optimizer +from allennlp.common.util import int_to_device logger = logging.getLogger(__name__) +@dataclass +class TrainerCheckpoint: + model_state: Dict[str, Any] + trainer_state: Dict[str, Any] + + class Trainer(Registrable): """ The base class for an AllenNLP trainer. It can do pretty much @@ -77,7 +56,7 @@ def __init__( if isinstance(cuda_device, list): raise ConfigurationError( - "In allennlp 1.0, the Trainer can only be assigned a single `cuda_device`. " + "In AllenNLP 1.0, the Trainer can only be assigned a single `cuda_device`. " "Instead, we use torch's DistributedDataParallel at the command level, meaning " "our Trainer always uses a single GPU per process." ) @@ -101,1035 +80,13 @@ def train(self) -> Dict[str, Any]: """ raise NotImplementedError - @contextmanager - def get_checkpoint_state(self) -> Iterator[Tuple[Dict[str, Any], Dict[str, Any]]]: + def get_checkpoint_state(self) -> TrainerCheckpoint: """ Returns a tuple of (model state, training state), where training state could have several internal components (e.g., for an, optimizer, learning rate scheduler, etc.). - - This is a context manager, and should be called as `with trainer.get_checkpoint_state() as - state:`, so that the trainer has the opportunity to change and restore its internal state - for checkpointing. This is used, e.g., for moving averages of model weights. """ raise NotImplementedError - -@Trainer.register("gradient_descent", constructor="from_partial_objects") -class GradientDescentTrainer(Trainer): - """ - A trainer for doing supervised learning with gradient descent. It just takes a labeled dataset - and a `DataLoader`, and uses the supplied `Optimizer` to learn the weights for your model over - some fixed number of epochs. You can also pass in a validation data_loader and enable early - stopping. There are many other bells and whistles as well. - - Registered as a `Trainer` with the name "gradient_descent" (and is also the default `Trainer`). - The constructor that is registered is [`from_partial_objects`](#from_partial_objects) - - see the arguments to that function for the exact keys that should be used, if you are using - a configuration file. They largely match the arguments to `__init__`, and we don't repeat their - docstrings in `from_partial_objects`. - - [0]: https://tinyurl.com/y5mv44fw - - # Parameters - - model : `Model`, required. - An AllenNLP model to be optimized. Pytorch Modules can also be optimized if - their `forward` method returns a dictionary with a "loss" key, containing a - scalar tensor representing the loss function to be optimized. - - If you are training your model using GPUs, your model should already be - on the correct device. (If you are using our `train` command this will be - handled for you.) - - In a typical AllenNLP configuration file, this parameter does not get an entry under the - "trainer", it gets constructed separately. - - optimizer : `torch.nn.Optimizer`, required. - An instance of a Pytorch Optimizer, instantiated with the parameters of the - model to be optimized. - - data_loader : `DataLoader`, required. - A `DataLoader` containing your `Dataset`, yielding padded indexed batches. - - In a typical AllenNLP configuration file, this parameter does not get an entry under the - "trainer", it gets constructed separately. - - patience : `Optional[int] > 0`, optional (default=`None`) - Number of epochs to be patient before early stopping: the training is stopped - after `patience` epochs with no improvement. If given, it must be `> 0`. - If None, early stopping is disabled. - - validation_metric : `Union[str, List[str]]`, optional (default=`"-loss"`) - Validation metric to measure for whether to stop training using patience - and whether to serialize an `is_best` model each epoch. The metric name - must be prepended with either "+" or "-", which specifies whether the metric - is an increasing or decreasing function. If you specify more than one metric, - the metrics will be summed to make the `is_best` decision. - - validation_data_loader : `DataLoader`, optional (default=`None`) - A `DataLoader` to use for the validation set. If `None`, then - use the training `DataLoader` with the validation data. - - In a typical AllenNLP configuration file, this parameter does not get an entry under the - "trainer", it gets constructed separately. - - num_epochs : `int`, optional (default = `20`) - Number of training epochs. - - serialization_dir : `str`, optional (default=`None`) - Path to directory for saving and loading model files. Models will not be saved if - this parameter is not passed. - - In a typical AllenNLP configuration file, this parameter does not get an entry under the - "trainer", it gets constructed separately. - - checkpointer : `Checkpointer`, optional (default=`None`) - A `Checkpointer` is responsible for periodically saving model weights. If none is given - here, we will construct one with default parameters. - - cuda_device : `Optional[Union[int, torch.device]]`, optional (default = `None`) - An integer or `torch.device` specifying the CUDA device to use for this process. - If -1, the CPU is used. If `None` and you have a GPU available, that GPU will be used. - - !!! Note - If you *don't* intend to use a GPU, but you have one available, you'll need - to explicitly set `cuda_device=-1`. - - !!! Note - If you intend to use a GPU, your model already needs to be on the correct device, - which you can do with `model = model.cuda()`. - - !!! Note - Data parallelism is controlled at the allennlp train level, so each trainer will have a single GPU. - - grad_norm : `float`, optional, (default = `None`). - If provided, gradient norms will be rescaled to have a maximum of this value. - - grad_clipping : `float`, optional (default = `None`). - If provided, gradients will be clipped `during the backward pass` to have an (absolute) - maximum of this value. If you are getting `NaNs` in your gradients during training - that are not solved by using `grad_norm`, you may need this. - - learning_rate_scheduler : `LearningRateScheduler`, optional (default = `None`) - If specified, the learning rate will be decayed with respect to - this schedule at the end of each epoch (or batch, if the scheduler implements - the `step_batch` method). If you use `torch.optim.lr_scheduler.ReduceLROnPlateau`, - this will use the `validation_metric` provided to determine if learning has plateaued. - To support updating the learning rate on every batch, this can optionally implement - `step_batch(batch_num_total)` which updates the learning rate given the batch number. - - momentum_scheduler : `MomentumScheduler`, optional (default = `None`) - If specified, the momentum will be updated at the end of each batch or epoch - according to the schedule. - - moving_average : `MovingAverage`, optional, (default = `None`) - If provided, we will maintain moving averages for all parameters. During training, we - employ a shadow variable for each parameter, which maintains the moving average. During - evaluation, we backup the original parameters and assign the moving averages to corresponding - parameters. Be careful that when saving the checkpoint, we will save the moving averages of - parameters. This is necessary because we want the saved model to perform as well as the validated - model if we load it later. But this may cause problems if you restart the training from checkpoint. - - callbacks : `List[Lazy[TrainerCallback]]`, optional (default = `None`) - A list of callbacks that can be called at certain events: e.g. each batch, epoch, and at the start - and end of training, etc. - - distributed : `bool`, optional, (default = `False`) - If set, PyTorch's `DistributedDataParallel` is used to train the model in multiple GPUs. This also - requires `world_size` to be greater than 1. - - In a typical AllenNLP configuration file, this parameter does not get an entry under the - "trainer", it gets constructed separately (you need a top-level "distributed" key, next to - the "trainer" entry, that specifies a list of "cuda_devices"). - - local_rank : `int`, optional, (default = `0`) - This is the unique identifier of the `Trainer` in a distributed process group. The GPU device id is - used as the rank. - - In a typical AllenNLP configuration file, this parameter does not get an entry under the - "trainer", it gets constructed separately. - - world_size : `int`, (default = `1`) - The number of `Trainer` workers participating in the distributed training. - - In a typical AllenNLP configuration file, this parameter does not get an entry under the - "trainer", it gets constructed separately. - - num_gradient_accumulation_steps : `int`, optional, (default = `1`) - Gradients are accumulated for the given number of steps before doing an optimizer step. This can - be useful to accommodate batches that are larger than the RAM size. Refer [Thomas Wolf's - post][0] for details on Gradient Accumulation. - - use_amp : `bool`, optional, (default = `False`) - If `True`, we'll train using [Automatic Mixed Precision](https://pytorch.org/docs/stable/amp.html). - - enable_default_callbacks : `bool`, optional (default = `True`) - When `True`, the [`DEFAULT_CALLBACKS`](#default_callbacks) will be used in - addition to any other callbacks listed in the `callbacks` parameter. - When set to `False`, `DEFAULT_CALLBACKS` are not used. - - run_confidence_checks : `bool`, optional (default = `True`) - Determines whether model confidence checks, such as - [`NormalizationBiasVerification`](../../confidence_checks/normalization_bias_verification/), - are run. - - run_sanity_checks : `bool`, optional (default = `True`) - This parameter is deprecated. Please use `run_confidence_checks` instead. - - """ - - def __init__( - self, - model: Model, - optimizer: torch.optim.Optimizer, - data_loader: DataLoader, - patience: Optional[int] = None, - validation_metric: Union[str, List[str]] = "-loss", - validation_data_loader: DataLoader = None, - num_epochs: int = 20, - serialization_dir: Optional[str] = None, - checkpointer: Checkpointer = None, - cuda_device: Optional[Union[int, torch.device]] = None, - grad_norm: Optional[float] = None, - grad_clipping: Optional[float] = None, - learning_rate_scheduler: Optional[LearningRateScheduler] = None, - momentum_scheduler: Optional[MomentumScheduler] = None, - moving_average: Optional[MovingAverage] = None, - callbacks: List[TrainerCallback] = None, - distributed: bool = False, - local_rank: int = 0, - world_size: int = 1, - num_gradient_accumulation_steps: int = 1, - use_amp: bool = False, - enable_default_callbacks: bool = True, - run_confidence_checks: bool = True, - **kwargs, - ) -> None: - super().__init__( - serialization_dir=serialization_dir, - cuda_device=cuda_device, - distributed=distributed, - local_rank=local_rank, - world_size=world_size, - ) - - if "run_sanity_checks" in kwargs: - warnings.warn( - "'run_sanity_checks' is deprecated, please use 'run_confidence_checks' instead.", - DeprecationWarning, - ) - run_confidence_checks = kwargs["run_sanity_checks"] - - # I am not calling move_to_gpu here, because if the model is - # not already on the GPU then the optimizer is going to be wrong. - self.model = model - - self.data_loader = data_loader - self.data_loader.set_target_device(self.cuda_device) - self._validation_data_loader = validation_data_loader - if self._validation_data_loader is not None: - self._validation_data_loader.set_target_device(self.cuda_device) - self.optimizer = optimizer - - if patience is None: # no early stopping - if validation_data_loader is not None: - logger.warning( - "You provided a validation dataset but patience was set to None, " - "meaning that early stopping is disabled" - ) - elif (not isinstance(patience, int)) or patience <= 0: - raise ConfigurationError( - '{} is an invalid value for "patience": it must be a positive integer ' - "or None (if you want to disable early stopping)".format(patience) - ) - - # For tracking is_best_so_far and should_stop_early - self._metric_tracker = MetricTracker(validation_metric, patience) - - self._num_epochs = num_epochs - - self._checkpointer: Optional[Checkpointer] = checkpointer - if checkpointer is None and serialization_dir is not None: - self._checkpointer = Checkpointer(serialization_dir) - - self._grad_norm = grad_norm - self._grad_clipping = grad_clipping - - self._learning_rate_scheduler = learning_rate_scheduler - self._momentum_scheduler = momentum_scheduler - self._moving_average = moving_average - - self._callbacks = callbacks or [] - default_callbacks = list(DEFAULT_CALLBACKS) if enable_default_callbacks else [] - - if run_confidence_checks: - default_callbacks.append(ConfidenceChecksCallback) - for callback_cls in default_callbacks: - for callback in self._callbacks: - if callback.__class__ == callback_cls: - break - else: - self._callbacks.append(callback_cls(self._serialization_dir)) - - self._batch_num_total = 0 - self._last_log = 0.0 # time of last logging - self._num_gradient_accumulation_steps = num_gradient_accumulation_steps - - # Enable automatic mixed precision training. - self._scaler: Optional[amp.GradScaler] = None - self._use_amp = use_amp - if self._use_amp: - if self.cuda_device == torch.device("cpu"): - raise ValueError("Using AMP requires a cuda device") - self._scaler = amp.GradScaler() - - # Using `DistributedDataParallel`(ddp) brings in a quirk wrt AllenNLP's `Model` interface and its - # usage. A `Model` object is wrapped by `ddp`, but assigning the wrapped model to `self.model` - # will break the usages such as `Model.get_regularization_penalty`, `Model.get_metrics`, etc. - # - # Hence a reference to Pytorch's object is maintained in the case of distributed training and in the - # normal case, reference to `Model` is retained. This reference is only used in - # these places: `model.__call__`, `model.train` and `model.eval`. - if self._distributed: - self._pytorch_model = DistributedDataParallel( - self.model, - device_ids=None if self.cuda_device == torch.device("cpu") else [self.cuda_device], - find_unused_parameters=True, - ) - else: - self._pytorch_model = self.model - - def rescale_gradients(self) -> float: - """ - Performs gradient rescaling. Is a no-op if gradient rescaling is not enabled. - - Returns the norm of the gradients. - """ - parameters_to_clip = [p for p in self.model.parameters() if p.grad is not None] - if self._grad_norm: - if self._scaler is not None: - # Need to first unscale gradients in order to clip as usual. - self._scaler.unscale_(self.optimizer) - return clip_grad_norm_(parameters_to_clip, self._grad_norm) - else: - return torch.norm( - torch.stack([torch.norm(p.grad.detach()) for p in parameters_to_clip]) - ) - - def batch_outputs(self, batch: TensorDict, for_training: bool) -> Dict[str, torch.Tensor]: - """ - Does a forward pass on the given batch and returns the output dictionary that the model - returns, after adding any specified regularization penalty to the loss (if training). - """ - output_dict = self._pytorch_model(**batch) - - if for_training: - try: - assert "loss" in output_dict - regularization_penalty = self.model.get_regularization_penalty() - - if regularization_penalty is not None: - output_dict["reg_loss"] = regularization_penalty - output_dict["loss"] += regularization_penalty - - except AssertionError: - if for_training: - raise RuntimeError( - "The model you are trying to optimize does not contain a" - " 'loss' key in the output of model.forward(inputs)." - ) - - return output_dict - - def _train_epoch(self, epoch: int) -> Dict[str, float]: - """ - Trains one epoch and returns metrics. - """ - logger.info("Epoch %d/%d", epoch, self._num_epochs - 1) - cpu_memory_usage = [] - for worker, memory in common_util.peak_cpu_memory().items(): - cpu_memory_usage.append((worker, memory)) - logger.info(f"Worker {worker} memory usage: {common_util.format_size(memory)}") - gpu_memory_usage = [] - for gpu, memory in common_util.peak_gpu_memory().items(): - gpu_memory_usage.append((gpu, memory)) - logger.info(f"GPU {gpu} memory usage: {common_util.format_size(memory)}") - - regularization_penalty = self.model.get_regularization_penalty() - - train_loss = 0.0 - batch_loss = 0.0 - train_reg_loss = None if regularization_penalty is None else 0.0 - batch_reg_loss = None if regularization_penalty is None else 0.0 - - # Set the model to "train" mode. - self._pytorch_model.train() - - # Get tqdm for the training batches - batch_generator = iter(self.data_loader) - batch_group_generator = common_util.lazy_groups_of( - batch_generator, self._num_gradient_accumulation_steps - ) - - logger.info("Training") - - num_training_batches: Union[int, float] - try: - len_data_loader = len(self.data_loader) - num_training_batches = math.ceil( - len_data_loader / self._num_gradient_accumulation_steps - ) - except TypeError: - num_training_batches = float("inf") - - # Having multiple tqdm bars in case of distributed training will be a mess. Hence only the primary's - # progress is shown - if self._primary: - batch_group_generator_tqdm = Tqdm.tqdm( - batch_group_generator, total=num_training_batches - ) - else: - batch_group_generator_tqdm = batch_group_generator - - self._last_log = time.time() - - batches_this_epoch = 0 - if self._batch_num_total is None: - self._batch_num_total = 0 - - done_early = False - for batch_group in batch_group_generator_tqdm: - if done_early: - break - - batches_this_epoch += 1 - self._batch_num_total += 1 - batch_num_total = self._batch_num_total - - # Zero gradients. - # NOTE: this is actually more efficient than calling `self.optimizer.zero_grad()` - # because it avoids a read op when the gradients are first updated below. - for param_group in self.optimizer.param_groups: - for p in param_group["params"]: - p.grad = None - - batch_loss = 0.0 - batch_group_outputs = [] - for batch in batch_group: - if self._distributed: - # Check whether the other workers have stopped already (due to differing amounts of - # data in each). If so, we can't proceed because we would hang when we hit the - # barrier implicit in Model.forward. We use a IntTensor instead a BoolTensor - # here because NCCL process groups apparently don't support BoolTensor. - done = torch.tensor(0, device=self.cuda_device) - torch.distributed.all_reduce(done, torch.distributed.ReduceOp.SUM) - if done.item() > 0: - done_early = True - logger.warning( - f"Worker {torch.distributed.get_rank()} finishing training early! " - "This implies that there is an imbalance in your training " - "data across the workers and that some amount of it will be " - "ignored. A small amount of this is fine, but a major imbalance " - "should be avoided. Note: This warning will appear unless your " - "data is perfectly balanced." - ) - break - - with amp.autocast(self._use_amp): - batch_outputs = self.batch_outputs(batch, for_training=True) - batch_group_outputs.append(batch_outputs) - loss = batch_outputs["loss"] - reg_loss = batch_outputs.get("reg_loss") - if torch.isnan(loss): - raise ValueError("nan loss encountered") - loss = loss / len(batch_group) - - batch_loss += loss.item() - if reg_loss is not None: - reg_loss = reg_loss / len(batch_group) - batch_reg_loss = reg_loss.item() - train_reg_loss += batch_reg_loss # type: ignore - - if self._scaler is not None: - self._scaler.scale(loss).backward() - else: - loss.backward() - if len(batch_group_outputs) <= 0: - continue - - train_loss += batch_loss - - batch_grad_norm = self.rescale_gradients() - - # This does nothing if batch_num_total is None or you are using a - # scheduler which doesn't update per batch. - if self._learning_rate_scheduler: - self._learning_rate_scheduler.step_batch(batch_num_total) - if self._momentum_scheduler: - self._momentum_scheduler.step_batch(batch_num_total) - - if self._scaler is not None: - self._scaler.step(self.optimizer) - self._scaler.update() - else: - self.optimizer.step() - - # Update moving averages - if self._moving_average is not None: - self._moving_average.apply(batch_num_total) - - # Update the description with the latest metrics - metrics = training_util.get_metrics( - self.model, - train_loss, - train_reg_loss, - batch_loss, - batch_reg_loss, - batches_this_epoch, - world_size=self._world_size, - cuda_device=self.cuda_device, - ) - - if self._primary: - # Updating tqdm only for the primary as the trainers wouldn't have one - description = training_util.description_from_metrics(metrics) - batch_group_generator_tqdm.set_description(description, refresh=False) - - if self._checkpointer is not None: - self._checkpointer.maybe_save_checkpoint(self, epoch, batches_this_epoch) - - for callback in self._callbacks: - callback.on_batch( - self, - batch_group, - batch_group_outputs, - metrics, - epoch, - batches_this_epoch, - is_training=True, - is_primary=self._primary, - batch_grad_norm=batch_grad_norm, - ) - - if self._distributed and not done_early: - logger.warning( - f"Worker {torch.distributed.get_rank()} completed its entire epoch (training)." - ) - # Indicate that we're done so that any workers that have remaining data stop the epoch early. - done = torch.tensor(1, device=self.cuda_device) - torch.distributed.all_reduce(done, torch.distributed.ReduceOp.SUM) - assert done.item() - - # Let all workers finish their epoch before computing - # the final statistics for the epoch. - if self._distributed: - dist.barrier() - - metrics = training_util.get_metrics( - self.model, - train_loss, - train_reg_loss, - batch_loss=None, - batch_reg_loss=None, - num_batches=batches_this_epoch, - reset=True, - world_size=self._world_size, - cuda_device=self.cuda_device, - ) - - for (worker, memory) in cpu_memory_usage: - metrics["worker_" + str(worker) + "_memory_MB"] = memory / (1024 * 1024) - for (gpu_num, memory) in gpu_memory_usage: - metrics["gpu_" + str(gpu_num) + "_memory_MB"] = memory / (1024 * 1024) - return metrics - - def _validation_loss(self, epoch: int) -> Tuple[float, Optional[float], int]: - """ - Computes the validation loss. Returns it and the number of batches. - """ - logger.info("Validating") - - self._pytorch_model.eval() - - # Replace parameter values with the shadow values from the moving averages. - if self._moving_average is not None: - self._moving_average.assign_average_value() - - if self._validation_data_loader is not None: - validation_data_loader = self._validation_data_loader - else: - raise ConfigurationError( - "Validation results cannot be calculated without a validation_data_loader" - ) - - regularization_penalty = self.model.get_regularization_penalty() - - # Having multiple tqdm bars in case of distributed training will be a mess. Hence only the primary's - # progress is shown - if self._primary: - val_generator_tqdm = Tqdm.tqdm(validation_data_loader) - else: - val_generator_tqdm = validation_data_loader - - batches_this_epoch = 0 - val_loss = 0.0 - val_batch_loss = 0.0 - val_reg_loss = None if regularization_penalty is None else 0.0 - val_batch_reg_loss = None if regularization_penalty is None else 0.0 - done_early = False - for batch in val_generator_tqdm: - if self._distributed: - # Check whether the other workers have stopped already (due to differing amounts of - # data in each). If so, we can't proceed because we would hang when we hit the - # barrier implicit in Model.forward. We use a IntTensor instead a BoolTensor - # here because NCCL process groups apparently don't support BoolTensor. - done = torch.tensor(0, device=self.cuda_device) - torch.distributed.all_reduce(done, torch.distributed.ReduceOp.SUM) - if done.item() > 0: - done_early = True - logger.warning( - f"Worker {torch.distributed.get_rank()} finishing validation early! " - "This implies that there is an imbalance in your validation " - "data across the workers and that some amount of it will be " - "ignored. A small amount of this is fine, but a major imbalance " - "should be avoided. Note: This warning will appear unless your " - "data is perfectly balanced." - ) - break - - with amp.autocast(self._use_amp): - batch_outputs = self.batch_outputs(batch, for_training=False) - loss = batch_outputs.get("loss") - reg_loss = batch_outputs.get("reg_loss") - if loss is not None: - # You shouldn't necessarily have to compute a loss for validation, so we allow for - # `loss` to be None. We need to be careful, though - `batches_this_epoch` is - # currently only used as the divisor for the loss function, so we can safely only - # count those batches for which we actually have a loss. If this variable ever - # gets used for something else, we might need to change things around a bit. - batches_this_epoch += 1 - val_batch_loss = loss.item() - val_loss += val_batch_loss - if reg_loss is not None: - val_batch_reg_loss = reg_loss.item() - val_reg_loss += val_batch_reg_loss # type: ignore - - # Update the description with the latest metrics - val_metrics = training_util.get_metrics( - self.model, - val_loss, - val_reg_loss, - val_batch_loss, - val_batch_reg_loss, - batches_this_epoch, - world_size=self._world_size, - cuda_device=self.cuda_device, - ) - - description = training_util.description_from_metrics(val_metrics) - if self._primary: - val_generator_tqdm.set_description(description, refresh=False) - - for callback in self._callbacks: - callback.on_batch( - self, - [batch], - [batch_outputs], - val_metrics, - epoch, - batches_this_epoch, - is_training=False, - is_primary=self._primary, - ) - - if self._distributed and not done_early: - logger.warning( - f"Worker {torch.distributed.get_rank()} completed its entire epoch (validation)." - ) - # Indicate that we're done so that any workers that have remaining data stop validation early. - done = torch.tensor(1, device=self.cuda_device) - torch.distributed.all_reduce(done, torch.distributed.ReduceOp.SUM) - assert done.item() - - # Now restore the original parameter values. - if self._moving_average is not None: - self._moving_average.restore() - - return val_loss, val_reg_loss, batches_this_epoch - - def train(self) -> Dict[str, Any]: - """ - Trains the supplied model with the supplied parameters. - """ - - for callback in self._callbacks: - callback.on_start(self, is_primary=self._primary) - - # Set default values in case of failure - epoch = None - metrics = None - - try: - metrics, epoch = self._try_train() - return metrics - finally: - for callback in self._callbacks: - callback.on_end(self, metrics=metrics, epoch=epoch, is_primary=self._primary) - - def _try_train(self) -> Tuple[Dict[str, Any], int]: - try: - epoch_counter = self._restore_checkpoint() - except RuntimeError: - traceback.print_exc() - raise ConfigurationError( - "Could not recover training from the checkpoint. Did you mean to output to " - "a different serialization directory or delete the existing serialization " - "directory?" - ) - - training_util.enable_gradient_clipping(self.model, self._grad_clipping) - - logger.info("Beginning training.") - - val_metrics: Dict[str, float] = {} - metrics: Dict[str, Any] = {} - epochs_trained = 0 - training_start_time = time.time() - - metrics["best_epoch"] = self._metric_tracker.best_epoch - for key, value in self._metric_tracker.best_epoch_metrics.items(): - metrics["best_validation_" + key] = value - - for epoch in range(epoch_counter, self._num_epochs): - epoch_start_time = time.time() - train_metrics = self._train_epoch(epoch) - - # Back up the model now, in case something goes wrong later with the evaluation - if self._primary and self._checkpointer is not None: - self._checkpointer.shelve_model(epoch, self) - # Wait for the primary process to finish saving the model checkpoint - if self._distributed: - dist.barrier() - - # get peak of memory usage - for key, value in train_metrics.items(): - if key.startswith("gpu_") and key.endswith("_memory_MB"): - metrics["peak_" + key] = max(metrics.get("peak_" + key, 0), value) - elif key.startswith("worker_") and key.endswith("_memory_MB"): - metrics["peak_" + key] = max(metrics.get("peak_" + key, 0), value) - - this_epoch_val_metric: float = 0.0 - if self._validation_data_loader is not None: - with torch.no_grad(): - # We have a validation set, so compute all the metrics on it. - val_loss, val_reg_loss, num_batches = self._validation_loss(epoch) - - # It is safe again to wait till the validation is done. This is - # important to get the metrics right. - if self._distributed: - dist.barrier() - - val_metrics = training_util.get_metrics( - self.model, - val_loss, - val_reg_loss, - batch_loss=None, - batch_reg_loss=None, - num_batches=num_batches, - reset=True, - world_size=self._world_size, - cuda_device=self.cuda_device, - ) - - # Check validation metric for early stopping - this_epoch_val_metric = self._metric_tracker.combined_score(val_metrics) - self._metric_tracker.add_metrics(val_metrics) - - # Create overall metrics dict - training_elapsed_time = time.time() - training_start_time - metrics["training_duration"] = str(datetime.timedelta(seconds=training_elapsed_time)) - metrics["training_start_epoch"] = epoch_counter - metrics["training_epochs"] = epochs_trained - metrics["epoch"] = epoch - - for key, value in train_metrics.items(): - metrics["training_" + key] = value - for key, value in val_metrics.items(): - metrics["validation_" + key] = value - - if self._metric_tracker.is_best_so_far(): - # Update all the best_ metrics. - # (Otherwise they just stay the same as they were.) - metrics["best_epoch"] = epoch - for key, value in val_metrics.items(): - metrics["best_validation_" + key] = value - - self._metric_tracker.best_epoch_metrics = val_metrics - - if self._serialization_dir and self._primary: - common_util.dump_metrics( - os.path.join(self._serialization_dir, f"metrics_epoch_{epoch}.json"), - metrics, - ) - - # The Scheduler API is agnostic to whether your schedule requires a validation metric - - # if it doesn't, the validation metric passed here is ignored. - if self._learning_rate_scheduler: - self._learning_rate_scheduler.step(this_epoch_val_metric) - if self._momentum_scheduler: - self._momentum_scheduler.step(this_epoch_val_metric) - - # The checkpointer saves state from the learning rate scheduler and the momentum - # scheduler, so we have to make sure those are updated before we save the checkpoint here. - if self._primary and self._checkpointer is not None: - self._checkpointer.save_checkpoint( - epoch, self, is_best_so_far=self._metric_tracker.is_best_so_far() - ) - # Wait for the primary process to finish saving the checkpoint - if self._distributed: - dist.barrier() - - for callback in self._callbacks: - callback.on_epoch(self, metrics=metrics, epoch=epoch, is_primary=self._primary) - - epoch_elapsed_time = time.time() - epoch_start_time - logger.info("Epoch duration: %s", datetime.timedelta(seconds=epoch_elapsed_time)) - - if epoch < self._num_epochs - 1: - training_elapsed_time = time.time() - training_start_time - estimated_time_remaining = training_elapsed_time * ( - (self._num_epochs - epoch_counter) / float(epoch - epoch_counter + 1) - 1 - ) - formatted_time = str(datetime.timedelta(seconds=int(estimated_time_remaining))) - logger.info("Estimated training time remaining: %s", formatted_time) - - epochs_trained += 1 - - if self._metric_tracker.should_stop_early(): - logger.info("Ran out of patience. Stopping training.") - break - else: - epoch = self._num_epochs - 1 - - # Load the best model state before returning - best_model_state = ( - None if self._checkpointer is None else self._checkpointer.best_model_state() - ) - if best_model_state: - self.model.load_state_dict(best_model_state) - - return metrics, epoch - - @contextmanager - def get_checkpoint_state(self) -> Iterator[Tuple[Dict[str, Any], Dict[str, Any]]]: - if self._moving_average is not None: - # Assigning average value to model parameters. The checkpointer will call - # `restore_state_after_checkpointing` when it is done to put this back to what it was. - self._moving_average.assign_average_value() - - model_state = self.model.state_dict() - - # These are the training states we need to persist. - training_states = { - "metric_tracker": self._metric_tracker.state_dict(), - "optimizer": self.optimizer.state_dict(), - "batch_num_total": self._batch_num_total, - } - - # If we have a learning rate or momentum scheduler, we should persist them too. - if self._learning_rate_scheduler is not None: - training_states["learning_rate_scheduler"] = self._learning_rate_scheduler.state_dict() - if self._momentum_scheduler is not None: - training_states["momentum_scheduler"] = self._momentum_scheduler.state_dict() - - try: - yield model_state, training_states - finally: - if self._moving_average is not None: - self._moving_average.restore() - - def _restore_checkpoint(self) -> int: - """ - Restores the model and training state from the last saved checkpoint. - This includes an epoch count and optimizer state, which is serialized separately - from model parameters. This function should only be used to continue training - - if you wish to load a model for inference/load parts of a model into a new - computation graph, you should use the native Pytorch functions: - ` model.load_state_dict(torch.load("/path/to/model/weights.th"))` - - If `self._serialization_dir` does not exist or does not contain any checkpointed weights, - this function will do nothing and return 0. - - # Returns - - epoch: `int` - The epoch at which to resume training, which should be one after the epoch - in the saved training state. - """ - if self._checkpointer is None: - return 0 - - model_state, training_state = self._checkpointer.restore_checkpoint() - - if not training_state: - # No checkpoint to restore, start at 0 - return 0 - - self.model.load_state_dict(model_state) - self.optimizer.load_state_dict(training_state["optimizer"]) - if ( - self._learning_rate_scheduler is not None - and "learning_rate_scheduler" in training_state - ): - self._learning_rate_scheduler.load_state_dict(training_state["learning_rate_scheduler"]) - if self._momentum_scheduler is not None and "momentum_scheduler" in training_state: - self._momentum_scheduler.load_state_dict(training_state["momentum_scheduler"]) - training_util.move_optimizer_to_cuda(self.optimizer) - - # Currently the `training_state` contains a serialized `MetricTracker`. - if "metric_tracker" in training_state: - self._metric_tracker.load_state_dict(training_state["metric_tracker"]) - else: - self._metric_tracker.clear() - - if isinstance(training_state["epoch"], int): - epoch_to_return = training_state["epoch"] + 1 - else: - epoch_to_return = int(training_state["epoch"].split(".")[0]) + 1 - - # For older checkpoints with batch_num_total missing, default to old behavior where - # it is unchanged. - batch_num_total = training_state.get("batch_num_total") - if batch_num_total is not None: - self._batch_num_total = batch_num_total - - return epoch_to_return - - @classmethod - def from_partial_objects( - cls, - model: Model, - serialization_dir: str, - data_loader: DataLoader, - validation_data_loader: DataLoader = None, - local_rank: int = 0, - patience: int = None, - validation_metric: Union[str, List[str]] = "-loss", - num_epochs: int = 20, - cuda_device: Optional[Union[int, torch.device]] = None, - grad_norm: float = None, - grad_clipping: float = None, - distributed: bool = False, - world_size: int = 1, - num_gradient_accumulation_steps: int = 1, - use_amp: bool = False, - no_grad: List[str] = None, - optimizer: Lazy[Optimizer] = Lazy(Optimizer.default), - learning_rate_scheduler: Lazy[LearningRateScheduler] = None, - momentum_scheduler: Lazy[MomentumScheduler] = None, - moving_average: Lazy[MovingAverage] = None, - checkpointer: Lazy[Checkpointer] = Lazy(Checkpointer), - callbacks: List[Lazy[TrainerCallback]] = None, - enable_default_callbacks: bool = True, - run_confidence_checks: bool = True, - **kwargs, - ) -> "Trainer": - """ - This method exists so that we can have a documented method to construct this class using - `FromParams`. If you are not using `FromParams` or config files, you can safely ignore this - method. - - The reason we can't just use `__init__` with `FromParams` here is because there are - sequential dependencies to this class's arguments. Anything that has a `Lazy[]` type - annotation needs something from one of the non-`Lazy` arguments. The `Optimizer` needs to - have the parameters from the `Model` before it's constructed, and the `Schedulers` need to - have the `Optimizer`. Because of this, the typical way we construct things `FromParams` - doesn't work, so we use `Lazy` to allow for constructing the objects sequentially. - - If you're not using `FromParams`, you can just construct these arguments in the right order - yourself in your code and call the constructor directly. - """ - if cuda_device is None: - from torch import cuda - - if cuda.device_count() > 0: - cuda_device = 0 - else: - cuda_device = -1 - - check_for_gpu(cuda_device) - if cuda_device >= 0: - # Moving model to GPU here so that the optimizer state gets constructed on - # the right device. - model = model.cuda(cuda_device) - - if no_grad: - for name, parameter in model.named_parameters(): - if any(re.search(regex, name) for regex in no_grad): - parameter.requires_grad_(False) - - parameters = [[n, p] for n, p in model.named_parameters() if p.requires_grad] - optimizer_ = optimizer.construct(model_parameters=parameters) - - common_util.log_frozen_and_tunable_parameter_names(model) - - batches_per_epoch: Optional[int] - try: - batches_per_epoch = len(data_loader) - batches_per_epoch = math.ceil(batches_per_epoch / num_gradient_accumulation_steps) - except TypeError: - batches_per_epoch = None - - moving_average_ = ( - None if moving_average is None else moving_average.construct(parameters=parameters) - ) - learning_rate_scheduler_ = ( - None - if learning_rate_scheduler is None - else learning_rate_scheduler.construct( - optimizer=optimizer_, num_epochs=num_epochs, num_steps_per_epoch=batches_per_epoch - ) - ) - momentum_scheduler_ = ( - None - if momentum_scheduler is None - else momentum_scheduler.construct(optimizer=optimizer_) - ) - checkpointer_ = checkpointer.construct(serialization_dir=serialization_dir) - - callbacks_: List[TrainerCallback] = [] - for callback_ in callbacks or []: - callbacks_.append(callback_.construct(serialization_dir=serialization_dir)) - - return cls( - model, - optimizer_, - data_loader, - patience=patience, - validation_metric=validation_metric, - validation_data_loader=validation_data_loader, - num_epochs=num_epochs, - serialization_dir=serialization_dir, - cuda_device=cuda_device, - grad_norm=grad_norm, - grad_clipping=grad_clipping, - learning_rate_scheduler=learning_rate_scheduler_, - momentum_scheduler=momentum_scheduler_, - checkpointer=checkpointer_, - moving_average=moving_average_, - callbacks=callbacks_, - distributed=distributed, - local_rank=local_rank, - world_size=world_size, - num_gradient_accumulation_steps=num_gradient_accumulation_steps, - use_amp=use_amp, - enable_default_callbacks=enable_default_callbacks, - run_confidence_checks=run_confidence_checks, - **kwargs, - ) - - -DEFAULT_CALLBACKS: Tuple[Type[TrainerCallback]] = (ConsoleLoggerCallback,) -""" -The default callbacks used by `GradientDescentTrainer`. -""" + def get_best_weights_path(self) -> Optional[str]: + """Returns the path to file containing the current best weights.""" + return None diff --git a/tests/commands/no_op_train_test.py b/tests/commands/no_op_train_test.py index 948edb39784..c862fccb581 100644 --- a/tests/commands/no_op_train_test.py +++ b/tests/commands/no_op_train_test.py @@ -31,7 +31,7 @@ def test_train_model(self): serialization_dir = self.TEST_DIR / "serialization_directory" train_model(params(), serialization_dir=serialization_dir) - archive = load_archive(str(serialization_dir / "model.tar.gz")) + archive = load_archive(serialization_dir / "model.tar.gz") model = archive.model assert model.forward(torch.tensor([1, 2, 3]))["class"] == torch.tensor(98) assert model.vocab.get_vocab_size() == 9 diff --git a/tests/training/checkpointer_test.py b/tests/training/checkpointer_test.py index 206fa43278b..ac102a3a983 100644 --- a/tests/training/checkpointer_test.py +++ b/tests/training/checkpointer_test.py @@ -1,21 +1,19 @@ import os -import re import time -from contextlib import contextmanager from allennlp.common.testing import AllenNlpTestCase from allennlp.common.params import Params from allennlp.training import Checkpointer, Trainer +from allennlp.training.trainer import TrainerCheckpoint class FakeTrainer(Trainer): - def __init__(self, model_state, training_states): + def __init__(self, model_state, training_state): self._model_state = model_state - self._training_states = training_states + self._training_state = training_state - @contextmanager - def get_checkpoint_state(self): - yield self._model_state, self._training_states + def get_checkpoint_state(self) -> TrainerCheckpoint: + return TrainerCheckpoint(self._model_state, self._training_state) class TestCheckpointer(AllenNlpTestCase): @@ -26,77 +24,72 @@ def retrieve_and_delete_saved(self): and returns the saved epochs as two lists of integers. """ serialization_files = os.listdir(self.TEST_DIR) - model_checkpoints = [x for x in serialization_files if "model_state_epoch" in x] - found_model_epochs = [ - int(re.search(r"model_state_epoch_([0-9\.\-]+)\.th", x).group(1)) - for x in model_checkpoints - ] + + model_checkpoints = [x for x in serialization_files if "model_state_" in x] + found_model_states = [Checkpointer._parse_model_state_path(x) for x in model_checkpoints] for f in model_checkpoints: os.remove(os.path.join(self.TEST_DIR, f)) - training_checkpoints = [x for x in serialization_files if "training_state_epoch" in x] - found_training_epochs = [ - int(re.search(r"training_state_epoch_([0-9\.\-]+)\.th", x).group(1)) - for x in training_checkpoints + + training_checkpoints = [x for x in serialization_files if "training_state_" in x] + found_training_states = [ + Checkpointer._parse_training_state_path(x) for x in training_checkpoints ] for f in training_checkpoints: os.remove(os.path.join(self.TEST_DIR, f)) - return sorted(found_model_epochs), sorted(found_training_epochs) + return sorted(found_model_states), sorted(found_training_states) def test_default(self): """ Tests that the default behavior keeps just the last 2 checkpoints. """ default_num_to_keep = 2 - num_epochs = 30 - target = list(range(num_epochs - default_num_to_keep, num_epochs)) + num_epochs = 5 + target = [(e, 0) for e in range(num_epochs - default_num_to_keep, num_epochs)] checkpointer = Checkpointer(serialization_dir=self.TEST_DIR) - - for e in range(num_epochs): - checkpointer.save_checkpoint( - epoch=e, - trainer=FakeTrainer(model_state={"epoch": e}, training_states={"epoch": e}), - is_best_so_far=False, - ) + for epochs_completed in range(num_epochs): + for batches_completed in [0, 5, 10]: + state = { + "epochs_completed": epochs_completed, + "batches_in_epoch_completed": batches_completed, + } + checkpointer.maybe_save_checkpoint( + FakeTrainer(model_state=state, training_state=state), + epochs_completed, + batches_completed, + ) models, training = self.retrieve_and_delete_saved() assert models == training == target def test_keep_zero(self): - checkpointer = Checkpointer( - serialization_dir=self.TEST_DIR, num_serialized_models_to_keep=0 - ) - for e in range(10): - checkpointer.save_checkpoint( - epoch=e, - trainer=FakeTrainer(model_state={"epoch": e}, training_states={"epoch": e}), - is_best_so_far=True, + checkpointer = Checkpointer(serialization_dir=self.TEST_DIR, keep_most_recent_by_count=0) + for epochs_completed in range(5): + state = {"epochs_completed": epochs_completed, "batches_in_epoch_completed": 0} + checkpointer.maybe_save_checkpoint( + FakeTrainer(model_state=state, training_state=state), epochs_completed, 0 ) files = os.listdir(self.TEST_DIR) - assert "model_state_epoch_1.th" not in files - assert "training_state_epoch_1.th" not in files + assert not any("model_state_" in x for x in files) + assert not any("training_state_" in x for x in files) def test_with_time(self): - """ - Tests that keep_serialized_model_every_num_seconds parameter causes a checkpoint to be saved - after enough time has elapsed between epochs. - """ - num_to_keep = 10 num_epochs = 30 - target = list(range(num_epochs - num_to_keep, num_epochs)) pauses = [5, 18, 26] - target = sorted(set(target + pauses)) + target = [(e, 0) for e in pauses] checkpointer = Checkpointer( serialization_dir=self.TEST_DIR, - num_serialized_models_to_keep=num_to_keep, - keep_serialized_model_every_num_seconds=1, + save_completed_epochs=False, + save_every_num_seconds=1, + keep_most_recent_by_count=3, ) for e in range(num_epochs): if e in pauses: time.sleep(2) - checkpointer.save_checkpoint( - epoch=e, - trainer=FakeTrainer(model_state={"epoch": e}, training_states={"epoch": e}), - is_best_so_far=False, + state = {"epochs_completed": e, "batches_in_epoch_completed": 0} + checkpointer.maybe_save_checkpoint( + trainer=FakeTrainer(model_state=state, training_state=state), + num_epochs_completed=e, + num_batches_in_epoch_completed=0, ) models, training = self.retrieve_and_delete_saved() assert models == training == target diff --git a/tests/training/trainer_test.py b/tests/training/trainer_test.py index 3926adf0ec2..2373caafefd 100644 --- a/tests/training/trainer_test.py +++ b/tests/training/trainer_test.py @@ -2,7 +2,6 @@ import glob import json import os -import re import time from typing import Any, Dict, List, Optional @@ -217,11 +216,11 @@ def test_data_loader_lazy_epoch_size_correct(self): num_epochs=num_epochs, serialization_dir=self.TEST_DIR, ) - assert trainer._batch_num_total == 0 + assert trainer._total_batches_completed == 0 metrics = trainer.train() epoch = metrics["epoch"] assert epoch == num_epochs - 1 - assert trainer._batch_num_total == num_epochs * 2 + assert trainer._total_batches_completed == num_epochs * 2 def test_data_loader_lazy_epoch_size_correct_custom_epoch_size(self): self.data_loader_lazy.batches_per_epoch = 3 @@ -234,11 +233,11 @@ def test_data_loader_lazy_epoch_size_correct_custom_epoch_size(self): num_epochs=num_epochs, serialization_dir=self.TEST_DIR, ) - assert trainer._batch_num_total == 0 + assert trainer._total_batches_completed == 0 metrics = trainer.train() epoch = metrics["epoch"] assert epoch == num_epochs - 1 - assert trainer._batch_num_total == num_epochs * 3 + assert trainer._total_batches_completed == num_epochs * 3 def test_trainer_respects_epoch_size_equals_total(self): batches_per_epoch = 4 @@ -256,11 +255,11 @@ def test_trainer_respects_epoch_size_equals_total(self): num_epochs=num_epochs, serialization_dir=self.TEST_DIR, ) - assert trainer._batch_num_total == 0 + assert trainer._total_batches_completed == 0 metrics = trainer.train() epoch = metrics["epoch"] assert epoch == num_epochs - 1 - assert trainer._batch_num_total == num_epochs * batches_per_epoch + assert trainer._total_batches_completed == num_epochs * batches_per_epoch def test_trainer_respects_epoch_size_larger_tnan_total(self): batches_per_epoch = 7 @@ -278,11 +277,11 @@ def test_trainer_respects_epoch_size_larger_tnan_total(self): num_epochs=num_epochs, serialization_dir=self.TEST_DIR, ) - assert trainer._batch_num_total == 0 + assert trainer._total_batches_completed == 0 metrics = trainer.train() epoch = metrics["epoch"] assert epoch == num_epochs - 1 - assert trainer._batch_num_total == num_epochs * batches_per_epoch + assert trainer._total_batches_completed == num_epochs * batches_per_epoch def test_trainer_respects_epoch_size_smaller_tnan_total(self): batches_per_epoch = 1 @@ -300,11 +299,11 @@ def test_trainer_respects_epoch_size_smaller_tnan_total(self): num_epochs=num_epochs, serialization_dir=self.TEST_DIR, ) - assert trainer._batch_num_total == 0 + assert trainer._total_batches_completed == 0 metrics = trainer.train() epoch = metrics["epoch"] assert epoch == num_epochs - 1 - assert trainer._batch_num_total == num_epochs * batches_per_epoch + assert trainer._total_batches_completed == num_epochs * batches_per_epoch def test_trainer_can_resume_training(self): trainer = GradientDescentTrainer( @@ -316,6 +315,7 @@ def test_trainer_can_resume_training(self): serialization_dir=self.TEST_DIR, ) trainer.train() + new_trainer = GradientDescentTrainer( self.model, self.optimizer, @@ -324,9 +324,9 @@ def test_trainer_can_resume_training(self): num_epochs=3, serialization_dir=self.TEST_DIR, ) + new_trainer._restore_checkpoint() - epoch = new_trainer._restore_checkpoint() - assert epoch == 1 + assert new_trainer._start_after_epochs_completed == 1 tracker = trainer._metric_tracker assert tracker.is_best_so_far() @@ -359,8 +359,8 @@ def test_trainer_can_resume_training_for_exponential_moving_average(self): moving_average=new_moving_average, ) - epoch = new_trainer._restore_checkpoint() - assert epoch == 1 + new_trainer._restore_checkpoint() + assert new_trainer._start_after_epochs_completed == 1 tracker = trainer._metric_tracker assert tracker.is_best_so_far() @@ -605,8 +605,8 @@ def test_trainer_can_run_and_resume_with_momentum_scheduler(self): num_epochs=6, serialization_dir=self.TEST_DIR, ) - epoch = new_trainer._restore_checkpoint() - assert epoch == 4 + new_trainer._restore_checkpoint() + new_trainer._start_after_epochs_completed = 4 assert new_trainer._momentum_scheduler.last_epoch == 3 new_trainer.train() @@ -672,8 +672,8 @@ def test_trainer_can_resume_with_lr_scheduler(self): num_epochs=4, serialization_dir=self.TEST_DIR, ) - epoch = new_trainer._restore_checkpoint() - assert epoch == 2 + new_trainer._restore_checkpoint() + assert new_trainer._start_after_epochs_completed == 2 assert new_trainer._learning_rate_scheduler.last_epoch == 1 new_trainer.train() @@ -719,17 +719,20 @@ def test_trainer_respects_num_serialized_models_to_keep(self): self.data_loader, num_epochs=5, serialization_dir=self.TEST_DIR, - checkpointer=Checkpointer( - serialization_dir=self.TEST_DIR, num_serialized_models_to_keep=3 - ), + checkpointer=Checkpointer(serialization_dir=self.TEST_DIR, keep_most_recent_by_count=3), ) trainer.train() # Now check the serialized files - for prefix in ["model_state_epoch_*", "training_state_epoch_*"]: - file_names = glob.glob(os.path.join(self.TEST_DIR, prefix)) - epochs = [int(re.search(r"_([0-9])\.th", fname).group(1)) for fname in file_names] - assert sorted(epochs) == [2, 3, 4] + expected = [(3, 0), (4, 0), (5, 0)] + + file_names = glob.glob(os.path.join(self.TEST_DIR, "model_state_e*_b*")) + epochs = [Checkpointer._parse_model_state_path(fname) for fname in file_names] + assert sorted(epochs) == expected + + file_names = glob.glob(os.path.join(self.TEST_DIR, "training_state_e*_b*")) + epochs = [Checkpointer._parse_training_state_path(fname) for fname in file_names] + assert sorted(epochs) == expected def test_trainer_saves_metrics_every_epoch(self): trainer = GradientDescentTrainer( @@ -739,9 +742,7 @@ def test_trainer_saves_metrics_every_epoch(self): validation_data_loader=self.validation_data_loader, num_epochs=5, serialization_dir=self.TEST_DIR, - checkpointer=Checkpointer( - serialization_dir=self.TEST_DIR, num_serialized_models_to_keep=3 - ), + checkpointer=Checkpointer(serialization_dir=self.TEST_DIR, keep_most_recent_by_count=3), ) trainer.train() @@ -757,9 +758,6 @@ def test_trainer_respects_keep_serialized_model_every_num_seconds(self): # To test: # Create an fake data loader that sleeps for 2.5 second per epoch, so the total # training time for one epoch is slightly greater then 2.5 seconds. - # Run for 6 epochs, keeping the last 2 models, models also kept every 5 seconds. - # Check the resulting checkpoints. Should then have models at epochs - # 2, 4, plus the last two at 5 and 6. class SlowDataLoader: data_loader = SimpleDataLoader(self.instances, batch_size=2) @@ -781,19 +779,24 @@ def set_target_device(self, _): num_epochs=6, serialization_dir=self.TEST_DIR, checkpointer=Checkpointer( + save_completed_epochs=False, serialization_dir=self.TEST_DIR, - num_serialized_models_to_keep=2, - keep_serialized_model_every_num_seconds=5, + keep_most_recent_by_count=4, + save_every_num_seconds=5, ), ) trainer.train() # Now check the serialized files - for prefix in ["model_state_epoch_*", "training_state_epoch_*"]: - file_names = glob.glob(os.path.join(self.TEST_DIR, prefix)) - epochs = [int(re.search(r"_([0-9])\.th", fname).group(1)) for fname in file_names] - # epoch N has N-1 in file name - assert sorted(epochs) == [1, 3, 4, 5] + expected = [(1, 1), (3, 1), (5, 1)] + + file_names = glob.glob(os.path.join(self.TEST_DIR, "model_state_e*_b*")) + epochs = [Checkpointer._parse_model_state_path(fname) for fname in file_names] + assert sorted(epochs) == expected + + file_names = glob.glob(os.path.join(self.TEST_DIR, "training_state_e*_b*")) + epochs = [Checkpointer._parse_training_state_path(fname) for fname in file_names] + assert sorted(epochs) == expected def test_trainer_can_log_learning_rates_tensorboard(self): data_loader = SimpleDataLoader(self.instances, 4) @@ -853,54 +856,64 @@ def test_confidence_check_default(self): # Check is not run, so no failure. trainer.train() - def test_trainer_saves_models_at_specified_interval(self): - data_loader = SimpleDataLoader(self.instances, 4) + @pytest.mark.parametrize("checkpoint_to_keep", range(20)) + def test_trainer_restores_and_makes_same_results(self, checkpoint_to_keep: int): + batch_size = 2 + data_loader = SimpleDataLoader(self.instances, batch_size) + num_epochs = 10 + num_batches = len(self.instances) // batch_size trainer = GradientDescentTrainer( self.model, self.optimizer, data_loader, - num_epochs=2, + validation_data_loader=data_loader, + num_epochs=num_epochs, serialization_dir=self.TEST_DIR, checkpointer=Checkpointer( serialization_dir=self.TEST_DIR, - model_save_interval=0.0001, - num_serialized_models_to_keep=10, + save_every_num_seconds=0.0001, + keep_most_recent_by_count=20, ), ) - trainer.train() + original_metrics = trainer.train() # Now check the serialized files for models saved during the epoch. - prefix = "model_state_epoch_*" - file_names = sorted(glob.glob(os.path.join(self.TEST_DIR, prefix))) - epochs = [re.search(r"_([0-9\.\-]+)\.th", fname).group(1) for fname in file_names] - # We should have checkpoints at the end of each epoch and during each, e.g. - # [0.timestamp, 0, 1.timestamp, 1] - assert len(epochs) == 4 - assert epochs[3] == "1" - assert "." in epochs[0] - - # Now make certain we can restore from timestamped checkpoint. - # To do so, remove the checkpoint from the end of epoch 1&2, so - # that we are forced to restore from the timestamped checkpoints. - for k in range(2): - os.remove(os.path.join(self.TEST_DIR, "model_state_epoch_{}.th".format(k))) - os.remove(os.path.join(self.TEST_DIR, "training_state_epoch_{}.th".format(k))) + file_names = glob.glob(os.path.join(self.TEST_DIR, "model_state_e*_b*")) + checkpoints = [Checkpointer._parse_model_state_path(fname) for fname in file_names] + checkpoints.sort() + + expected = [(e, b) for e in range(num_epochs) for b in range(num_batches + 1)] + del expected[0] + expected.append((num_epochs, 0)) + expected = expected[-20:] + assert checkpoints == expected + + # Now make certain we can restore from checkpoint in the middle of an epoch. + # To do so, remove the checkpoint at the end of epochs. + for i, checkpoint in enumerate(checkpoints): + if i != checkpoint_to_keep: + os.remove(trainer._checkpointer._model_state_path(*checkpoint)) + os.remove(trainer._checkpointer._training_state_path(*checkpoint)) os.remove(os.path.join(self.TEST_DIR, "best.th")) - restore_trainer = GradientDescentTrainer( + restored_trainer = GradientDescentTrainer( self.model, self.optimizer, self.data_loader, - num_epochs=2, + validation_data_loader=data_loader, + num_epochs=num_epochs, serialization_dir=self.TEST_DIR, - checkpointer=Checkpointer(serialization_dir=self.TEST_DIR, model_save_interval=0.0001), + checkpointer=Checkpointer( + serialization_dir=self.TEST_DIR, + save_every_num_seconds=0.0001, + keep_most_recent_by_count=10, + ), ) - epoch = restore_trainer._restore_checkpoint() - assert epoch == 2 - # One batch per epoch. - assert restore_trainer._batch_num_total == 2 + restored_metrics = restored_trainer.train() + + assert original_metrics["best_validation_loss"] == restored_metrics["best_validation_loss"] def test_trainer_saves_and_loads_best_validation_metrics_correctly_1(self): # Use -loss and run 1 epoch of original-training, and one of restored-training @@ -1033,9 +1046,11 @@ def test_trainer_can_run_gradient_accumulation(self): ) assert trainer._num_gradient_accumulation_steps == steps_to_accumulate - metrics = trainer.train() + trainer.train() - num_batches_trained_per_epoch = trainer._batch_num_total // (metrics["training_epochs"] + 1) + num_batches_trained_per_epoch = ( + trainer._total_batches_completed // trainer._epochs_completed + ) num_batches_expected = math.ceil( math.ceil(len(instances) / self.data_loader.batch_size) / steps_to_accumulate )