diff --git a/README.md b/README.md index 19a6b2b..9b6b5f4 100644 --- a/README.md +++ b/README.md @@ -26,8 +26,8 @@ This repository hosts the inference code of LightGlue, a lightweight feature matcher with high accuracy and blazing fast inference. It takes as input a set of keypoints and descriptors for each image and returns the indices of corresponding points. The architecture is based on adaptive pruning techniques, in both network width and depth - [check out the paper for more details](https://arxiv.org/pdf/2306.13643.pdf). -We release pretrained weights of LightGlue with [SuperPoint](https://arxiv.org/abs/1712.07629) and [DISK](https://arxiv.org/abs/2006.13566) local features. -The training end evaluation code will be released in July in a separate repo. To be notified, subscribe to [issue #6](https://github.com/cvg/LightGlue/issues/6). +We release pretrained weights of LightGlue with [SuperPoint](https://arxiv.org/abs/1712.07629), [DISK](https://arxiv.org/abs/2006.13566), [ALIKED](https://arxiv.org/abs/2304.03608) and [SIFT](https://www.cs.ubc.ca/~lowe/papers/ijcv04.pdf) local features. +The training end evaluation code can be found in our training library [glue-factory](https://github.com/cvg/glue-factory/). ## Installation and demo [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/cvg/LightGlue/blob/main/demo.ipynb) @@ -43,14 +43,14 @@ We provide a [demo notebook](demo.ipynb) which shows how to perform feature extr Here is a minimal script to match two images: ```python -from lightglue import LightGlue, SuperPoint, DISK +from lightglue import LightGlue, SuperPoint, DISK, SIFT, ALIKED from lightglue.utils import load_image, rbd # SuperPoint+LightGlue extractor = SuperPoint(max_num_keypoints=2048).eval().cuda() # load the extractor matcher = LightGlue(features='superpoint').eval().cuda() # load the matcher -# or DISK+LightGlue +# or DISK+LightGlue, ALIKED+LightGlue or SIFT+LightGlue extractor = DISK(max_num_keypoints=2048).eval().cuda() # load the extractor matcher = LightGlue(features='disk').eval().cuda() # load the matcher @@ -177,4 +177,4 @@ If you use any ideas from the paper or code from this repo, please consider citi ## License -The pre-trained weights of LightGlue and the code provided in this repository are released under the [Apache-2.0 license](./LICENSE). [DISK](https://github.com/cvlab-epfl/disk) follows this license as well but SuperPoint follows [a different, restrictive license](https://github.com/magicleap/SuperPointPretrainedNetwork/blob/master/LICENSE) (this includes its pre-trained weights and its [inference file](./lightglue/superpoint.py)). +The pre-trained weights of LightGlue and the code provided in this repository are released under the [Apache-2.0 license](./LICENSE). [DISK](https://github.com/cvlab-epfl/disk) follows this license as well but SuperPoint follows [a different, restrictive license](https://github.com/magicleap/SuperPointPretrainedNetwork/blob/master/LICENSE) (this includes its pre-trained weights and its [inference file](./lightglue/superpoint.py)). [ALIKED](https://github.com/Shiaoming/ALIKED) was published under a BSD-3-Clause license. diff --git a/lightglue/__init__.py b/lightglue/__init__.py index 7a20d06..42719c9 100644 --- a/lightglue/__init__.py +++ b/lightglue/__init__.py @@ -1,4 +1,6 @@ +from .aliked import ALIKED # noqa from .disk import DISK # noqa from .lightglue import LightGlue # noqa +from .sift import SIFT # noqa from .superpoint import SuperPoint # noqa from .utils import match_pair # noqa diff --git a/lightglue/aliked.py b/lightglue/aliked.py new file mode 100644 index 0000000..1161e1f --- /dev/null +++ b/lightglue/aliked.py @@ -0,0 +1,758 @@ +# BSD 3-Clause License + +# Copyright (c) 2022, Zhao Xiaoming +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# Authors: +# Xiaoming Zhao, Xingming Wu, Weihai Chen, Peter C.Y. Chen, Qingsong Xu, and Zhengguo Li +# Code from https://github.com/Shiaoming/ALIKED + +from typing import Callable, Optional + +import torch +import torch.nn.functional as F +import torchvision +from kornia.color import grayscale_to_rgb +from torch import nn +from torch.nn.modules.utils import _pair +from torchvision.models import resnet + +from .utils import Extractor + + +def get_patches( + tensor: torch.Tensor, required_corners: torch.Tensor, ps: int +) -> torch.Tensor: + c, h, w = tensor.shape + corner = (required_corners - ps / 2 + 1).long() + corner[:, 0] = corner[:, 0].clamp(min=0, max=w - 1 - ps) + corner[:, 1] = corner[:, 1].clamp(min=0, max=h - 1 - ps) + offset = torch.arange(0, ps) + + kw = {"indexing": "ij"} if torch.__version__ >= "1.10" else {} + x, y = torch.meshgrid(offset, offset, **kw) + patches = torch.stack((x, y)).permute(2, 1, 0).unsqueeze(2) + patches = patches.to(corner) + corner[None, None] + pts = patches.reshape(-1, 2) + sampled = tensor.permute(1, 2, 0)[tuple(pts.T)[::-1]] + sampled = sampled.reshape(ps, ps, -1, c) + assert sampled.shape[:3] == patches.shape[:3] + return sampled.permute(2, 3, 0, 1) + + +def simple_nms(scores: torch.Tensor, nms_radius: int): + """Fast Non-maximum suppression to remove nearby points""" + + zeros = torch.zeros_like(scores) + max_mask = scores == torch.nn.functional.max_pool2d( + scores, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius + ) + + for _ in range(2): + supp_mask = ( + torch.nn.functional.max_pool2d( + max_mask.float(), + kernel_size=nms_radius * 2 + 1, + stride=1, + padding=nms_radius, + ) + > 0 + ) + supp_scores = torch.where(supp_mask, zeros, scores) + new_max_mask = supp_scores == torch.nn.functional.max_pool2d( + supp_scores, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius + ) + max_mask = max_mask | (new_max_mask & (~supp_mask)) + return torch.where(max_mask, scores, zeros) + + +class DKD(nn.Module): + def __init__( + self, + radius: int = 2, + top_k: int = 0, + scores_th: float = 0.2, + n_limit: int = 20000, + ): + """ + Args: + radius: soft detection radius, kernel size is (2 * radius + 1) + top_k: top_k > 0: return top k keypoints + scores_th: top_k <= 0 threshold mode: + scores_th > 0: return keypoints with scores>scores_th + else: return keypoints with scores > scores.mean() + n_limit: max number of keypoint in threshold mode + """ + super().__init__() + self.radius = radius + self.top_k = top_k + self.scores_th = scores_th + self.n_limit = n_limit + self.kernel_size = 2 * self.radius + 1 + self.temperature = 0.1 # tuned temperature + self.unfold = nn.Unfold(kernel_size=self.kernel_size, padding=self.radius) + # local xy grid + x = torch.linspace(-self.radius, self.radius, self.kernel_size) + # (kernel_size*kernel_size) x 2 : (w,h) + kw = {"indexing": "ij"} if torch.__version__ >= "1.10" else {} + self.hw_grid = ( + torch.stack(torch.meshgrid([x, x], **kw)).view(2, -1).t()[:, [1, 0]] + ) + + def forward( + self, + scores_map: torch.Tensor, + sub_pixel: bool = True, + image_size: Optional[torch.Tensor] = None, + ): + """ + :param scores_map: Bx1xHxW + :param descriptor_map: BxCxHxW + :param sub_pixel: whether to use sub-pixel keypoint detection + :return: kpts: list[Nx2,...]; kptscores: list[N,....] normalised position: -1~1 + """ + b, c, h, w = scores_map.shape + scores_nograd = scores_map.detach() + nms_scores = simple_nms(scores_nograd, self.radius) + + # remove border + nms_scores[:, :, : self.radius, :] = 0 + nms_scores[:, :, :, : self.radius] = 0 + if image_size is not None: + for i in range(scores_map.shape[0]): + w, h = image_size[i].long() + nms_scores[i, :, h.item() - self.radius :, :] = 0 + nms_scores[i, :, :, w.item() - self.radius :] = 0 + else: + nms_scores[:, :, -self.radius :, :] = 0 + nms_scores[:, :, :, -self.radius :] = 0 + + # detect keypoints without grad + if self.top_k > 0: + topk = torch.topk(nms_scores.view(b, -1), self.top_k) + indices_keypoints = [topk.indices[i] for i in range(b)] # B x top_k + else: + if self.scores_th > 0: + masks = nms_scores > self.scores_th + if masks.sum() == 0: + th = scores_nograd.reshape(b, -1).mean(dim=1) # th = self.scores_th + masks = nms_scores > th.reshape(b, 1, 1, 1) + else: + th = scores_nograd.reshape(b, -1).mean(dim=1) # th = self.scores_th + masks = nms_scores > th.reshape(b, 1, 1, 1) + masks = masks.reshape(b, -1) + + indices_keypoints = [] # list, B x (any size) + scores_view = scores_nograd.reshape(b, -1) + for mask, scores in zip(masks, scores_view): + indices = mask.nonzero()[:, 0] + if len(indices) > self.n_limit: + kpts_sc = scores[indices] + sort_idx = kpts_sc.sort(descending=True)[1] + sel_idx = sort_idx[: self.n_limit] + indices = indices[sel_idx] + indices_keypoints.append(indices) + + wh = torch.tensor([w - 1, h - 1], device=scores_nograd.device) + + keypoints = [] + scoredispersitys = [] + kptscores = [] + if sub_pixel: + # detect soft keypoints with grad backpropagation + patches = self.unfold(scores_map) # B x (kernel**2) x (H*W) + self.hw_grid = self.hw_grid.to(scores_map) # to device + for b_idx in range(b): + patch = patches[b_idx].t() # (H*W) x (kernel**2) + indices_kpt = indices_keypoints[ + b_idx + ] # one dimension vector, say its size is M + patch_scores = patch[indices_kpt] # M x (kernel**2) + keypoints_xy_nms = torch.stack( + [indices_kpt % w, torch.div(indices_kpt, w, rounding_mode="trunc")], + dim=1, + ) # Mx2 + + # max is detached to prevent undesired backprop loops in the graph + max_v = patch_scores.max(dim=1).values.detach()[:, None] + x_exp = ( + (patch_scores - max_v) / self.temperature + ).exp() # M * (kernel**2), in [0, 1] + + # \frac{ \sum{(i,j) \times \exp(x/T)} }{ \sum{\exp(x/T)} } + xy_residual = ( + x_exp @ self.hw_grid / x_exp.sum(dim=1)[:, None] + ) # Soft-argmax, Mx2 + + hw_grid_dist2 = ( + torch.norm( + (self.hw_grid[None, :, :] - xy_residual[:, None, :]) + / self.radius, + dim=-1, + ) + ** 2 + ) + scoredispersity = (x_exp * hw_grid_dist2).sum(dim=1) / x_exp.sum(dim=1) + + # compute result keypoints + keypoints_xy = keypoints_xy_nms + xy_residual + keypoints_xy = keypoints_xy / wh * 2 - 1 # (w,h) -> (-1~1,-1~1) + + kptscore = torch.nn.functional.grid_sample( + scores_map[b_idx].unsqueeze(0), + keypoints_xy.view(1, 1, -1, 2), + mode="bilinear", + align_corners=True, + )[ + 0, 0, 0, : + ] # CxN + + keypoints.append(keypoints_xy) + scoredispersitys.append(scoredispersity) + kptscores.append(kptscore) + else: + for b_idx in range(b): + indices_kpt = indices_keypoints[ + b_idx + ] # one dimension vector, say its size is M + # To avoid warning: UserWarning: __floordiv__ is deprecated + keypoints_xy_nms = torch.stack( + [indices_kpt % w, torch.div(indices_kpt, w, rounding_mode="trunc")], + dim=1, + ) # Mx2 + keypoints_xy = keypoints_xy_nms / wh * 2 - 1 # (w,h) -> (-1~1,-1~1) + kptscore = torch.nn.functional.grid_sample( + scores_map[b_idx].unsqueeze(0), + keypoints_xy.view(1, 1, -1, 2), + mode="bilinear", + align_corners=True, + )[ + 0, 0, 0, : + ] # CxN + keypoints.append(keypoints_xy) + scoredispersitys.append(kptscore) # for jit.script compatability + kptscores.append(kptscore) + + return keypoints, scoredispersitys, kptscores + + +class InputPadder(object): + """Pads images such that dimensions are divisible by 8""" + + def __init__(self, h: int, w: int, divis_by: int = 8): + self.ht = h + self.wd = w + pad_ht = (((self.ht // divis_by) + 1) * divis_by - self.ht) % divis_by + pad_wd = (((self.wd // divis_by) + 1) * divis_by - self.wd) % divis_by + self._pad = [ + pad_wd // 2, + pad_wd - pad_wd // 2, + pad_ht // 2, + pad_ht - pad_ht // 2, + ] + + def pad(self, x: torch.Tensor): + assert x.ndim == 4 + return F.pad(x, self._pad, mode="replicate") + + def unpad(self, x: torch.Tensor): + assert x.ndim == 4 + ht = x.shape[-2] + wd = x.shape[-1] + c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]] + return x[..., c[0] : c[1], c[2] : c[3]] + + +class DeformableConv2d(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False, + mask=False, + ): + super(DeformableConv2d, self).__init__() + + self.padding = padding + self.mask = mask + + self.channel_num = ( + 3 * kernel_size * kernel_size if mask else 2 * kernel_size * kernel_size + ) + self.offset_conv = nn.Conv2d( + in_channels, + self.channel_num, + kernel_size=kernel_size, + stride=stride, + padding=self.padding, + bias=True, + ) + + self.regular_conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=self.padding, + bias=bias, + ) + + def forward(self, x): + h, w = x.shape[2:] + max_offset = max(h, w) / 4.0 + + out = self.offset_conv(x) + if self.mask: + o1, o2, mask = torch.chunk(out, 3, dim=1) + offset = torch.cat((o1, o2), dim=1) + mask = torch.sigmoid(mask) + else: + offset = out + mask = None + offset = offset.clamp(-max_offset, max_offset) + x = torchvision.ops.deform_conv2d( + input=x, + offset=offset, + weight=self.regular_conv.weight, + bias=self.regular_conv.bias, + padding=self.padding, + mask=mask, + ) + return x + + +def get_conv( + inplanes, + planes, + kernel_size=3, + stride=1, + padding=1, + bias=False, + conv_type="conv", + mask=False, +): + if conv_type == "conv": + conv = nn.Conv2d( + inplanes, + planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=bias, + ) + elif conv_type == "dcn": + conv = DeformableConv2d( + inplanes, + planes, + kernel_size=kernel_size, + stride=stride, + padding=_pair(padding), + bias=bias, + mask=mask, + ) + else: + raise TypeError + return conv + + +class ConvBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + gate: Optional[Callable[..., nn.Module]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, + conv_type: str = "conv", + mask: bool = False, + ): + super().__init__() + if gate is None: + self.gate = nn.ReLU(inplace=True) + else: + self.gate = gate + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self.conv1 = get_conv( + in_channels, out_channels, kernel_size=3, conv_type=conv_type, mask=mask + ) + self.bn1 = norm_layer(out_channels) + self.conv2 = get_conv( + out_channels, out_channels, kernel_size=3, conv_type=conv_type, mask=mask + ) + self.bn2 = norm_layer(out_channels) + + def forward(self, x): + x = self.gate(self.bn1(self.conv1(x))) # B x in_channels x H x W + x = self.gate(self.bn2(self.conv2(x))) # B x out_channels x H x W + return x + + +# modified based on torchvision\models\resnet.py#27->BasicBlock +class ResBlock(nn.Module): + expansion: int = 1 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + gate: Optional[Callable[..., nn.Module]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, + conv_type: str = "conv", + mask: bool = False, + ) -> None: + super(ResBlock, self).__init__() + if gate is None: + self.gate = nn.ReLU(inplace=True) + else: + self.gate = gate + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError("ResBlock only supports groups=1 and base_width=64") + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in ResBlock") + # Both self.conv1 and self.downsample layers + # downsample the input when stride != 1 + self.conv1 = get_conv( + inplanes, planes, kernel_size=3, conv_type=conv_type, mask=mask + ) + self.bn1 = norm_layer(planes) + self.conv2 = get_conv( + planes, planes, kernel_size=3, conv_type=conv_type, mask=mask + ) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x: torch.Tensor) -> torch.Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.gate(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.gate(out) + + return out + + +class SDDH(nn.Module): + def __init__( + self, + dims: int, + kernel_size: int = 3, + n_pos: int = 8, + gate=nn.ReLU(), + conv2D=False, + mask=False, + ): + super(SDDH, self).__init__() + self.kernel_size = kernel_size + self.n_pos = n_pos + self.conv2D = conv2D + self.mask = mask + + self.get_patches_func = get_patches + + # estimate offsets + self.channel_num = 3 * n_pos if mask else 2 * n_pos + self.offset_conv = nn.Sequential( + nn.Conv2d( + dims, + self.channel_num, + kernel_size=kernel_size, + stride=1, + padding=0, + bias=True, + ), + gate, + nn.Conv2d( + self.channel_num, + self.channel_num, + kernel_size=1, + stride=1, + padding=0, + bias=True, + ), + ) + + # sampled feature conv + self.sf_conv = nn.Conv2d( + dims, dims, kernel_size=1, stride=1, padding=0, bias=False + ) + + # convM + if not conv2D: + # deformable desc weights + agg_weights = torch.nn.Parameter(torch.rand(n_pos, dims, dims)) + self.register_parameter("agg_weights", agg_weights) + else: + self.convM = nn.Conv2d( + dims * n_pos, dims, kernel_size=1, stride=1, padding=0, bias=False + ) + + def forward(self, x, keypoints): + # x: [B,C,H,W] + # keypoints: list, [[N_kpts,2], ...] (w,h) + b, c, h, w = x.shape + wh = torch.tensor([[w - 1, h - 1]], device=x.device) + max_offset = max(h, w) / 4.0 + + offsets = [] + descriptors = [] + # get offsets for each keypoint + for ib in range(b): + xi, kptsi = x[ib], keypoints[ib] + kptsi_wh = (kptsi / 2 + 0.5) * wh + N_kpts = len(kptsi) + + if self.kernel_size > 1: + patch = self.get_patches_func( + xi, kptsi_wh.long(), self.kernel_size + ) # [N_kpts, C, K, K] + else: + kptsi_wh_long = kptsi_wh.long() + patch = ( + xi[:, kptsi_wh_long[:, 1], kptsi_wh_long[:, 0]] + .permute(1, 0) + .reshape(N_kpts, c, 1, 1) + ) + + offset = self.offset_conv(patch).clamp( + -max_offset, max_offset + ) # [N_kpts, 2*n_pos, 1, 1] + if self.mask: + offset = ( + offset[:, :, 0, 0].view(N_kpts, 3, self.n_pos).permute(0, 2, 1) + ) # [N_kpts, n_pos, 3] + offset = offset[:, :, :-1] # [N_kpts, n_pos, 2] + mask_weight = torch.sigmoid(offset[:, :, -1]) # [N_kpts, n_pos] + else: + offset = ( + offset[:, :, 0, 0].view(N_kpts, 2, self.n_pos).permute(0, 2, 1) + ) # [N_kpts, n_pos, 2] + offsets.append(offset) # for visualization + + # get sample positions + pos = kptsi_wh.unsqueeze(1) + offset # [N_kpts, n_pos, 2] + pos = 2.0 * pos / wh[None] - 1 + pos = pos.reshape(1, N_kpts * self.n_pos, 1, 2) + + # sample features + features = F.grid_sample( + xi.unsqueeze(0), pos, mode="bilinear", align_corners=True + ) # [1,C,(N_kpts*n_pos),1] + features = features.reshape(c, N_kpts, self.n_pos, 1).permute( + 1, 0, 2, 3 + ) # [N_kpts, C, n_pos, 1] + if self.mask: + features = torch.einsum("ncpo,np->ncpo", features, mask_weight) + + features = torch.selu_(self.sf_conv(features)).squeeze( + -1 + ) # [N_kpts, C, n_pos] + # convM + if not self.conv2D: + descs = torch.einsum( + "ncp,pcd->nd", features, self.agg_weights + ) # [N_kpts, C] + else: + features = features.reshape(N_kpts, -1)[ + :, :, None, None + ] # [N_kpts, C*n_pos, 1, 1] + descs = self.convM(features).squeeze() # [N_kpts, C] + + # normalize + descs = F.normalize(descs, p=2.0, dim=1) + descriptors.append(descs) + + return descriptors, offsets + + +class ALIKED(Extractor): + default_conf = { + "model_name": "aliked-n16", + "max_num_keypoints": -1, + "detection_threshold": 0.2, + "nms_radius": 2, + } + + checkpoint_url = "https://github.com/Shiaoming/ALIKED/raw/main/models/{}.pth" + + n_limit_max = 20000 + + # c1, c2, c3, c4, dim, K, M + cfgs = { + "aliked-t16": [8, 16, 32, 64, 64, 3, 16], + "aliked-n16": [16, 32, 64, 128, 128, 3, 16], + "aliked-n16rot": [16, 32, 64, 128, 128, 3, 16], + "aliked-n32": [16, 32, 64, 128, 128, 3, 32], + } + preprocess_conf = { + "resize": 1024, + } + + required_data_keys = ["image"] + + def __init__(self, **conf): + super().__init__(**conf) # Update with default configuration. + conf = self.conf + c1, c2, c3, c4, dim, K, M = self.cfgs[conf.model_name] + conv_types = ["conv", "conv", "dcn", "dcn"] + conv2D = False + mask = False + + # build model + self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2) + self.pool4 = nn.AvgPool2d(kernel_size=4, stride=4) + self.norm = nn.BatchNorm2d + self.gate = nn.SELU(inplace=True) + self.block1 = ConvBlock(3, c1, self.gate, self.norm, conv_type=conv_types[0]) + self.block2 = self.get_resblock(c1, c2, conv_types[1], mask) + self.block3 = self.get_resblock(c2, c3, conv_types[2], mask) + self.block4 = self.get_resblock(c3, c4, conv_types[3], mask) + + self.conv1 = resnet.conv1x1(c1, dim // 4) + self.conv2 = resnet.conv1x1(c2, dim // 4) + self.conv3 = resnet.conv1x1(c3, dim // 4) + self.conv4 = resnet.conv1x1(dim, dim // 4) + self.upsample2 = nn.Upsample( + scale_factor=2, mode="bilinear", align_corners=True + ) + self.upsample4 = nn.Upsample( + scale_factor=4, mode="bilinear", align_corners=True + ) + self.upsample8 = nn.Upsample( + scale_factor=8, mode="bilinear", align_corners=True + ) + self.upsample32 = nn.Upsample( + scale_factor=32, mode="bilinear", align_corners=True + ) + self.score_head = nn.Sequential( + resnet.conv1x1(dim, 8), + self.gate, + resnet.conv3x3(8, 4), + self.gate, + resnet.conv3x3(4, 4), + self.gate, + resnet.conv3x3(4, 1), + ) + self.desc_head = SDDH(dim, K, M, gate=self.gate, conv2D=conv2D, mask=mask) + self.dkd = DKD( + radius=conf.nms_radius, + top_k=-1 if conf.detection_threshold > 0 else conf.max_num_keypoints, + scores_th=conf.detection_threshold, + n_limit=conf.max_num_keypoints + if conf.max_num_keypoints > 0 + else self.n_limit_max, + ) + + state_dict = torch.hub.load_state_dict_from_url( + self.checkpoint_url.format(conf.model_name), map_location="cpu" + ) + self.load_state_dict(state_dict, strict=True) + + def get_resblock(self, c_in, c_out, conv_type, mask): + return ResBlock( + c_in, + c_out, + 1, + nn.Conv2d(c_in, c_out, 1), + gate=self.gate, + norm_layer=self.norm, + conv_type=conv_type, + mask=mask, + ) + + def extract_dense_map(self, image): + # Pads images such that dimensions are divisible by + div_by = 2**5 + padder = InputPadder(image.shape[-2], image.shape[-1], div_by) + image = padder.pad(image) + + # ================================== feature encoder + x1 = self.block1(image) # B x c1 x H x W + x2 = self.pool2(x1) + x2 = self.block2(x2) # B x c2 x H/2 x W/2 + x3 = self.pool4(x2) + x3 = self.block3(x3) # B x c3 x H/8 x W/8 + x4 = self.pool4(x3) + x4 = self.block4(x4) # B x dim x H/32 x W/32 + # ================================== feature aggregation + x1 = self.gate(self.conv1(x1)) # B x dim//4 x H x W + x2 = self.gate(self.conv2(x2)) # B x dim//4 x H//2 x W//2 + x3 = self.gate(self.conv3(x3)) # B x dim//4 x H//8 x W//8 + x4 = self.gate(self.conv4(x4)) # B x dim//4 x H//32 x W//32 + x2_up = self.upsample2(x2) # B x dim//4 x H x W + x3_up = self.upsample8(x3) # B x dim//4 x H x W + x4_up = self.upsample32(x4) # B x dim//4 x H x W + x1234 = torch.cat([x1, x2_up, x3_up, x4_up], dim=1) + # ================================== score head + score_map = torch.sigmoid(self.score_head(x1234)) + feature_map = torch.nn.functional.normalize(x1234, p=2, dim=1) + + # Unpads images + feature_map = padder.unpad(feature_map) + score_map = padder.unpad(score_map) + + return feature_map, score_map + + def forward(self, data: dict) -> dict: + image = data["image"] + if image.shape[1] == 1: + image = grayscale_to_rgb(image) + feature_map, score_map = self.extract_dense_map(image) + keypoints, kptscores, scoredispersitys = self.dkd( + score_map, image_size=data.get("image_size") + ) + descriptors, offsets = self.desc_head(feature_map, keypoints) + + _, _, h, w = image.shape + wh = torch.tensor([w - 1, h - 1], device=image.device) + # no padding required + # we can set detection_threshold=-1 and conf.max_num_keypoints > 0 + return { + "keypoints": wh * (torch.stack(keypoints) + 1) / 2.0, # B x N x 2 + "descriptors": torch.stack(descriptors), # B x N x D + "keypoint_scores": torch.stack(kptscores), # B x N + } diff --git a/lightglue/disk.py b/lightglue/disk.py index 9632578..8cb2195 100644 --- a/lightglue/disk.py +++ b/lightglue/disk.py @@ -1,13 +1,10 @@ -from types import SimpleNamespace - import kornia import torch -import torch.nn as nn -from .utils import ImagePreprocessor +from .utils import Extractor -class DISK(nn.Module): +class DISK(Extractor): default_conf = { "weights": "depth", "max_num_keypoints": None, @@ -18,7 +15,6 @@ class DISK(nn.Module): } preprocess_conf = { - **ImagePreprocessor.default_conf, "resize": 1024, "grayscale": False, } @@ -26,9 +22,7 @@ class DISK(nn.Module): required_data_keys = ["image"] def __init__(self, **conf) -> None: - super().__init__() - self.conf = {**self.default_conf, **conf} - self.conf = SimpleNamespace(**self.conf) + super().__init__(**conf) # Update with default configuration. self.model = kornia.feature.DISK.from_pretrained(self.conf.weights) def forward(self, data: dict) -> dict: @@ -36,6 +30,8 @@ def forward(self, data: dict) -> dict: for key in self.required_data_keys: assert key in data, f"Missing key {key} in data" image = data["image"] + if image.shape[1] == 1: + image = kornia.color.grayscale_to_rgb(image) features = self.model( image, n=self.conf.max_num_keypoints, @@ -57,15 +53,3 @@ def forward(self, data: dict) -> dict: "keypoint_scores": scores.to(image).contiguous(), "descriptors": descriptors.to(image).contiguous(), } - - def extract(self, img: torch.Tensor, **conf) -> dict: - """Perform extraction with online resizing""" - if img.dim() == 3: - img = img[None] # add batch dim - assert img.dim() == 4 and img.shape[0] == 1 - shape = img.shape[-2:][::-1] - img, scales = ImagePreprocessor(**{**self.preprocess_conf, **conf})(img) - feats = self.forward({"image": img}) - feats["image_size"] = torch.tensor(shape)[None].to(img).float() - feats["keypoints"] = (feats["keypoints"] + 0.5) / scales[None] - 0.5 - return feats diff --git a/lightglue/lightglue.py b/lightglue/lightglue.py index 9ddcda4..fcbb7ee 100644 --- a/lightglue/lightglue.py +++ b/lightglue/lightglue.py @@ -314,6 +314,7 @@ class LightGlue(nn.Module): "name": "lightglue", # just for interfacing "input_dim": 256, # input descriptor dimension (autoselected from weights) "descriptor_dim": 256, + "add_scale_ori": False, "n_layers": 9, "num_heads": 4, "flash": True, # enable FlashAttention if available. @@ -339,17 +340,36 @@ class LightGlue(nn.Module): url = "https://github.com/cvg/LightGlue/releases/download/{}/{}_lightglue.pth" features = { - "superpoint": ("superpoint_lightglue", 256), - "disk": ("disk_lightglue", 128), + "superpoint": { + "weights": "superpoint_lightglue", + "input_dim": 256, + }, + "disk": { + "weights": "disk_lightglue", + "input_dim": 128, + }, + "aliked": { + "weights": "aliked_lightglue", + "input_dim": 128, + }, + "sift": { + "weights": "sift_lightglue", + "input_dim": 128, + "add_scale_ori": True, + }, } def __init__(self, features="superpoint", **conf) -> None: super().__init__() - self.conf = {**self.default_conf, **conf} + self.conf = conf = SimpleNamespace(**{**self.default_conf, **conf}) if features is not None: - assert features in list(self.features.keys()) - self.conf["weights"], self.conf["input_dim"] = self.features[features] - self.conf = conf = SimpleNamespace(**self.conf) + if features not in self.features: + raise ValueError( + f"Unsupported features: {features} not in " + f"{{{','.join(self.features)}}}" + ) + for k, v in self.features[features].items(): + setattr(conf, k, v) if conf.input_dim != conf.descriptor_dim: self.input_proj = nn.Linear(conf.input_dim, conf.descriptor_dim, bias=True) @@ -357,7 +377,9 @@ def __init__(self, features="superpoint", **conf) -> None: self.input_proj = nn.Identity() head_dim = conf.descriptor_dim // conf.num_heads - self.posenc = LearnableFourierPositionalEncoding(2, head_dim, head_dim) + self.posenc = LearnableFourierPositionalEncoding( + 2 + 2 * self.conf.add_scale_ori, head_dim, head_dim + ) h, n, d = conf.num_heads, conf.n_layers, conf.descriptor_dim @@ -378,7 +400,7 @@ def __init__(self, features="superpoint", **conf) -> None: state_dict = None if features is not None: - fname = f"{conf.weights}_{self.version}.pth".replace(".", "-") + fname = f"{conf.weights}_{self.version.replace('.', '-')}.pth" state_dict = torch.hub.load_state_dict_from_url( self.url.format(self.version, features), file_name=fname ) @@ -452,6 +474,13 @@ def _forward(self, data: dict) -> dict: kpts0 = normalize_keypoints(kpts0, size0).clone() kpts1 = normalize_keypoints(kpts1, size1).clone() + if self.conf.add_scale_ori: + kpts0 = torch.cat( + [kpts0] + [data0[k].unsqueeze(-1) for k in ("scales", "oris")], -1 + ) + kpts1 = torch.cat( + [kpts1] + [data1[k].unsqueeze(-1) for k in ("scales", "oris")], -1 + ) desc0 = data0["descriptors"].detach().contiguous() desc1 = data1["descriptors"].detach().contiguous() diff --git a/lightglue/sift.py b/lightglue/sift.py new file mode 100644 index 0000000..802fc1c --- /dev/null +++ b/lightglue/sift.py @@ -0,0 +1,216 @@ +import warnings + +import cv2 +import numpy as np +import torch +from kornia.color import rgb_to_grayscale +from packaging import version + +try: + import pycolmap +except ImportError: + pycolmap = None + +from .utils import Extractor + + +def filter_dog_point(points, scales, angles, image_shape, nms_radius, scores=None): + h, w = image_shape + ij = np.round(points - 0.5).astype(int).T[::-1] + + # Remove duplicate points (identical coordinates). + # Pick highest scale or score + s = scales if scores is None else scores + buffer = np.zeros((h, w)) + np.maximum.at(buffer, tuple(ij), s) + keep = np.where(buffer[tuple(ij)] == s)[0] + + # Pick lowest angle (arbitrary). + ij = ij[:, keep] + buffer[:] = np.inf + o_abs = np.abs(angles[keep]) + np.minimum.at(buffer, tuple(ij), o_abs) + mask = buffer[tuple(ij)] == o_abs + ij = ij[:, mask] + keep = keep[mask] + + if nms_radius > 0: + # Apply NMS on the remaining points + buffer[:] = 0 + buffer[tuple(ij)] = s[keep] # scores or scale + + local_max = torch.nn.functional.max_pool2d( + torch.from_numpy(buffer).unsqueeze(0), + kernel_size=nms_radius * 2 + 1, + stride=1, + padding=nms_radius, + ).squeeze(0) + is_local_max = buffer == local_max.numpy() + keep = keep[is_local_max[tuple(ij)]] + return keep + + +def sift_to_rootsift(x: torch.Tensor, eps=1e-6) -> torch.Tensor: + x = torch.nn.functional.normalize(x, p=1, dim=-1, eps=eps) + x.clip_(min=eps).sqrt_() + return torch.nn.functional.normalize(x, p=2, dim=-1, eps=eps) + + +def run_opencv_sift(features: cv2.Feature2D, image: np.ndarray) -> np.ndarray: + """ + Detect keypoints using OpenCV Detector. + Optionally, perform description. + Args: + features: OpenCV based keypoints detector and descriptor + image: Grayscale image of uint8 data type + Returns: + keypoints: 1D array of detected cv2.KeyPoint + scores: 1D array of responses + descriptors: 1D array of descriptors + """ + detections, descriptors = features.detectAndCompute(image, None) + points = np.array([k.pt for k in detections], dtype=np.float32) + scores = np.array([k.response for k in detections], dtype=np.float32) + scales = np.array([k.size for k in detections], dtype=np.float32) + angles = np.deg2rad(np.array([k.angle for k in detections], dtype=np.float32)) + return points, scores, scales, angles, descriptors + + +class SIFT(Extractor): + default_conf = { + "rootsift": True, + "nms_radius": 0, # None to disable filtering entirely. + "max_num_keypoints": 4096, + "backend": "opencv", # in {opencv, pycolmap, pycolmap_cpu, pycolmap_cuda} + "detection_threshold": 0.0066667, # from COLMAP + "edge_threshold": 10, + "first_octave": -1, # only used by pycolmap, the default of COLMAP + "num_octaves": 4, + } + + preprocess_conf = { + "resize": 1024, + } + + required_data_keys = ["image"] + + def __init__(self, **conf): + super().__init__(**conf) # Update with default configuration. + backend = self.conf.backend + if backend.startswith("pycolmap"): + if pycolmap is None: + raise ImportError( + "Cannot find module pycolmap: install it with pip" + "or use backend=opencv." + ) + options = { + "peak_threshold": self.conf.detection_threshold, + "edge_threshold": self.conf.edge_threshold, + "first_octave": self.conf.first_octave, + "num_octaves": self.conf.num_octaves, + "normalization": pycolmap.Normalization.L2, # L1_ROOT is buggy. + } + device = ( + "auto" if backend == "pycolmap" else backend.replace("pycolmap_", "") + ) + if ( + backend == "pycolmap_cpu" or not pycolmap.has_cuda + ) and pycolmap.__version__ < "0.5.0": + warnings.warn( + "The pycolmap CPU SIFT is buggy in version < 0.5.0, " + "consider upgrading pycolmap or use the CUDA version.", + stacklevel=1, + ) + else: + options["max_num_features"] = self.conf.max_num_keypoints + self.sift = pycolmap.Sift(options=options, device=device) + elif backend == "opencv": + self.sift = cv2.SIFT_create( + contrastThreshold=self.conf.detection_threshold, + nfeatures=self.conf.max_num_keypoints, + edgeThreshold=self.conf.edge_threshold, + nOctaveLayers=self.conf.num_octaves, + ) + else: + backends = {"opencv", "pycolmap", "pycolmap_cpu", "pycolmap_cuda"} + raise ValueError( + f"Unknown backend: {backend} not in " f"{{{','.join(backends)}}}." + ) + + def extract_single_image(self, image: torch.Tensor): + image_np = image.cpu().numpy().squeeze(0) + + if self.conf.backend.startswith("pycolmap"): + if version.parse(pycolmap.__version__) >= version.parse("0.5.0"): + detections, descriptors = self.sift.extract(image_np) + scores = None # Scores are not exposed by COLMAP anymore. + else: + detections, scores, descriptors = self.sift.extract(image_np) + keypoints = detections[:, :2] # Keep only (x, y). + scales, angles = detections[:, -2:].T + if scores is not None and ( + self.conf.backend == "pycolmap_cpu" or not pycolmap.has_cuda + ): + # Set the scores as a combination of abs. response and scale. + scores = np.abs(scores) * scales + elif self.conf.backend == "opencv": + # TODO: Check if opencv keypoints are already in corner convention + keypoints, scores, scales, angles, descriptors = run_opencv_sift( + self.sift, (image_np * 255.0).astype(np.uint8) + ) + pred = { + "keypoints": keypoints, + "scales": scales, + "oris": angles, + "descriptors": descriptors, + } + if scores is not None: + pred["keypoint_scores"] = scores + + # sometimes pycolmap returns points outside the image. We remove them + if self.conf.backend.startswith("pycolmap"): + is_inside = ( + pred["keypoints"] + 0.5 < np.array([image_np.shape[-2:][::-1]]) + ).all(-1) + pred = {k: v[is_inside] for k, v in pred.items()} + + if self.conf.nms_radius is not None: + keep = filter_dog_point( + pred["keypoints"], + pred["scales"], + pred["oris"], + image_np.shape, + self.conf.nms_radius, + scores=pred.get("keypoint_scores"), + ) + pred = {k: v[keep] for k, v in pred.items()} + + pred = {k: torch.from_numpy(v) for k, v in pred.items()} + if scores is not None: + # Keep the k keypoints with highest score + num_points = self.conf.max_num_keypoints + if num_points is not None and len(pred["keypoints"]) > num_points: + indices = torch.topk(pred["keypoint_scores"], num_points).indices + pred = {k: v[indices] for k, v in pred.items()} + + return pred + + def forward(self, data: dict) -> dict: + image = data["image"] + if image.shape[1] == 3: + image = rgb_to_grayscale(image) + device = image.device + image = image.cpu() + pred = [] + for k in range(len(image)): + img = image[k] + if "image_size" in data.keys(): + # avoid extracting points in padded areas + w, h = data["image_size"][k] + img = img[:, :h, :w] + p = self.extract_single_image(img) + pred.append(p) + pred = {k: torch.stack([p[k] for p in pred], 0).to(device) for k in pred[0]} + if self.conf.rootsift: + pred["descriptors"] = sift_to_rootsift(pred["descriptors"]) + return pred diff --git a/lightglue/superpoint.py b/lightglue/superpoint.py index 999df9e..d6d380e 100644 --- a/lightglue/superpoint.py +++ b/lightglue/superpoint.py @@ -43,9 +43,10 @@ # Adapted by Remi Pautrat, Philipp Lindenberger import torch +from kornia.color import rgb_to_grayscale from torch import nn -from .utils import ImagePreprocessor +from .utils import Extractor def simple_nms(scores, nms_radius: int): @@ -94,7 +95,7 @@ def sample_descriptors(keypoints, descriptors, s: int = 8): return descriptors -class SuperPoint(nn.Module): +class SuperPoint(Extractor): """SuperPoint Convolutional Detector and Descriptor SuperPoint: Self-Supervised Interest Point Detection and @@ -112,17 +113,13 @@ class SuperPoint(nn.Module): } preprocess_conf = { - **ImagePreprocessor.default_conf, "resize": 1024, - "grayscale": True, } required_data_keys = ["image"] def __init__(self, **conf): - super().__init__() - self.conf = {**self.default_conf, **conf} - + super().__init__(**conf) # Update with default configuration. self.relu = nn.ReLU(inplace=True) self.pool = nn.MaxPool2d(kernel_size=2, stride=2) c1, c2, c3, c4, c5 = 64, 64, 128, 128, 256 @@ -141,14 +138,13 @@ def __init__(self, **conf): self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) self.convDb = nn.Conv2d( - c5, self.conf["descriptor_dim"], kernel_size=1, stride=1, padding=0 + c5, self.conf.descriptor_dim, kernel_size=1, stride=1, padding=0 ) url = "https://github.com/cvg/LightGlue/releases/download/v0.1_arxiv/superpoint_v1.pth" # noqa self.load_state_dict(torch.hub.load_state_dict_from_url(url)) - mk = self.conf["max_num_keypoints"] - if mk is not None and mk <= 0: + if self.conf.max_num_keypoints is not None and self.conf.max_num_keypoints <= 0: raise ValueError("max_num_keypoints must be positive or None") def forward(self, data: dict) -> dict: @@ -156,9 +152,9 @@ def forward(self, data: dict) -> dict: for key in self.required_data_keys: assert key in data, f"Missing key {key} in data" image = data["image"] - if image.shape[1] == 3: # RGB - scale = image.new_tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1) - image = (image * scale).sum(1, keepdim=True) + if image.shape[1] == 3: + image = rgb_to_grayscale(image) + # Shared Encoder x = self.relu(self.conv1a(image)) x = self.relu(self.conv1b(x)) @@ -179,18 +175,18 @@ def forward(self, data: dict) -> dict: b, _, h, w = scores.shape scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8) scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h * 8, w * 8) - scores = simple_nms(scores, self.conf["nms_radius"]) + scores = simple_nms(scores, self.conf.nms_radius) # Discard keypoints near the image borders - if self.conf["remove_borders"]: - pad = self.conf["remove_borders"] + if self.conf.remove_borders: + pad = self.conf.remove_borders scores[:, :pad] = -1 scores[:, :, :pad] = -1 scores[:, -pad:] = -1 scores[:, :, -pad:] = -1 # Extract keypoints - best_kp = torch.where(scores > self.conf["detection_threshold"]) + best_kp = torch.where(scores > self.conf.detection_threshold) scores = scores[best_kp] # Separate into batches @@ -200,11 +196,11 @@ def forward(self, data: dict) -> dict: scores = [scores[best_kp[0] == i] for i in range(b)] # Keep the k keypoints with highest score - if self.conf["max_num_keypoints"] is not None: + if self.conf.max_num_keypoints is not None: keypoints, scores = list( zip( *[ - top_k_keypoints(k, s, self.conf["max_num_keypoints"]) + top_k_keypoints(k, s, self.conf.max_num_keypoints) for k, s in zip(keypoints, scores) ] ) @@ -229,15 +225,3 @@ def forward(self, data: dict) -> dict: "keypoint_scores": torch.stack(scores, 0), "descriptors": torch.stack(descriptors, 0).transpose(-1, -2).contiguous(), } - - def extract(self, img: torch.Tensor, **conf) -> dict: - """Perform extraction with online resizing""" - if img.dim() == 3: - img = img[None] # add batch dim - assert img.dim() == 4 and img.shape[0] == 1 - shape = img.shape[-2:][::-1] - img, scales = ImagePreprocessor(**{**self.preprocess_conf, **conf})(img) - feats = self.forward({"image": img}) - feats["image_size"] = torch.tensor(shape)[None].to(img).float() - feats["keypoints"] = (feats["keypoints"] + 0.5) / scales[None] - 0.5 - return feats diff --git a/lightglue/utils.py b/lightglue/utils.py index ab7fc0e..d1c1ab2 100644 --- a/lightglue/utils.py +++ b/lightglue/utils.py @@ -16,7 +16,6 @@ class ImagePreprocessor: "interpolation": "bilinear", "align_corners": None, "antialias": True, - "grayscale": False, # convert rgb to grayscale } def __init__(self, **conf) -> None: @@ -36,10 +35,6 @@ def __call__(self, img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: align_corners=self.conf.align_corners, ) scale = torch.Tensor([img.shape[-1] / w, img.shape[-2] / h]).to(img) - if self.conf.grayscale and img.shape[-3] == 3: - img = kornia.color.rgb_to_grayscale(img) - elif not self.conf.grayscale and img.shape[-3] == 1: - img = kornia.color.grayscale_to_rgb(img) return img, scale @@ -133,6 +128,25 @@ def load_image(path: Path, resize: int = None, **kwargs) -> torch.Tensor: return numpy_image_to_torch(image) +class Extractor(torch.nn.Module): + def __init__(self, **conf): + super().__init__() + self.conf = SimpleNamespace(**{**self.default_conf, **conf}) + + @torch.no_grad() + def extract(self, img: torch.Tensor, **conf) -> dict: + """Perform extraction with online resizing""" + if img.dim() == 3: + img = img[None] # add batch dim + assert img.dim() == 4 and img.shape[0] == 1 + shape = img.shape[-2:][::-1] + img, scales = ImagePreprocessor(**{**self.preprocess_conf, **conf})(img) + feats = self.forward({"image": img}) + feats["image_size"] = torch.tensor(shape)[None].to(img).float() + feats["keypoints"] = (feats["keypoints"] + 0.5) / scales[None] - 0.5 + return feats + + def match_pair( extractor, matcher,