Skip to content

Commit

Permalink
add ignore keys if classes differ (#1271)
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 authored Jul 27, 2023
1 parent efe7ca0 commit 416c639
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 5 deletions.
32 changes: 31 additions & 1 deletion doctr/models/detection/differentiable_binarization/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ def _dbnet(
fpn_layers: List[str],
backbone_submodule: Optional[str] = None,
pretrained_backbone: bool = True,
ignore_keys: Optional[List[str]] = None,
**kwargs: Any,
) -> DBNet:
# Starting with Imagenet pretrained params introduces some NaNs in layer3 & layer4 of resnet50
Expand All @@ -312,7 +313,12 @@ def _dbnet(
model = DBNet(feat_extractor, cfg=default_cfgs[arch], **kwargs)
# Load pretrained parameters
if pretrained:
load_pretrained_params(model, default_cfgs[arch]["url"])
# The number of class_names is not the same as the number of classes in the pretrained model =>
# remove the layer weights
_ignore_keys = (
ignore_keys if kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]) else None
)
load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)

return model

Expand Down Expand Up @@ -340,6 +346,12 @@ def db_resnet34(pretrained: bool = False, **kwargs: Any) -> DBNet:
resnet34,
["layer1", "layer2", "layer3", "layer4"],
None,
ignore_keys=[
"prob_head.6.weight",
"prob_head.6.bias",
"thresh_head.6.weight",
"thresh_head.6.bias",
],
**kwargs,
)

Expand Down Expand Up @@ -367,6 +379,12 @@ def db_resnet50(pretrained: bool = False, **kwargs: Any) -> DBNet:
resnet50,
["layer1", "layer2", "layer3", "layer4"],
None,
ignore_keys=[
"prob_head.6.weight",
"prob_head.6.bias",
"thresh_head.6.weight",
"thresh_head.6.bias",
],
**kwargs,
)

Expand Down Expand Up @@ -394,6 +412,12 @@ def db_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> DBNet:
mobilenet_v3_large,
["3", "6", "12", "16"],
"features",
ignore_keys=[
"prob_head.6.weight",
"prob_head.6.bias",
"thresh_head.6.weight",
"thresh_head.6.bias",
],
**kwargs,
)

Expand Down Expand Up @@ -422,5 +446,11 @@ def db_resnet50_rotation(pretrained: bool = False, **kwargs: Any) -> DBNet:
resnet50,
["layer1", "layer2", "layer3", "layer4"],
None,
ignore_keys=[
"prob_head.6.weight",
"prob_head.6.bias",
"thresh_head.6.weight",
"thresh_head.6.bias",
],
**kwargs,
)
44 changes: 40 additions & 4 deletions doctr/models/detection/linknet/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ def _linknet(
backbone_fn: Callable[[bool], nn.Module],
fpn_layers: List[str],
pretrained_backbone: bool = True,
ignore_keys: Optional[List[str]] = None,
**kwargs: Any,
) -> LinkNet:
pretrained_backbone = pretrained_backbone and not pretrained
Expand All @@ -266,7 +267,12 @@ def _linknet(
model = LinkNet(feat_extractor, cfg=default_cfgs[arch], **kwargs)
# Load pretrained parameters
if pretrained:
load_pretrained_params(model, default_cfgs[arch]["url"])
# The number of class_names is not the same as the number of classes in the pretrained model =>
# remove the layer weights
_ignore_keys = (
ignore_keys if kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]) else None
)
load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)

return model

Expand All @@ -288,7 +294,17 @@ def linknet_resnet18(pretrained: bool = False, **kwargs: Any) -> LinkNet:
text detection architecture
"""

return _linknet("linknet_resnet18", pretrained, resnet18, ["layer1", "layer2", "layer3", "layer4"], **kwargs)
return _linknet(
"linknet_resnet18",
pretrained,
resnet18,
["layer1", "layer2", "layer3", "layer4"],
ignore_keys=[
"classifier.6.weight",
"classifier.6.bias",
],
**kwargs,
)


def linknet_resnet34(pretrained: bool = False, **kwargs: Any) -> LinkNet:
Expand All @@ -308,7 +324,17 @@ def linknet_resnet34(pretrained: bool = False, **kwargs: Any) -> LinkNet:
text detection architecture
"""

return _linknet("linknet_resnet34", pretrained, resnet34, ["layer1", "layer2", "layer3", "layer4"], **kwargs)
return _linknet(
"linknet_resnet34",
pretrained,
resnet34,
["layer1", "layer2", "layer3", "layer4"],
ignore_keys=[
"classifier.6.weight",
"classifier.6.bias",
],
**kwargs,
)


def linknet_resnet50(pretrained: bool = False, **kwargs: Any) -> LinkNet:
Expand All @@ -328,4 +354,14 @@ def linknet_resnet50(pretrained: bool = False, **kwargs: Any) -> LinkNet:
text detection architecture
"""

return _linknet("linknet_resnet50", pretrained, resnet50, ["layer1", "layer2", "layer3", "layer4"], **kwargs)
return _linknet(
"linknet_resnet50",
pretrained,
resnet50,
["layer1", "layer2", "layer3", "layer4"],
ignore_keys=[
"classifier.6.weight",
"classifier.6.bias",
],
**kwargs,
)

0 comments on commit 416c639

Please sign in to comment.