Skip to content

Commit

Permalink
Fix typehints, improve error raising for encoders
Browse files Browse the repository at this point in the history
  • Loading branch information
qubvel committed Dec 13, 2020
1 parent aa705da commit bdc58db
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 5 deletions.
14 changes: 12 additions & 2 deletions segmentation_models_pytorch/encoders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,23 @@


def get_encoder(name, in_channels=3, depth=5, weights=None):
Encoder = encoders[name]["encoder"]

try:
Encoder = encoders[name]["encoder"]
except KeyError:
raise KeyError("Wrong encoder name `{}`, supported encoders: {}".format(name, list(encoders.keys())))

params = encoders[name]["params"]
params.update(depth=depth)
encoder = Encoder(**params)

if weights is not None:
settings = encoders[name]["pretrained_settings"][weights]
try:
settings = encoders[name]["pretrained_settings"][weights]
except KeyError:
raise KeyError("Wrong pretrained weights `{}` for encoder `{}`. Avaliable options are: {}".format(
weights, name, list(encoders[name]["pretrained_settings"].keys()),
))
encoder.load_state_dict(model_zoo.load_url(settings["url"]))

encoder.set_in_channels(in_channels)
Expand Down
2 changes: 1 addition & 1 deletion segmentation_models_pytorch/pan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class PAN(SegmentationModel):
def __init__(
self,
encoder_name: str = "resnet34",
encoder_weights: str = "imagenet",
encoder_weights: Optional[str] = "imagenet",
encoder_dilation: bool = True,
decoder_channels: int = 32,
in_channels: int = 3,
Expand Down
2 changes: 1 addition & 1 deletion segmentation_models_pytorch/unet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(
self,
encoder_name: str = "resnet34",
encoder_depth: int = 5,
encoder_weights: str = "imagenet",
encoder_weights: Optional[str] = "imagenet",
decoder_use_batchnorm: bool = True,
decoder_channels: List[int] = (256, 128, 64, 32, 16),
decoder_attention_type: Optional[str] = None,
Expand Down
2 changes: 1 addition & 1 deletion segmentation_models_pytorch/unetplusplus/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(
self,
encoder_name: str = "resnet34",
encoder_depth: int = 5,
encoder_weights: str = "imagenet",
encoder_weights: Optional[str] = "imagenet",
decoder_use_batchnorm: bool = True,
decoder_channels: List[int] = (256, 128, 64, 32, 16),
decoder_attention_type: Optional[str] = None,
Expand Down

0 comments on commit bdc58db

Please sign in to comment.