diff --git a/doctr/models/detection/differentiable_binarization/pytorch.py b/doctr/models/detection/differentiable_binarization/pytorch.py index 597d4b04d7..919de2044a 100644 --- a/doctr/models/detection/differentiable_binarization/pytorch.py +++ b/doctr/models/detection/differentiable_binarization/pytorch.py @@ -289,6 +289,7 @@ def _dbnet( fpn_layers: List[str], backbone_submodule: Optional[str] = None, pretrained_backbone: bool = True, + ignore_keys: Optional[List[str]] = None, **kwargs: Any, ) -> DBNet: # Starting with Imagenet pretrained params introduces some NaNs in layer3 & layer4 of resnet50 @@ -312,7 +313,12 @@ def _dbnet( model = DBNet(feat_extractor, cfg=default_cfgs[arch], **kwargs) # Load pretrained parameters if pretrained: - load_pretrained_params(model, default_cfgs[arch]["url"]) + # The number of class_names is not the same as the number of classes in the pretrained model => + # remove the layer weights + _ignore_keys = ( + ignore_keys if kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]) else None + ) + load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys) return model @@ -340,6 +346,12 @@ def db_resnet34(pretrained: bool = False, **kwargs: Any) -> DBNet: resnet34, ["layer1", "layer2", "layer3", "layer4"], None, + ignore_keys=[ + "prob_head.6.weight", + "prob_head.6.bias", + "thresh_head.6.weight", + "thresh_head.6.bias", + ], **kwargs, ) @@ -367,6 +379,12 @@ def db_resnet50(pretrained: bool = False, **kwargs: Any) -> DBNet: resnet50, ["layer1", "layer2", "layer3", "layer4"], None, + ignore_keys=[ + "prob_head.6.weight", + "prob_head.6.bias", + "thresh_head.6.weight", + "thresh_head.6.bias", + ], **kwargs, ) @@ -394,6 +412,12 @@ def db_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> DBNet: mobilenet_v3_large, ["3", "6", "12", "16"], "features", + ignore_keys=[ + "prob_head.6.weight", + "prob_head.6.bias", + "thresh_head.6.weight", + "thresh_head.6.bias", + ], **kwargs, ) @@ -422,5 +446,11 @@ def db_resnet50_rotation(pretrained: bool = False, **kwargs: Any) -> DBNet: resnet50, ["layer1", "layer2", "layer3", "layer4"], None, + ignore_keys=[ + "prob_head.6.weight", + "prob_head.6.bias", + "thresh_head.6.weight", + "thresh_head.6.bias", + ], **kwargs, ) diff --git a/doctr/models/detection/linknet/pytorch.py b/doctr/models/detection/linknet/pytorch.py index c3edcfbf70..f40c4018f6 100644 --- a/doctr/models/detection/linknet/pytorch.py +++ b/doctr/models/detection/linknet/pytorch.py @@ -247,6 +247,7 @@ def _linknet( backbone_fn: Callable[[bool], nn.Module], fpn_layers: List[str], pretrained_backbone: bool = True, + ignore_keys: Optional[List[str]] = None, **kwargs: Any, ) -> LinkNet: pretrained_backbone = pretrained_backbone and not pretrained @@ -266,7 +267,12 @@ def _linknet( model = LinkNet(feat_extractor, cfg=default_cfgs[arch], **kwargs) # Load pretrained parameters if pretrained: - load_pretrained_params(model, default_cfgs[arch]["url"]) + # The number of class_names is not the same as the number of classes in the pretrained model => + # remove the layer weights + _ignore_keys = ( + ignore_keys if kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]) else None + ) + load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys) return model @@ -288,7 +294,17 @@ def linknet_resnet18(pretrained: bool = False, **kwargs: Any) -> LinkNet: text detection architecture """ - return _linknet("linknet_resnet18", pretrained, resnet18, ["layer1", "layer2", "layer3", "layer4"], **kwargs) + return _linknet( + "linknet_resnet18", + pretrained, + resnet18, + ["layer1", "layer2", "layer3", "layer4"], + ignore_keys=[ + "classifier.6.weight", + "classifier.6.bias", + ], + **kwargs, + ) def linknet_resnet34(pretrained: bool = False, **kwargs: Any) -> LinkNet: @@ -308,7 +324,17 @@ def linknet_resnet34(pretrained: bool = False, **kwargs: Any) -> LinkNet: text detection architecture """ - return _linknet("linknet_resnet34", pretrained, resnet34, ["layer1", "layer2", "layer3", "layer4"], **kwargs) + return _linknet( + "linknet_resnet34", + pretrained, + resnet34, + ["layer1", "layer2", "layer3", "layer4"], + ignore_keys=[ + "classifier.6.weight", + "classifier.6.bias", + ], + **kwargs, + ) def linknet_resnet50(pretrained: bool = False, **kwargs: Any) -> LinkNet: @@ -328,4 +354,14 @@ def linknet_resnet50(pretrained: bool = False, **kwargs: Any) -> LinkNet: text detection architecture """ - return _linknet("linknet_resnet50", pretrained, resnet50, ["layer1", "layer2", "layer3", "layer4"], **kwargs) + return _linknet( + "linknet_resnet50", + pretrained, + resnet50, + ["layer1", "layer2", "layer3", "layer4"], + ignore_keys=[ + "classifier.6.weight", + "classifier.6.bias", + ], + **kwargs, + )