Skip to content

Commit

Permalink
Rework Pytorch Hub support code (#202)
Browse files Browse the repository at this point in the history
Rework support code for torch.hub.load() to allow reusing shared functions and eventually expose more models.
  • Loading branch information
patricklabatut authored Sep 27, 2023
1 parent 6a62615 commit 9a4564c
Show file tree
Hide file tree
Showing 5 changed files with 250 additions and 173 deletions.
4 changes: 4 additions & 0 deletions dinov2/hub/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
84 changes: 84 additions & 0 deletions dinov2/hub/backbones.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.

from enum import Enum
from typing import Union

import torch

from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name


class Weights(Enum):
LVD142M = "LVD142M"


def _make_dinov2_model(
*,
arch_name: str = "vit_large",
img_size: int = 518,
patch_size: int = 14,
init_values: float = 1.0,
ffn_layer: str = "mlp",
block_chunks: int = 0,
pretrained: bool = True,
weights: Union[Weights, str] = Weights.LVD142M,
**kwargs,
):
from ..models import vision_transformer as vits

if isinstance(weights, str):
try:
weights = Weights[weights]
except KeyError:
raise AssertionError(f"Unsupported weights: {weights}")

model_name = _make_dinov2_model_name(arch_name, patch_size)
vit_kwargs = dict(
img_size=img_size,
patch_size=patch_size,
init_values=init_values,
ffn_layer=ffn_layer,
block_chunks=block_chunks,
)
vit_kwargs.update(**kwargs)
model = vits.__dict__[arch_name](**vit_kwargs)

if pretrained:
url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_pretrain.pth"
state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
model.load_state_dict(state_dict, strict=False)

return model


def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
"""
DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
"""
return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs)


def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
"""
DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
"""
return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs)


def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
"""
DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
"""
return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs)


def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
"""
DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
"""
return _make_dinov2_model(
arch_name="vit_giant2", ffn_layer="swiglufused", weights=weights, pretrained=pretrained, **kwargs
)
147 changes: 147 additions & 0 deletions dinov2/hub/classifiers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.

from enum import Enum
from typing import Union

import torch
import torch.nn as nn

from .backbones import _make_dinov2_model
from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name


class Weights(Enum):
IMAGENET1K = "IMAGENET1K"


def _make_dinov2_linear_classification_head(
*,
model_name: str = "dinov2_vitl14",
embed_dim: int = 1024,
layers: int = 4,
pretrained: bool = True,
weights: Union[Weights, str] = Weights.IMAGENET1K,
**kwargs,
):
if layers not in (1, 4):
raise AssertionError(f"Unsupported number of layers: {layers}")
if isinstance(weights, str):
try:
weights = Weights[weights]
except KeyError:
raise AssertionError(f"Unsupported weights: {weights}")

linear_head = nn.Linear((1 + layers) * embed_dim, 1_000)

if pretrained:
layers_str = str(layers) if layers == 4 else ""
url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_linear{layers_str}_head.pth"
state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
linear_head.load_state_dict(state_dict, strict=False)

return linear_head


class _LinearClassifierWrapper(nn.Module):
def __init__(self, *, backbone: nn.Module, linear_head: nn.Module, layers: int = 4):
super().__init__()
self.backbone = backbone
self.linear_head = linear_head
self.layers = layers

def forward(self, x):
if self.layers == 1:
x = self.backbone.forward_features(x)
cls_token = x["x_norm_clstoken"]
patch_tokens = x["x_norm_patchtokens"]
# fmt: off
linear_input = torch.cat([
cls_token,
patch_tokens.mean(dim=1),
], dim=1)
# fmt: on
elif self.layers == 4:
x = self.backbone.get_intermediate_layers(x, n=4, return_class_token=True)
# fmt: off
linear_input = torch.cat([
x[0][1],
x[1][1],
x[2][1],
x[3][1],
x[3][0].mean(dim=1),
], dim=1)
# fmt: on
else:
assert False, f"Unsupported number of layers: {self.layers}"
return self.linear_head(linear_input)


def _make_dinov2_linear_classifier(
*,
arch_name: str = "vit_large",
layers: int = 4,
pretrained: bool = True,
weights: Union[Weights, str] = Weights.IMAGENET1K,
**kwargs,
):
backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs)

embed_dim = backbone.embed_dim
patch_size = backbone.patch_size
model_name = _make_dinov2_model_name(arch_name, patch_size)
linear_head = _make_dinov2_linear_classification_head(
model_name=model_name,
embed_dim=embed_dim,
layers=layers,
pretrained=pretrained,
weights=weights,
)

return _LinearClassifierWrapper(backbone=backbone, linear_head=linear_head, layers=layers)


def dinov2_vits14_lc(
*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs
):
"""
Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-S/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
"""
return _make_dinov2_linear_classifier(
arch_name="vit_small", layers=layers, pretrained=pretrained, weights=weights, **kwargs
)


def dinov2_vitb14_lc(
*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs
):
"""
Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-B/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
"""
return _make_dinov2_linear_classifier(
arch_name="vit_base", layers=layers, pretrained=pretrained, weights=weights, **kwargs
)


def dinov2_vitl14_lc(
*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs
):
"""
Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-L/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
"""
return _make_dinov2_linear_classifier(
arch_name="vit_large", layers=layers, pretrained=pretrained, weights=weights, **kwargs
)


def dinov2_vitg14_lc(
*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs
):
"""
Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-g/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
"""
return _make_dinov2_linear_classifier(
arch_name="vit_giant2", layers=layers, ffn_layer="swiglufused", pretrained=pretrained, weights=weights, **kwargs
)
13 changes: 13 additions & 0 deletions dinov2/hub/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.

import torch.nn as nn

_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"


def _make_dinov2_model_name(arch_name: str, patch_size: int) -> str:
compact_arch_name = arch_name.replace("_", "")[:4]
return f"dinov2_{compact_arch_name}{patch_size}"
Loading

0 comments on commit 9a4564c

Please sign in to comment.