Skip to content

Commit

Permalink
Linting, fix giant models, missing imports, ...
Browse files Browse the repository at this point in the history
  • Loading branch information
Patrick Labatut committed Sep 30, 2023
1 parent f8611e3 commit 841ba94
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 96 deletions.
31 changes: 19 additions & 12 deletions dinov2/hub/depth/decode_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,26 @@
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.

import copy

import torch
import torch.nn as nn

from .ops import resize


# XXX: (Untested) replacement for mmcv.imdenormalize()
def _imdenormalize(img, mean, std, to_bgr=True):
import numpy as np

mean = mean.reshape(1, -1).astype(np.float64)
std = std.reshape(1, -1).astype(np.float64)
img = (img * std) + mean
if to_bgr:
img = img[::-1]
return img


class DepthBaseDecodeHead(nn.Module):
"""Base class for BaseDecodeHead.
Expand Down Expand Up @@ -59,12 +73,7 @@ def __init__(

self.in_channels = in_channels
self.channels = channels
if isinstance(loss_decode, dict):
self.loss_decode = build_loss(loss_decode)
elif isinstance(loss_decode, (list, tuple)):
self.loss_decode = nn.ModuleList()
for loss in loss_decode:
self.loss_decode.append(build_loss(loss))
self.loss_decode = loss_decode
self.align_corners = align_corners
self.min_depth = min_depth
self.max_depth = max_depth
Expand Down Expand Up @@ -177,9 +186,11 @@ def losses(self, depth_pred, depth_gt):
return loss

def log_images(self, img_path, depth_pred, depth_gt, img_meta):
import numpy as np

show_img = copy.deepcopy(img_path.detach().cpu().permute(1, 2, 0))
show_img = show_img.numpy().astype(np.float32)
show_img = mmcv.imdenormalize(
show_img = _imdenormalize(
show_img,
img_meta["img_norm_cfg"]["mean"],
img_meta["img_norm_cfg"]["std"],
Expand All @@ -203,11 +214,7 @@ def log_images(self, img_path, depth_pred, depth_gt, img_meta):
class BNHead(DepthBaseDecodeHead):
"""Just a batchnorm."""

def __init__(self,
input_transform="resize_concat",
in_index=(0, 1, 2, 3),
upsample=1,
**kwargs):
def __init__(self, input_transform="resize_concat", in_index=(0, 1, 2, 3), upsample=1, **kwargs):
super().__init__(**kwargs)
self.input_transform = input_transform
self.in_index = in_index
Expand Down
24 changes: 24 additions & 0 deletions dinov2/hub/depth/encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,34 @@
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.

from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F

from .ops import resize


def add_prefix(inputs, prefix):
"""Add prefix for dict.
Args:
inputs (dict): The input dict with str keys.
prefix (str): The prefix to add.
Returns:
dict: The dict with keys updated with ``prefix``.
"""

outputs = dict()
for name, value in inputs.items():
outputs[f"{prefix}.{name}"] = value

return outputs


class DepthEncoderDecoder(nn.Module):
"""Encoder Decoder depther.
Expand Down Expand Up @@ -294,6 +316,8 @@ def val_step(self, data_batch, **kwargs):

@staticmethod
def _parse_losses(losses):
import torch.distributed as dist

"""Parse the raw outputs (losses) of the network.
Args:
Expand Down
9 changes: 3 additions & 6 deletions dinov2/hub/depth/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,12 @@
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.

import warnings

import torch.nn.functional as F


def resize(input,
size=None,
scale_factor=None,
mode="nearest",
align_corners=None,
warning=False):
def resize(input, size=None, scale_factor=None, mode="nearest", align_corners=None, warning=False):
if warning:
if size is not None and align_corners:
input_h, input_w = tuple(int(x) for x in input.shape[2:])
Expand Down
110 changes: 41 additions & 69 deletions dinov2/hub/depthers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@

from enum import Enum
from functools import partial
from typing import Union

import torch

from .backbones import _make_dinov2_model
from .backbones import _make_dinov2_model
from .depth import BNHead, DepthEncoderDecoder
from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name, CenterPadding

Expand Down Expand Up @@ -55,17 +56,18 @@ def _make_dinov2_linear_depther(
arch_name: str = "vit_large",
layers: int = 4,
pretrained: bool = True,
weights: str = Weights.NYU.value,
weights: Union[Weights, str] = Weights.NYU,
**kwargs,
):
if layers not in (1, 4):
raise AssertionError(f"Unsupported number of layers: {layers}")
if weights not in (weights.value for weights in Weights):
raise AssertionError(f"Unsupported weights: {weights}")
if isinstance(weights, str):
try:
weights = Weights[weights]
except KeyError:
raise AssertionError(f"Unsupported weights: {weights}")

backbone = _make_dinov2_model(arch_name=arch_name,
pretrained=pretrained,
**kwargs)
backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs)

embed_dim = backbone.embed_dim
patch_size = backbone.patch_size
Expand All @@ -76,23 +78,20 @@ def _make_dinov2_linear_depther(
layers=layers,
)

layer_counts = {
layer_count = {
"vit_small": 12,
"vit_base": 12,
"vit_large": 24,
"vit_giant2": 40,
}
layer_count = layer_counts[arch_name]

out_indices = {
"vit_small": [2, 5, 8, 11],
"vit_base": [2, 5, 8, 11],
"vit_large": [4, 11, 17, 23],
"vit_giant2": [9, 19, 29, 39],
}
}[arch_name]

if layers == 4:
out_index = out_indices[arch_name]
out_index = {
"vit_small": [2, 5, 8, 11],
"vit_base": [2, 5, 8, 11],
"vit_large": [4, 11, 17, 23],
"vit_giant2": [9, 19, 29, 39],
}[arch_name]
else:
assert layers == 1
out_index = [layer_count - 1]
Expand All @@ -105,15 +104,12 @@ def _make_dinov2_linear_depther(
return_class_token=True,
norm=False,
)
model.backbone.register_forward_pre_hook(
lambda _, x: CenterPadding(patch_size)(x[0])
)
model.backbone.register_forward_pre_hook(lambda _, x: CenterPadding(patch_size)(x[0]))

if pretrained:
layers_str = str(layers) if layers == 4 else ""
weights_str = weights.lower()
weights_str = weights.value.lower()
url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_{weights_str}_linear{layers_str}_head.pth"
print(url)
checkpoint = torch.hub.load_state_dict_from_url(url, map_location="cpu")
if "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
Expand All @@ -122,49 +118,25 @@ def _make_dinov2_linear_depther(
return model


def dinov2_vits14_ld(*,
layers: int = 4,
pretrained: bool = True,
weights: str = Weights.NYU.value,
**kwargs):
return _make_dinov2_linear_depther(arch_name="vit_small",
layers=layers,
pretrained=pretrained,
weights=weights,
**kwargs)


def dinov2_vitb14_ld(*,
layers: int = 4,
pretrained: bool = True,
weights: str = Weights.NYU.value,
**kwargs):
return _make_dinov2_linear_depther(arch_name="vit_base",
layers=layers,
pretrained=pretrained,
weights=weights,
**kwargs)


def dinov2_vitl14_ld(*,
layers: int = 4,
pretrained: bool = True,
weights: str = Weights.NYU.value,
**kwargs):
return _make_dinov2_linear_depther(arch_name="vit_large",
layers=layers,
pretrained=pretrained,
weights=weights,
**kwargs)


def dinov2_vitg14_ld(*,
layers: int = 4,
pretrained: bool = True,
weights: str = Weights.NYU.value,
**kwargs):
return _make_dinov2_linear_depther(arch_name="vit_giant2",
layers=layers,
pretrained=pretrained,
weights=weights,
**kwargs)
def dinov2_vits14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
return _make_dinov2_linear_depther(
arch_name="vit_small", layers=layers, pretrained=pretrained, weights=weights, **kwargs
)


def dinov2_vitb14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
return _make_dinov2_linear_depther(
arch_name="vit_base", layers=layers, pretrained=pretrained, weights=weights, **kwargs
)


def dinov2_vitl14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
return _make_dinov2_linear_depther(
arch_name="vit_large", layers=layers, pretrained=pretrained, weights=weights, **kwargs
)


def dinov2_vitg14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
return _make_dinov2_linear_depther(
arch_name="vit_giant2", layers=layers, ffn_layer="swiglufused", pretrained=pretrained, weights=weights, **kwargs
)
4 changes: 1 addition & 3 deletions dinov2/hub/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ def _get_pad(self, size):

@torch.inference_mode()
def forward(self, x):
pads = list(itertools.chain.from_iterable(
self._get_pad(m) for m in x.shape[:1:-1])
)
pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1]))
output = F.pad(x, pads)
return output
7 changes: 1 addition & 6 deletions hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,7 @@

from dinov2.hub.backbones import dinov2_vitb14, dinov2_vitg14, dinov2_vitl14, dinov2_vits14
from dinov2.hub.classifiers import dinov2_vitb14_lc, dinov2_vitg14_lc, dinov2_vitl14_lc, dinov2_vits14_lc
from dinov2.hub.depthers import (
dinov2_vitb14_ld,
dinov2_vitg14_ld,
dinov2_vitl14_ld,
dinov2_vits14_ld,
)
from dinov2.hub.depthers import dinov2_vitb14_ld, dinov2_vitg14_ld, dinov2_vitl14_ld, dinov2_vits14_ld


dependencies = ["torch"]

0 comments on commit 841ba94

Please sign in to comment.