From 9a4564ce5ebfe66a37fd16c6a233fb04ffb0a752 Mon Sep 17 00:00:00 2001 From: Patrick Labatut <60359573+patricklabatut@users.noreply.github.com> Date: Wed, 27 Sep 2023 17:06:03 +0200 Subject: [PATCH] Rework Pytorch Hub support code (#202) Rework support code for torch.hub.load() to allow reusing shared functions and eventually expose more models. --- dinov2/hub/__init__.py | 4 + dinov2/hub/backbones.py | 84 ++++++++++++++++++ dinov2/hub/classifiers.py | 147 ++++++++++++++++++++++++++++++++ dinov2/hub/utils.py | 13 +++ hubconf.py | 175 +------------------------------------- 5 files changed, 250 insertions(+), 173 deletions(-) create mode 100644 dinov2/hub/__init__.py create mode 100644 dinov2/hub/backbones.py create mode 100644 dinov2/hub/classifiers.py create mode 100644 dinov2/hub/utils.py diff --git a/dinov2/hub/__init__.py b/dinov2/hub/__init__.py new file mode 100644 index 000000000..b88da6bf8 --- /dev/null +++ b/dinov2/hub/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. diff --git a/dinov2/hub/backbones.py b/dinov2/hub/backbones.py new file mode 100644 index 000000000..17e00981f --- /dev/null +++ b/dinov2/hub/backbones.py @@ -0,0 +1,84 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from enum import Enum +from typing import Union + +import torch + +from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name + + +class Weights(Enum): + LVD142M = "LVD142M" + + +def _make_dinov2_model( + *, + arch_name: str = "vit_large", + img_size: int = 518, + patch_size: int = 14, + init_values: float = 1.0, + ffn_layer: str = "mlp", + block_chunks: int = 0, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.LVD142M, + **kwargs, +): + from ..models import vision_transformer as vits + + if isinstance(weights, str): + try: + weights = Weights[weights] + except KeyError: + raise AssertionError(f"Unsupported weights: {weights}") + + model_name = _make_dinov2_model_name(arch_name, patch_size) + vit_kwargs = dict( + img_size=img_size, + patch_size=patch_size, + init_values=init_values, + ffn_layer=ffn_layer, + block_chunks=block_chunks, + ) + vit_kwargs.update(**kwargs) + model = vits.__dict__[arch_name](**vit_kwargs) + + if pretrained: + url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_pretrain.pth" + state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") + model.load_state_dict(state_dict, strict=False) + + return model + + +def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_giant2", ffn_layer="swiglufused", weights=weights, pretrained=pretrained, **kwargs + ) diff --git a/dinov2/hub/classifiers.py b/dinov2/hub/classifiers.py new file mode 100644 index 000000000..636a732c1 --- /dev/null +++ b/dinov2/hub/classifiers.py @@ -0,0 +1,147 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from enum import Enum +from typing import Union + +import torch +import torch.nn as nn + +from .backbones import _make_dinov2_model +from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name + + +class Weights(Enum): + IMAGENET1K = "IMAGENET1K" + + +def _make_dinov2_linear_classification_head( + *, + model_name: str = "dinov2_vitl14", + embed_dim: int = 1024, + layers: int = 4, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.IMAGENET1K, + **kwargs, +): + if layers not in (1, 4): + raise AssertionError(f"Unsupported number of layers: {layers}") + if isinstance(weights, str): + try: + weights = Weights[weights] + except KeyError: + raise AssertionError(f"Unsupported weights: {weights}") + + linear_head = nn.Linear((1 + layers) * embed_dim, 1_000) + + if pretrained: + layers_str = str(layers) if layers == 4 else "" + url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_linear{layers_str}_head.pth" + state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") + linear_head.load_state_dict(state_dict, strict=False) + + return linear_head + + +class _LinearClassifierWrapper(nn.Module): + def __init__(self, *, backbone: nn.Module, linear_head: nn.Module, layers: int = 4): + super().__init__() + self.backbone = backbone + self.linear_head = linear_head + self.layers = layers + + def forward(self, x): + if self.layers == 1: + x = self.backbone.forward_features(x) + cls_token = x["x_norm_clstoken"] + patch_tokens = x["x_norm_patchtokens"] + # fmt: off + linear_input = torch.cat([ + cls_token, + patch_tokens.mean(dim=1), + ], dim=1) + # fmt: on + elif self.layers == 4: + x = self.backbone.get_intermediate_layers(x, n=4, return_class_token=True) + # fmt: off + linear_input = torch.cat([ + x[0][1], + x[1][1], + x[2][1], + x[3][1], + x[3][0].mean(dim=1), + ], dim=1) + # fmt: on + else: + assert False, f"Unsupported number of layers: {self.layers}" + return self.linear_head(linear_input) + + +def _make_dinov2_linear_classifier( + *, + arch_name: str = "vit_large", + layers: int = 4, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.IMAGENET1K, + **kwargs, +): + backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs) + + embed_dim = backbone.embed_dim + patch_size = backbone.patch_size + model_name = _make_dinov2_model_name(arch_name, patch_size) + linear_head = _make_dinov2_linear_classification_head( + model_name=model_name, + embed_dim=embed_dim, + layers=layers, + pretrained=pretrained, + weights=weights, + ) + + return _LinearClassifierWrapper(backbone=backbone, linear_head=linear_head, layers=layers) + + +def dinov2_vits14_lc( + *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-S/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_small", layers=layers, pretrained=pretrained, weights=weights, **kwargs + ) + + +def dinov2_vitb14_lc( + *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-B/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_base", layers=layers, pretrained=pretrained, weights=weights, **kwargs + ) + + +def dinov2_vitl14_lc( + *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-L/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_large", layers=layers, pretrained=pretrained, weights=weights, **kwargs + ) + + +def dinov2_vitg14_lc( + *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-g/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_giant2", layers=layers, ffn_layer="swiglufused", pretrained=pretrained, weights=weights, **kwargs + ) diff --git a/dinov2/hub/utils.py b/dinov2/hub/utils.py new file mode 100644 index 000000000..f1829b4ce --- /dev/null +++ b/dinov2/hub/utils.py @@ -0,0 +1,13 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch.nn as nn + +_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2" + + +def _make_dinov2_model_name(arch_name: str, patch_size: int) -> str: + compact_arch_name = arch_name.replace("_", "")[:4] + return f"dinov2_{compact_arch_name}{patch_size}" diff --git a/hubconf.py b/hubconf.py index fc6a4eee7..b3b448373 100644 --- a/hubconf.py +++ b/hubconf.py @@ -3,179 +3,8 @@ # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. -import torch -import torch.nn as nn +from dinov2.hub.backbones import dinov2_vitb14, dinov2_vitg14, dinov2_vitl14, dinov2_vits14 +from dinov2.hub.classifiers import dinov2_vitb14_lc, dinov2_vitg14_lc, dinov2_vitl14_lc, dinov2_vits14_lc dependencies = ["torch"] - - -_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2" - - -def _make_dinov2_model_name(arch_name: str, patch_size: int) -> str: - compact_arch_name = arch_name.replace("_", "")[:4] - return f"dinov2_{compact_arch_name}{patch_size}" - - -def _make_dinov2_model( - *, - arch_name: str = "vit_large", - img_size: int = 518, - patch_size: int = 14, - init_values: float = 1.0, - ffn_layer: str = "mlp", - block_chunks: int = 0, - pretrained: bool = True, - **kwargs, -): - from dinov2.models import vision_transformer as vits - - model_name = _make_dinov2_model_name(arch_name, patch_size) - vit_kwargs = dict( - img_size=img_size, - patch_size=patch_size, - init_values=init_values, - ffn_layer=ffn_layer, - block_chunks=block_chunks, - ) - vit_kwargs.update(**kwargs) - model = vits.__dict__[arch_name](**vit_kwargs) - - if pretrained: - url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_pretrain.pth" - state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") - model.load_state_dict(state_dict, strict=False) - - return model - - -def dinov2_vits14(*, pretrained: bool = True, **kwargs): - """ - DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset. - """ - return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, **kwargs) - - -def dinov2_vitb14(*, pretrained: bool = True, **kwargs): - """ - DINOv2 ViT-B/14 model pretrained on the LVD-142M dataset. - """ - return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, **kwargs) - - -def dinov2_vitl14(*, pretrained: bool = True, **kwargs): - """ - DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset. - """ - return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, **kwargs) - - -def dinov2_vitg14(*, pretrained: bool = True, **kwargs): - """ - DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset. - """ - return _make_dinov2_model(arch_name="vit_giant2", ffn_layer="swiglufused", pretrained=pretrained, **kwargs) - - -def _make_dinov2_linear_head( - *, - model_name: str = "dinov2_vitl14", - embed_dim: int = 1024, - layers: int = 4, - pretrained: bool = True, - **kwargs, -): - assert layers in (1, 4), f"Unsupported number of layers: {layers}" - linear_head = nn.Linear((1 + layers) * embed_dim, 1_000) - - if pretrained: - layers_str = str(layers) if layers == 4 else "" - url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_linear{layers_str}_head.pth" - state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") - linear_head.load_state_dict(state_dict, strict=False) - - return linear_head - - -class _LinearClassifierWrapper(nn.Module): - def __init__(self, *, backbone: nn.Module, linear_head: nn.Module, layers: int = 4): - super().__init__() - self.backbone = backbone - self.linear_head = linear_head - self.layers = layers - - def forward(self, x): - if self.layers == 1: - x = self.backbone.forward_features(x) - cls_token = x["x_norm_clstoken"] - patch_tokens = x["x_norm_patchtokens"] - # fmt: off - linear_input = torch.cat([ - cls_token, - patch_tokens.mean(dim=1), - ], dim=1) - # fmt: on - elif self.layers == 4: - x = self.backbone.get_intermediate_layers(x, n=4, return_class_token=True) - # fmt: off - linear_input = torch.cat([ - x[0][1], - x[1][1], - x[2][1], - x[3][1], - x[3][0].mean(dim=1), - ], dim=1) - # fmt: on - else: - assert False, f"Unsupported number of layers: {self.layers}" - return self.linear_head(linear_input) - - -def _make_dinov2_linear_classifier( - *, - arch_name: str = "vit_large", - layers: int = 4, - pretrained: bool = True, - **kwargs, -): - backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs) - - embed_dim = backbone.embed_dim - patch_size = backbone.patch_size - model_name = _make_dinov2_model_name(arch_name, patch_size) - linear_head = _make_dinov2_linear_head( - model_name=model_name, embed_dim=embed_dim, layers=layers, pretrained=pretrained - ) - - return _LinearClassifierWrapper(backbone=backbone, linear_head=linear_head, layers=layers) - - -def dinov2_vits14_lc(*, layers: int = 4, pretrained: bool = True, **kwargs): - """ - Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-S/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. - """ - return _make_dinov2_linear_classifier(arch_name="vit_small", layers=layers, pretrained=pretrained, **kwargs) - - -def dinov2_vitb14_lc(*, layers: int = 4, pretrained: bool = True, **kwargs): - """ - Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-B/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. - """ - return _make_dinov2_linear_classifier(arch_name="vit_base", layers=layers, pretrained=pretrained, **kwargs) - - -def dinov2_vitl14_lc(*, layers: int = 4, pretrained: bool = True, **kwargs): - """ - Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-L/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. - """ - return _make_dinov2_linear_classifier(arch_name="vit_large", layers=layers, pretrained=pretrained, **kwargs) - - -def dinov2_vitg14_lc(*, layers: int = 4, pretrained: bool = True, **kwargs): - """ - Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-g/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. - """ - return _make_dinov2_linear_classifier( - arch_name="vit_giant2", layers=layers, ffn_layer="swiglufused", pretrained=pretrained, **kwargs - )