Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

Support freezing model anywhere in fine tuning #728

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 74 additions & 13 deletions classy_vision/tasks/fine_tuning_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict
import warnings
from enum import Enum
from typing import Any, Callable, Dict, Union

from classy_vision.generic.util import (
load_and_broadcast_checkpoint,
Expand All @@ -13,15 +15,42 @@
from classy_vision.tasks import ClassificationTask, register_task


class FreezeUntil(Enum):
"""
Enum for a pre-specified point to freeze the classy model unitl.

Attributes:
HEAD (str): Freeze the model unitl the classy head
"""

HEAD = "head"

def __eq__(self, other: str):
return other.lower() == self.value


@register_task("fine_tuning")
class FineTuningTask(ClassificationTask):
"""Finetuning training task.

This task encapsultates all of the components and steps needed to
fine-tune a classifier using a :class:`classy_vision.trainer.ClassyTrainer`.

:var pretrained_checkpoint_path: String path to pretrained model
:var reset_heads: bool. Whether or not to reset the model heads during finetuning.
:var freeze_until: optional string. If specified, must be a string name of a module within
the model. Finetuning will freeze the model up to this module. Model weights will
only be trainable from this modeule onwards, always including the head. To freeze the
trunk model, specify 'head' as the un-freeze point.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.pretrained_checkpoint_dict = None
self.pretrained_checkpoint_path = None
self.pretrained_checkpoint_load_strict = True
self.reset_heads = False
self.freeze_trunk = False
self.freeze_until = None

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "FineTuningTask":
Expand All @@ -44,7 +73,13 @@ def from_config(cls, config: Dict[str, Any]) -> "FineTuningTask":
)

task.set_reset_heads(config.get("reset_heads", False))
task.set_freeze_trunk(config.get("freeze_trunk", False))
assert (
"freeze_trunk" not in config or "freeze_until" not in config
), "Config options 'freeze_trunk' and 'freeze_until' cannot both be specified"
if "freeze_trunk" in config:
task.set_freeze_trunk(config.get("freeze_trunk", False))
else:
task.set_freeze_until(config.get("freeze_until", None))
return task

def set_pretrained_checkpoint(self, checkpoint_path: str) -> "FineTuningTask":
Expand All @@ -68,22 +103,46 @@ def set_reset_heads(self, reset_heads: bool) -> "FineTuningTask":
return self

def set_freeze_trunk(self, freeze_trunk: bool) -> "FineTuningTask":
self.freeze_trunk = freeze_trunk
if freeze_trunk:
self.freeze_until = FreezeUntil.HEAD.value
warnings.warn(
"Congig option freeze_trunk has been deprecated. "
"Use \"freeze_until:'head'\" instead",
DeprecationWarning,
)

return self

def set_freeze_until(self, freeze_until: Union[str, None]) -> "FineTuningTask":
self.freeze_until = freeze_until
return self

def _set_model_train_mode(self):
phase = self.phases[self.phase_idx]
self.loss.train(phase["train"])

if self.freeze_trunk:
if self.freeze_until is not None:
# convert all the sub-modules to the eval mode, except the heads
self.base_model.eval()
for heads in self.base_model.get_heads().values():
for h in heads:
h.train(phase["train"])
self._apply_to_nonfrozen(lambda x: x.train(phase["train"]))
else:
self.base_model.train(phase["train"])

def _apply_to_nonfrozen(self, callable: Callable[..., Any]) -> None:
for heads in self.base_model.get_heads().values():
for h in heads:
callable(h)
if not self.freeze_until == FreezeUntil.HEAD:
unfrozen_module = False
for name, module in self.base_model.named_modules():
if name == self.freeze_until:
unfrozen_module = True
if unfrozen_module:
callable(module)
assert (
unfrozen_module
), f"Freeze until point {self.freeze_until} not found in model"

def prepare(self) -> None:
super().prepare()
if self.checkpoint_dict is None:
Expand All @@ -109,15 +168,17 @@ def prepare(self) -> None:
state_load_success
), "Update classy state from pretrained checkpoint was unsuccessful."

if self.freeze_trunk:
if self.freeze_until is not None:
# do not track gradients for all the parameters in the model except
# for the parameters in the heads
for param in self.base_model.parameters():
param.requires_grad = False
for heads in self.base_model.get_heads().values():
for h in heads:
for param in h.parameters():
param.requires_grad = True

def _set_requires_grad_true(x):
for param in x.parameters():
param.requires_grad = True

self._apply_to_nonfrozen(_set_requires_grad_true)
# re-create ddp model
self.distributed_model = None
self.init_distributed_data_parallel_model()
9 changes: 8 additions & 1 deletion test/generic/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,15 @@ def recursive_unpack(batch):
raise TypeError("Unexpected type %s passed to unpack" % type(batch))


def compare_model_state(test_fixture, state, state2, check_heads=True):
def compare_model_state(
test_fixture, state, state2, check_heads=True, state_changed_params=()
):
for k in state["model"]["trunk"].keys():
if k in state_changed_params:
test_fixture.assertFalse(
torch.allclose(state["model"]["trunk"][k], state2["model"]["trunk"][k])
)
continue
if not torch.allclose(state["model"]["trunk"][k], state2["model"]["trunk"][k]):
print(k, state["model"]["trunk"][k], state2["model"]["trunk"][k])
test_fixture.assertTrue(
Expand Down
91 changes: 82 additions & 9 deletions test/tasks_fine_tuning_task_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,47 @@ def forward(self, x, target):


class TestFineTuningTask(unittest.TestCase):
def _compare_model_state(self, state_1, state_2, check_heads=True):
return compare_model_state(self, state_1, state_2, check_heads=check_heads)
def _compare_model_state(
self, state_1, state_2, check_heads=True, state_changed_params=()
):
return compare_model_state(
self,
state_1,
state_2,
check_heads=check_heads,
state_changed_params=state_changed_params,
)

def _compare_state_dict(self, state_1, state_2, check_heads=True):
for k in state_1.keys():
self.assertTrue(torch.allclose(state_1[k].cpu(), state_2[k].cpu()))

def _get_unfrezee_points_to_unfrozen_params(self):
return {
"blocks.0.block0-0._module.downsample.1": (
"blocks.0.block0-0.downsample.1.weight",
"blocks.0.block0-0.downsample.1.bias",
"blocks.0.block0-0.downsample.1.running_mean",
"blocks.0.block0-0.downsample.1.running_var",
"blocks.0.block0-0.downsample.1.num_batches_tracked",
),
"blocks.0.block0-0._module.convolutional_block.6": (
"blocks.0.block0-0.convolutional_block.6.weight",
"blocks.0.block0-0.bn.weight",
"blocks.0.block0-0.bn.bias",
"blocks.0.block0-0.bn.running_mean",
"blocks.0.block0-0.bn.running_var",
"blocks.0.block0-0.bn.num_batches_tracked",
"blocks.0.block0-0.downsample.0.weight",
"blocks.0.block0-0.downsample.1.weight",
"blocks.0.block0-0.downsample.1.bias",
"blocks.0.block0-0.downsample.1.running_mean",
"blocks.0.block0-0.downsample.1.running_var",
"blocks.0.block0-0.downsample.1.num_batches_tracked",
),
"head": (),
}

def _get_fine_tuning_config(
self, head_num_classes=100, pretrained_checkpoint=False
):
Expand Down Expand Up @@ -152,19 +186,22 @@ def test_train(self):
trainer = LocalTrainer()
trainer.train(pre_train_task)
checkpoint = get_checkpoint_dict(pre_train_task, {})

unfreeze_points = self._get_unfrezee_points_to_unfrozen_params()
unfreeze_options = list(unfreeze_points) + [None]
for reset_heads, heads_num_classes in [(False, 100), (True, 20)]:
for freeze_trunk in [True, False]:
fine_tuning_config = self._get_fine_tuning_config(
head_num_classes=heads_num_classes
for unfreeze_point in unfreeze_options:
fine_tuning_config = copy.deepcopy(
self._get_fine_tuning_config(head_num_classes=heads_num_classes)
)
# Extra epochs helps ensure that unfrozen parameters change value
fine_tuning_config["num_epochs"] = 4
fine_tuning_task = build_task(fine_tuning_config)
fine_tuning_task = (
fine_tuning_task._set_pretrained_checkpoint_dict(
copy.deepcopy(checkpoint)
)
.set_reset_heads(reset_heads)
.set_freeze_trunk(freeze_trunk)
.set_freeze_until(unfreeze_point)
)
# run in test mode to compare the model state
fine_tuning_task.set_test_only(True)
Expand All @@ -177,12 +214,14 @@ def test_train(self):
# run in train mode to check accuracy
fine_tuning_task.set_test_only(False)
trainer.train(fine_tuning_task)
if freeze_trunk:
# if trunk is frozen the states should be the same
if unfreeze_point is not None:
# check that expected part of model is frozen
# and unfrozen part isn't frozen
self._compare_model_state(
pre_train_task.model.get_classy_state(),
fine_tuning_task.model.get_classy_state(),
check_heads=False,
state_changed_params=unfreeze_points[unfreeze_point],
)
else:
# trunk isn't frozen, the states should be different
Expand All @@ -196,6 +235,40 @@ def test_train(self):
accuracy = fine_tuning_task.meters[0].value["top_1"]
self.assertAlmostEqual(accuracy, 1.0)

def test_freeze_trunk_backwards_compatability(self):
pre_train_config = self._get_pre_train_config(head_num_classes=100)
pre_train_task = build_task(pre_train_config)
trainer = LocalTrainer()
trainer.train(pre_train_task)
checkpoint = get_checkpoint_dict(pre_train_task, {})
for reset_heads, heads_num_classes in [(False, 100), (True, 20)]:
fine_tuning_config = copy.deepcopy(
self._get_fine_tuning_config(head_num_classes=heads_num_classes)
)
fine_tuning_config["freeze_trunk"] = True
with self.assertWarns(DeprecationWarning):
fine_tuning_task = build_task(fine_tuning_config)
fine_tuning_task = fine_tuning_task._set_pretrained_checkpoint_dict(
copy.deepcopy(checkpoint)
).set_reset_heads(reset_heads)
fine_tuning_task.set_test_only(True)
trainer.train(fine_tuning_task)
self._compare_model_state(
pre_train_task.model.get_classy_state(),
fine_tuning_task.model.get_classy_state(),
check_heads=not reset_heads,
)
# run in train mode to check accuracy
fine_tuning_task.set_test_only(False)
trainer.train(fine_tuning_task)
self._compare_model_state(
pre_train_task.model.get_classy_state(),
fine_tuning_task.model.get_classy_state(),
check_heads=False,
)
accuracy = fine_tuning_task.meters[0].value["top_1"]
self.assertAlmostEqual(accuracy, 1.0)

def test_train_parametric_loss(self):
heads_num_classes = 100
pre_train_config = self._get_pre_train_config(
Expand Down