Skip to content

Commit

Permalink
Black formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
rpautrat committed Feb 11, 2024
1 parent 201905b commit dbf8d47
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 18 deletions.
1 change: 1 addition & 0 deletions gluefactory/datasets/eth3d.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
ETH3D multi-view benchmark, used for line matching evaluation.
"""

import logging
import os
import shutil
Expand Down
1 change: 1 addition & 0 deletions gluefactory/datasets/hpatches.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Simply load images from a folder or nested folders (does not have any split).
"""

import argparse
import logging
import tarfile
Expand Down
6 changes: 3 additions & 3 deletions gluefactory/geometry/gt_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,9 +375,9 @@ def gt_line_matches_from_pose_depth(
all_in_batch = (
torch.arange(b_size)[:, None].repeat(1, assignation.shape[-1]).flatten()
)
positive[
all_in_batch, assignation[:, 0].flatten(), assignation[:, 1].flatten()
] = True
positive[all_in_batch, assignation[:, 0].flatten(), assignation[:, 1].flatten()] = (
True
)

m0 = assignation.new_full((b_size, n_lines0), unmatched, dtype=torch.long)
m0.scatter_(-1, assignation[:, 0], assignation[:, 1])
Expand Down
17 changes: 11 additions & 6 deletions gluefactory/models/cache_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,11 @@ def pad_line_features(pred, seq_l: int = None):

def recursive_load(grp, pkeys):
return {
k: torch.from_numpy(grp[k].__array__())
if isinstance(grp[k], h5py.Dataset)
else recursive_load(grp[k], list(grp.keys()))
k: (
torch.from_numpy(grp[k].__array__())
if isinstance(grp[k], h5py.Dataset)
else recursive_load(grp[k], list(grp.keys()))
)
for k in pkeys
}

Expand Down Expand Up @@ -108,9 +110,12 @@ def _forward(self, data):
pred = recursive_load(grp, pkeys)
if self.numeric_dtype is not None:
pred = {
k: v
if not isinstance(v, torch.Tensor) or not torch.is_floating_point(v)
else v.to(dtype=self.numeric_dtype)
k: (
v
if not isinstance(v, torch.Tensor)
or not torch.is_floating_point(v)
else v.to(dtype=self.numeric_dtype)
)
for k, v in pred.items()
}
pred = batch_to_device(pred, device)
Expand Down
8 changes: 5 additions & 3 deletions gluefactory/models/extractors/aliked.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,9 +717,11 @@ def _init(self, conf):
radius=conf.nms_radius,
top_k=-1 if conf.detection_threshold > 0 else conf.max_num_keypoints,
scores_th=conf.detection_threshold,
n_limit=conf.max_num_keypoints
if conf.max_num_keypoints > 0
else self.n_limit_max,
n_limit=(
conf.max_num_keypoints
if conf.max_num_keypoints > 0
else self.n_limit_max
),
)

# load pretrained
Expand Down
1 change: 1 addition & 0 deletions gluefactory/models/extractors/superpoint_open.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
The implementation of this model and its trained weights are made
available under the MIT license.
"""

from collections import OrderedDict
from types import SimpleNamespace

Expand Down
6 changes: 3 additions & 3 deletions gluefactory/models/lines/wireframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,9 @@ def _forward(self, data):
associativity = torch.eye(
len(all_points[-1]), dtype=torch.bool, device=device
)
associativity[
: n_true_junctions[bs], : n_true_junctions[bs]
] = line_association[bs][: n_true_junctions[bs], : n_true_junctions[bs]]
associativity[: n_true_junctions[bs], : n_true_junctions[bs]] = (
line_association[bs][: n_true_junctions[bs], : n_true_junctions[bs]]
)
pl_associativity.append(associativity)

all_points = torch.stack(all_points, dim=0)
Expand Down
8 changes: 5 additions & 3 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@ def create_input_data(cv_img0, cv_img1, device):
data = {"view0": ip(img0), "view1": ip(img1)}
data = map_tensor(
data,
lambda t: t[None].to(device)
if isinstance(t, Tensor)
else torch.from_numpy(t)[None].to(device),
lambda t: (
t[None].to(device)
if isinstance(t, Tensor)
else torch.from_numpy(t)[None].to(device)
),
)
return data

Expand Down

0 comments on commit dbf8d47

Please sign in to comment.