diff --git a/gluefactory/models/extractors/superpoint_open.py b/gluefactory/models/extractors/superpoint_open.py index 434e0a1d..a9248a66 100644 --- a/gluefactory/models/extractors/superpoint_open.py +++ b/gluefactory/models/extractors/superpoint_open.py @@ -7,6 +7,7 @@ """ from collections import OrderedDict +from pathlib import Path from types import SimpleNamespace import torch @@ -85,6 +86,7 @@ class SuperPoint(BaseModel): "descriptor_dim": 256, "channels": [64, 64, 128, 128, 256], "dense_outputs": None, + "weights": None, # local path of pretrained weights } checkpoint_url = "https://github.com/rpautrat/SuperPoint/raw/master/weights/superpoint_v6_from_tf.pth" # noqa: E501 @@ -112,7 +114,10 @@ def _init(self, conf): VGGBlock(c, self.conf.descriptor_dim, 1, relu=False), ) - state_dict = torch.hub.load_state_dict_from_url(self.checkpoint_url) + if conf.weights is not None and Path(conf.weights).exists(): + state_dict = torch.load(conf.weights, map_location="cpu") + else: + state_dict = torch.hub.load_state_dict_from_url(self.checkpoint_url) self.load_state_dict(state_dict) def _forward(self, data):