diff --git a/classy_vision/tasks/fine_tuning_task.py b/classy_vision/tasks/fine_tuning_task.py index bd30e29c7d..62eb53a5f4 100644 --- a/classy_vision/tasks/fine_tuning_task.py +++ b/classy_vision/tasks/fine_tuning_task.py @@ -4,7 +4,9 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Dict +import warnings +from enum import Enum +from typing import Any, Callable, Dict, Union from classy_vision.generic.util import ( load_and_broadcast_checkpoint, @@ -13,15 +15,42 @@ from classy_vision.tasks import ClassificationTask, register_task +class FreezeUntil(Enum): + """ + Enum for a pre-specified point to freeze the classy model unitl. + + Attributes: + HEAD (str): Freeze the model unitl the classy head + """ + + HEAD = "head" + + def __eq__(self, other: str): + return other.lower() == self.value + + @register_task("fine_tuning") class FineTuningTask(ClassificationTask): + """Finetuning training task. + + This task encapsultates all of the components and steps needed to + fine-tune a classifier using a :class:`classy_vision.trainer.ClassyTrainer`. + + :var pretrained_checkpoint_path: String path to pretrained model + :var reset_heads: bool. Whether or not to reset the model heads during finetuning. + :var freeze_until: optional string. If specified, must be a string name of a module within + the model. Finetuning will freeze the model up to this module. Model weights will + only be trainable from this modeule onwards, always including the head. To freeze the + trunk model, specify 'head' as the un-freeze point. + """ + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.pretrained_checkpoint_dict = None self.pretrained_checkpoint_path = None self.pretrained_checkpoint_load_strict = True self.reset_heads = False - self.freeze_trunk = False + self.freeze_until = None @classmethod def from_config(cls, config: Dict[str, Any]) -> "FineTuningTask": @@ -44,7 +73,13 @@ def from_config(cls, config: Dict[str, Any]) -> "FineTuningTask": ) task.set_reset_heads(config.get("reset_heads", False)) - task.set_freeze_trunk(config.get("freeze_trunk", False)) + assert ( + "freeze_trunk" not in config or "freeze_until" not in config + ), "Config options 'freeze_trunk' and 'freeze_until' cannot both be specified" + if "freeze_trunk" in config: + task.set_freeze_trunk(config.get("freeze_trunk", False)) + else: + task.set_freeze_until(config.get("freeze_until", None)) return task def set_pretrained_checkpoint(self, checkpoint_path: str) -> "FineTuningTask": @@ -68,22 +103,46 @@ def set_reset_heads(self, reset_heads: bool) -> "FineTuningTask": return self def set_freeze_trunk(self, freeze_trunk: bool) -> "FineTuningTask": - self.freeze_trunk = freeze_trunk + if freeze_trunk: + self.freeze_until = FreezeUntil.HEAD.value + warnings.warn( + "Congig option freeze_trunk has been deprecated. " + "Use \"freeze_until:'head'\" instead", + DeprecationWarning, + ) + + return self + + def set_freeze_until(self, freeze_until: Union[str, None]) -> "FineTuningTask": + self.freeze_until = freeze_until return self def _set_model_train_mode(self): phase = self.phases[self.phase_idx] self.loss.train(phase["train"]) - if self.freeze_trunk: + if self.freeze_until is not None: # convert all the sub-modules to the eval mode, except the heads self.base_model.eval() - for heads in self.base_model.get_heads().values(): - for h in heads: - h.train(phase["train"]) + self._apply_to_nonfrozen(lambda x: x.train(phase["train"])) else: self.base_model.train(phase["train"]) + def _apply_to_nonfrozen(self, callable: Callable[..., Any]) -> None: + for heads in self.base_model.get_heads().values(): + for h in heads: + callable(h) + if not self.freeze_until == FreezeUntil.HEAD: + unfrozen_module = False + for name, module in self.base_model.named_modules(): + if name == self.freeze_until: + unfrozen_module = True + if unfrozen_module: + callable(module) + assert ( + unfrozen_module + ), f"Freeze until point {self.freeze_until} not found in model" + def prepare(self) -> None: super().prepare() if self.checkpoint_dict is None: @@ -109,15 +168,17 @@ def prepare(self) -> None: state_load_success ), "Update classy state from pretrained checkpoint was unsuccessful." - if self.freeze_trunk: + if self.freeze_until is not None: # do not track gradients for all the parameters in the model except # for the parameters in the heads for param in self.base_model.parameters(): param.requires_grad = False - for heads in self.base_model.get_heads().values(): - for h in heads: - for param in h.parameters(): - param.requires_grad = True + + def _set_requires_grad_true(x): + for param in x.parameters(): + param.requires_grad = True + + self._apply_to_nonfrozen(_set_requires_grad_true) # re-create ddp model self.distributed_model = None self.init_distributed_data_parallel_model() diff --git a/test/generic/utils.py b/test/generic/utils.py index c05418cba9..d126748f65 100644 --- a/test/generic/utils.py +++ b/test/generic/utils.py @@ -215,8 +215,15 @@ def recursive_unpack(batch): raise TypeError("Unexpected type %s passed to unpack" % type(batch)) -def compare_model_state(test_fixture, state, state2, check_heads=True): +def compare_model_state( + test_fixture, state, state2, check_heads=True, state_changed_params=() +): for k in state["model"]["trunk"].keys(): + if k in state_changed_params: + test_fixture.assertFalse( + torch.allclose(state["model"]["trunk"][k], state2["model"]["trunk"][k]) + ) + continue if not torch.allclose(state["model"]["trunk"][k], state2["model"]["trunk"][k]): print(k, state["model"]["trunk"][k], state2["model"]["trunk"][k]) test_fixture.assertTrue( diff --git a/test/tasks_fine_tuning_task_test.py b/test/tasks_fine_tuning_task_test.py index b4a444d64b..7b8f48f795 100644 --- a/test/tasks_fine_tuning_task_test.py +++ b/test/tasks_fine_tuning_task_test.py @@ -40,13 +40,47 @@ def forward(self, x, target): class TestFineTuningTask(unittest.TestCase): - def _compare_model_state(self, state_1, state_2, check_heads=True): - return compare_model_state(self, state_1, state_2, check_heads=check_heads) + def _compare_model_state( + self, state_1, state_2, check_heads=True, state_changed_params=() + ): + return compare_model_state( + self, + state_1, + state_2, + check_heads=check_heads, + state_changed_params=state_changed_params, + ) def _compare_state_dict(self, state_1, state_2, check_heads=True): for k in state_1.keys(): self.assertTrue(torch.allclose(state_1[k].cpu(), state_2[k].cpu())) + def _get_unfrezee_points_to_unfrozen_params(self): + return { + "blocks.0.block0-0._module.downsample.1": ( + "blocks.0.block0-0.downsample.1.weight", + "blocks.0.block0-0.downsample.1.bias", + "blocks.0.block0-0.downsample.1.running_mean", + "blocks.0.block0-0.downsample.1.running_var", + "blocks.0.block0-0.downsample.1.num_batches_tracked", + ), + "blocks.0.block0-0._module.convolutional_block.6": ( + "blocks.0.block0-0.convolutional_block.6.weight", + "blocks.0.block0-0.bn.weight", + "blocks.0.block0-0.bn.bias", + "blocks.0.block0-0.bn.running_mean", + "blocks.0.block0-0.bn.running_var", + "blocks.0.block0-0.bn.num_batches_tracked", + "blocks.0.block0-0.downsample.0.weight", + "blocks.0.block0-0.downsample.1.weight", + "blocks.0.block0-0.downsample.1.bias", + "blocks.0.block0-0.downsample.1.running_mean", + "blocks.0.block0-0.downsample.1.running_var", + "blocks.0.block0-0.downsample.1.num_batches_tracked", + ), + "head": (), + } + def _get_fine_tuning_config( self, head_num_classes=100, pretrained_checkpoint=False ): @@ -152,19 +186,22 @@ def test_train(self): trainer = LocalTrainer() trainer.train(pre_train_task) checkpoint = get_checkpoint_dict(pre_train_task, {}) - + unfreeze_points = self._get_unfrezee_points_to_unfrozen_params() + unfreeze_options = list(unfreeze_points) + [None] for reset_heads, heads_num_classes in [(False, 100), (True, 20)]: - for freeze_trunk in [True, False]: - fine_tuning_config = self._get_fine_tuning_config( - head_num_classes=heads_num_classes + for unfreeze_point in unfreeze_options: + fine_tuning_config = copy.deepcopy( + self._get_fine_tuning_config(head_num_classes=heads_num_classes) ) + # Extra epochs helps ensure that unfrozen parameters change value + fine_tuning_config["num_epochs"] = 4 fine_tuning_task = build_task(fine_tuning_config) fine_tuning_task = ( fine_tuning_task._set_pretrained_checkpoint_dict( copy.deepcopy(checkpoint) ) .set_reset_heads(reset_heads) - .set_freeze_trunk(freeze_trunk) + .set_freeze_until(unfreeze_point) ) # run in test mode to compare the model state fine_tuning_task.set_test_only(True) @@ -177,12 +214,14 @@ def test_train(self): # run in train mode to check accuracy fine_tuning_task.set_test_only(False) trainer.train(fine_tuning_task) - if freeze_trunk: - # if trunk is frozen the states should be the same + if unfreeze_point is not None: + # check that expected part of model is frozen + # and unfrozen part isn't frozen self._compare_model_state( pre_train_task.model.get_classy_state(), fine_tuning_task.model.get_classy_state(), check_heads=False, + state_changed_params=unfreeze_points[unfreeze_point], ) else: # trunk isn't frozen, the states should be different @@ -196,6 +235,40 @@ def test_train(self): accuracy = fine_tuning_task.meters[0].value["top_1"] self.assertAlmostEqual(accuracy, 1.0) + def test_freeze_trunk_backwards_compatability(self): + pre_train_config = self._get_pre_train_config(head_num_classes=100) + pre_train_task = build_task(pre_train_config) + trainer = LocalTrainer() + trainer.train(pre_train_task) + checkpoint = get_checkpoint_dict(pre_train_task, {}) + for reset_heads, heads_num_classes in [(False, 100), (True, 20)]: + fine_tuning_config = copy.deepcopy( + self._get_fine_tuning_config(head_num_classes=heads_num_classes) + ) + fine_tuning_config["freeze_trunk"] = True + with self.assertWarns(DeprecationWarning): + fine_tuning_task = build_task(fine_tuning_config) + fine_tuning_task = fine_tuning_task._set_pretrained_checkpoint_dict( + copy.deepcopy(checkpoint) + ).set_reset_heads(reset_heads) + fine_tuning_task.set_test_only(True) + trainer.train(fine_tuning_task) + self._compare_model_state( + pre_train_task.model.get_classy_state(), + fine_tuning_task.model.get_classy_state(), + check_heads=not reset_heads, + ) + # run in train mode to check accuracy + fine_tuning_task.set_test_only(False) + trainer.train(fine_tuning_task) + self._compare_model_state( + pre_train_task.model.get_classy_state(), + fine_tuning_task.model.get_classy_state(), + check_heads=False, + ) + accuracy = fine_tuning_task.meters[0].value["top_1"] + self.assertAlmostEqual(accuracy, 1.0) + def test_train_parametric_loss(self): heads_num_classes = 100 pre_train_config = self._get_pre_train_config(