diff --git a/.flake8 b/.flake8 index 8dd399ab..899119f2 100644 --- a/.flake8 +++ b/.flake8 @@ -1,3 +1,4 @@ [flake8] max-line-length = 88 extend-ignore = E203 +exclude = .git,__pycache__,build,.venv/ diff --git a/.github/workflows/code-quality.yml b/.github/workflows/code-quality.yml index 41be302f..368b225f 100644 --- a/.github/workflows/code-quality.yml +++ b/.github/workflows/code-quality.yml @@ -8,24 +8,17 @@ on: pull_request: types: [ assigned, opened, synchronize, reopened ] jobs: - formatting-check: - name: Formatting Check - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - - uses: psf/black@stable - with: - jupyter: true - linting-check: - name: Linting Check + check: + name: Format and Lint Checks runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - uses: actions/setup-python@v4 with: - python-version: "3.10" + python-version: '3.10' cache: 'pip' - run: python -m pip install --upgrade pip - - run: python -m pip install . - - run: python -m pip install --upgrade flake8 - - run: python -m flake8 . --exclude build/ + - run: python -m pip install .[dev] + - run: python -m flake8 . + - run: python -m isort . --check-only --diff + - run: python -m black . --check --diff diff --git a/gluefactory/__init__.py b/gluefactory/__init__.py index b3d01152..0d83f92d 100644 --- a/gluefactory/__init__.py +++ b/gluefactory/__init__.py @@ -1,4 +1,5 @@ import logging + from .utils.experiments import load_experiment # noqa: F401 formatter = logging.Formatter( diff --git a/gluefactory/datasets/__init__.py b/gluefactory/datasets/__init__.py index 2941a4c1..ce05e9a6 100644 --- a/gluefactory/datasets/__init__.py +++ b/gluefactory/datasets/__init__.py @@ -1,6 +1,7 @@ import importlib.util -from .base_dataset import BaseDataset + from ..utils.tools import get_class +from .base_dataset import BaseDataset def get_dataset(name): diff --git a/gluefactory/datasets/augmentations.py b/gluefactory/datasets/augmentations.py index ea726a0e..bd391294 100644 --- a/gluefactory/datasets/augmentations.py +++ b/gluefactory/datasets/augmentations.py @@ -1,11 +1,11 @@ from typing import Union import albumentations as A +import cv2 import numpy as np import torch from albumentations.pytorch.transforms import ToTensorV2 from omegaconf import OmegaConf -import cv2 class IdentityTransform(A.ImageOnlyTransform): diff --git a/gluefactory/datasets/base_dataset.py b/gluefactory/datasets/base_dataset.py index aeb316a9..ef622cbc 100644 --- a/gluefactory/datasets/base_dataset.py +++ b/gluefactory/datasets/base_dataset.py @@ -3,12 +3,13 @@ See mnist.py for an example of dataset. """ -from abc import ABCMeta, abstractmethod import collections import logging -from omegaconf import OmegaConf +from abc import ABCMeta, abstractmethod + import omegaconf import torch +from omegaconf import OmegaConf from torch.utils.data import DataLoader, Sampler, get_worker_info from torch.utils.data._utils.collate import ( default_collate_err_msg_format, diff --git a/gluefactory/datasets/eth3d.py b/gluefactory/datasets/eth3d.py index e0cdf14e..ca5e2648 100644 --- a/gluefactory/datasets/eth3d.py +++ b/gluefactory/datasets/eth3d.py @@ -4,18 +4,18 @@ import logging import os import shutil +import zipfile +from pathlib import Path -import numpy as np import cv2 +import numpy as np import torch -from pathlib import Path -import zipfile -from .base_dataset import BaseDataset -from .utils import scale_intrinsics from ..geometry.wrappers import Camera, Pose from ..settings import DATA_PATH from ..utils.image import ImagePreprocessor, load_image +from .base_dataset import BaseDataset +from .utils import scale_intrinsics logger = logging.getLogger(__name__) diff --git a/gluefactory/datasets/homographies.py b/gluefactory/datasets/homographies.py index f5a21310..08f7563c 100644 --- a/gluefactory/datasets/homographies.py +++ b/gluefactory/datasets/homographies.py @@ -11,25 +11,25 @@ from pathlib import Path import cv2 +import matplotlib.pyplot as plt import numpy as np import omegaconf import torch -import matplotlib.pyplot as plt from omegaconf import OmegaConf from tqdm import tqdm -from .augmentations import IdentityAugmentation, augmentations -from .base_dataset import BaseDataset -from ..settings import DATA_PATH -from ..models.cache_loader import CacheLoader, pad_local_features -from ..utils.image import read_image from ..geometry.homography import ( - sample_homography_corners, compute_homography, + sample_homography_corners, warp_points, ) +from ..models.cache_loader import CacheLoader, pad_local_features +from ..settings import DATA_PATH +from ..utils.image import read_image from ..utils.tools import fork_rng from ..visualization.viz2d import plot_image_grid +from .augmentations import IdentityAugmentation, augmentations +from .base_dataset import BaseDataset logger = logging.getLogger(__name__) diff --git a/gluefactory/datasets/hpatches.py b/gluefactory/datasets/hpatches.py index d3054cd9..baf4ac8e 100644 --- a/gluefactory/datasets/hpatches.py +++ b/gluefactory/datasets/hpatches.py @@ -4,16 +4,17 @@ import argparse import logging import tarfile + import matplotlib.pyplot as plt import numpy as np import torch from omegaconf import OmegaConf -from .base_dataset import BaseDataset from ..settings import DATA_PATH -from ..utils.image import load_image, ImagePreprocessor +from ..utils.image import ImagePreprocessor, load_image from ..utils.tools import fork_rng from ..visualization.viz2d import plot_image_grid +from .base_dataset import BaseDataset logger = logging.getLogger(__name__) diff --git a/gluefactory/datasets/image_folder.py b/gluefactory/datasets/image_folder.py index 474a6c17..ecbd3abf 100644 --- a/gluefactory/datasets/image_folder.py +++ b/gluefactory/datasets/image_folder.py @@ -2,13 +2,14 @@ Simply load images from a folder or nested folders (does not have any split). """ -from pathlib import Path -import torch import logging +from pathlib import Path + import omegaconf +import torch +from ..utils.image import ImagePreprocessor, load_image from .base_dataset import BaseDataset -from ..utils.image import load_image, ImagePreprocessor class ImageFolder(BaseDataset, torch.utils.data.Dataset): diff --git a/gluefactory/datasets/image_pairs.py b/gluefactory/datasets/image_pairs.py index da0706a2..08bd7603 100644 --- a/gluefactory/datasets/image_pairs.py +++ b/gluefactory/datasets/image_pairs.py @@ -3,13 +3,14 @@ """ from pathlib import Path -import torch + import numpy as np -from .base_dataset import BaseDataset -from ..utils.image import load_image, ImagePreprocessor +import torch -from ..settings import DATA_PATH from ..geometry.wrappers import Camera, Pose +from ..settings import DATA_PATH +from ..utils.image import ImagePreprocessor, load_image +from .base_dataset import BaseDataset def names_to_pair(name0, name1, separator="/"): diff --git a/gluefactory/datasets/megadepth.py b/gluefactory/datasets/megadepth.py index d4b60020..19a7586c 100644 --- a/gluefactory/datasets/megadepth.py +++ b/gluefactory/datasets/megadepth.py @@ -1,9 +1,9 @@ import argparse import logging -from pathlib import Path -from collections.abc import Iterable -import tarfile import shutil +import tarfile +from collections.abc import Iterable +from pathlib import Path import h5py import matplotlib.pyplot as plt @@ -12,18 +12,14 @@ import torch from omegaconf import OmegaConf -from .base_dataset import BaseDataset -from .utils import ( - scale_intrinsics, - rotate_intrinsics, - rotate_pose_inplane, -) from ..geometry.wrappers import Camera, Pose from ..models.cache_loader import CacheLoader -from ..utils.tools import fork_rng -from ..utils.image import load_image, ImagePreprocessor from ..settings import DATA_PATH -from ..visualization.viz2d import plot_image_grid, plot_heatmaps +from ..utils.image import ImagePreprocessor, load_image +from ..utils.tools import fork_rng +from ..visualization.viz2d import plot_heatmaps, plot_image_grid +from .base_dataset import BaseDataset +from .utils import rotate_intrinsics, rotate_pose_inplane, scale_intrinsics logger = logging.getLogger(__name__) scene_lists_path = Path(__file__).parent / "megadepth_scene_lists" diff --git a/gluefactory/eval/__init__.py b/gluefactory/eval/__init__.py index e072cf9f..0d451e06 100644 --- a/gluefactory/eval/__init__.py +++ b/gluefactory/eval/__init__.py @@ -1,4 +1,5 @@ import torch + from ..utils.tools import get_class from .eval_pipeline import EvalPipeline diff --git a/gluefactory/eval/eth3d.py b/gluefactory/eval/eth3d.py index ef2b3a79..d2fe3a5d 100644 --- a/gluefactory/eval/eth3d.py +++ b/gluefactory/eval/eth3d.py @@ -1,23 +1,18 @@ +from collections import defaultdict from pathlib import Path -from omegaconf import OmegaConf + import matplotlib.pyplot as plt -from collections import defaultdict -from tqdm import tqdm import numpy as np +from omegaconf import OmegaConf +from tqdm import tqdm -from .io import ( - parse_eval_args, - load_model, - get_eval_parser, -) - -from .eval_pipeline import EvalPipeline, load_eval - -from ..utils.export_predictions import export_predictions -from .utils import get_tp_fp_pts, aggregate_pr_results -from ..settings import EVAL_PATH -from ..models.cache_loader import CacheLoader from ..datasets import get_dataset +from ..models.cache_loader import CacheLoader +from ..settings import EVAL_PATH +from ..utils.export_predictions import export_predictions +from .eval_pipeline import EvalPipeline, load_eval +from .io import get_eval_parser, load_model, parse_eval_args +from .utils import aggregate_pr_results, get_tp_fp_pts def eval_dataset(loader, pred_file, suffix=""): diff --git a/gluefactory/eval/eval_pipeline.py b/gluefactory/eval/eval_pipeline.py index 750969af..ac562377 100644 --- a/gluefactory/eval/eval_pipeline.py +++ b/gluefactory/eval/eval_pipeline.py @@ -1,7 +1,8 @@ -from omegaconf import OmegaConf -import numpy as np import json + import h5py +import numpy as np +from omegaconf import OmegaConf def load_eval(dir): diff --git a/gluefactory/eval/hpatches.py b/gluefactory/eval/hpatches.py index c714bf11..8be7b704 100644 --- a/gluefactory/eval/hpatches.py +++ b/gluefactory/eval/hpatches.py @@ -1,31 +1,27 @@ +from collections import defaultdict +from collections.abc import Iterable from pathlib import Path -from omegaconf import OmegaConf from pprint import pprint + import matplotlib.pyplot as plt -from collections import defaultdict -from collections.abc import Iterable -from tqdm import tqdm import numpy as np -from ..visualization.viz2d import plot_cumulative +from omegaconf import OmegaConf +from tqdm import tqdm -from .io import ( - parse_eval_args, - load_model, - get_eval_parser, -) -from ..utils.export_predictions import export_predictions -from ..settings import EVAL_PATH -from ..models.cache_loader import CacheLoader from ..datasets import get_dataset +from ..models.cache_loader import CacheLoader +from ..settings import EVAL_PATH +from ..utils.export_predictions import export_predictions +from ..utils.tools import AUCMetric +from ..visualization.viz2d import plot_cumulative +from .eval_pipeline import EvalPipeline +from .io import get_eval_parser, load_model, parse_eval_args from .utils import ( + eval_homography_dlt, eval_homography_robust, - eval_poses, eval_matches_homography, - eval_homography_dlt, + eval_poses, ) -from ..utils.tools import AUCMetric - -from .eval_pipeline import EvalPipeline class HPatchesPipeline(EvalPipeline): diff --git a/gluefactory/eval/inspect.py b/gluefactory/eval/inspect.py index 913371b2..1b7a3929 100644 --- a/gluefactory/eval/inspect.py +++ b/gluefactory/eval/inspect.py @@ -1,9 +1,10 @@ import argparse +from collections import defaultdict from pathlib import Path -import matplotlib.pyplot as plt -import matplotlib from pprint import pprint -from collections import defaultdict + +import matplotlib +import matplotlib.pyplot as plt from ..settings import EVAL_PATH from ..visualization.global_frame import GlobalFrame @@ -11,7 +12,6 @@ from . import get_benchmark from .eval_pipeline import load_eval - if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("benchmark", type=str) diff --git a/gluefactory/eval/io.py b/gluefactory/eval/io.py index 93b72593..067e8456 100644 --- a/gluefactory/eval/io.py +++ b/gluefactory/eval/io.py @@ -1,13 +1,14 @@ -import pkg_resources +import argparse from pathlib import Path +from pprint import pprint from typing import Optional + +import pkg_resources from omegaconf import OmegaConf -import argparse -from pprint import pprint from ..models import get_model -from ..utils.experiments import load_experiment from ..settings import TRAINING_PATH +from ..utils.experiments import load_experiment def parse_config_path(name_or_path: Optional[str], defaults: str) -> Path: diff --git a/gluefactory/eval/megadepth1500.py b/gluefactory/eval/megadepth1500.py index d9eb3377..e3593612 100644 --- a/gluefactory/eval/megadepth1500.py +++ b/gluefactory/eval/megadepth1500.py @@ -1,26 +1,23 @@ -import torch +import zipfile +from collections import defaultdict +from collections.abc import Iterable from pathlib import Path -from omegaconf import OmegaConf from pprint import pprint + import matplotlib.pyplot as plt -from collections import defaultdict -from collections.abc import Iterable -from tqdm import tqdm -import zipfile import numpy as np -from ..visualization.viz2d import plot_cumulative -from .io import ( - parse_eval_args, - load_model, - get_eval_parser, -) -from ..utils.export_predictions import export_predictions -from ..settings import EVAL_PATH, DATA_PATH -from ..models.cache_loader import CacheLoader +import torch +from omegaconf import OmegaConf +from tqdm import tqdm + from ..datasets import get_dataset +from ..models.cache_loader import CacheLoader +from ..settings import DATA_PATH, EVAL_PATH +from ..utils.export_predictions import export_predictions +from ..visualization.viz2d import plot_cumulative from .eval_pipeline import EvalPipeline - -from .utils import eval_relative_pose_robust, eval_poses, eval_matches_epipolar +from .io import get_eval_parser, load_model, parse_eval_args +from .utils import eval_matches_epipolar, eval_poses, eval_relative_pose_robust class MegaDepth1500Pipeline(EvalPipeline): diff --git a/gluefactory/eval/utils.py b/gluefactory/eval/utils.py index 77adb8df..c6e6f006 100644 --- a/gluefactory/eval/utils.py +++ b/gluefactory/eval/utils.py @@ -1,11 +1,12 @@ +import kornia import numpy as np import torch -import kornia -from ..geometry.epipolar import relative_pose_error, generalized_epi_dist -from ..geometry.homography import sym_homography_error, homography_corner_error + +from ..geometry.epipolar import generalized_epi_dist, relative_pose_error from ..geometry.gt_generation import IGNORE_FEATURE -from ..utils.tools import AUCMetric +from ..geometry.homography import homography_corner_error, sym_homography_error from ..robust_estimators import load_estimator +from ..utils.tools import AUCMetric def check_keys_recursive(d, pattern): diff --git a/gluefactory/geometry/depth.py b/gluefactory/geometry/depth.py index ea2da60f..ca68bc5f 100644 --- a/gluefactory/geometry/depth.py +++ b/gluefactory/geometry/depth.py @@ -1,5 +1,5 @@ -import torch import kornia +import torch from .utils import get_image_coords from .wrappers import Camera diff --git a/gluefactory/geometry/epipolar.py b/gluefactory/geometry/epipolar.py index d7c71296..7e1507c0 100644 --- a/gluefactory/geometry/epipolar.py +++ b/gluefactory/geometry/epipolar.py @@ -1,7 +1,8 @@ +import numpy as np import torch + from .utils import skew_symmetric, to_homogeneous -from .wrappers import Pose, Camera -import numpy as np +from .wrappers import Camera, Pose def T_to_E(T: Pose): diff --git a/gluefactory/geometry/gt_generation.py b/gluefactory/geometry/gt_generation.py index 52acc0fe..21390cd7 100644 --- a/gluefactory/geometry/gt_generation.py +++ b/gluefactory/geometry/gt_generation.py @@ -2,9 +2,9 @@ import torch from scipy.optimize import linear_sum_assignment -from .homography import warp_points_torch +from .depth import project, sample_depth from .epipolar import T_to_E, sym_epipolar_distance_all -from .depth import sample_depth, project +from .homography import warp_points_torch IGNORE_FEATURE = -2 UNMATCHED_FEATURE = -1 diff --git a/gluefactory/geometry/homography.py b/gluefactory/geometry/homography.py index 7679bf97..3acb9307 100644 --- a/gluefactory/geometry/homography.py +++ b/gluefactory/geometry/homography.py @@ -1,9 +1,10 @@ -from typing import Tuple import math +from typing import Tuple + import numpy as np import torch -from .utils import to_homogeneous, from_homogeneous +from .utils import from_homogeneous, to_homogeneous def flat2mat(H): diff --git a/gluefactory/geometry/wrappers.py b/gluefactory/geometry/wrappers.py index 0886f58e..9d4a1b10 100644 --- a/gluefactory/geometry/wrappers.py +++ b/gluefactory/geometry/wrappers.py @@ -6,13 +6,14 @@ import functools import inspect import math -from typing import Union, Tuple, List, Dict, NamedTuple, Optional -import torch +from typing import Dict, List, NamedTuple, Optional, Tuple, Union + import numpy as np +import torch from .utils import ( - distort_points, J_distort_points, + distort_points, skew_symmetric, so3exp_map, to_homogeneous, diff --git a/gluefactory/models/__init__.py b/gluefactory/models/__init__.py index 5d3f71b9..a9d1a05c 100644 --- a/gluefactory/models/__init__.py +++ b/gluefactory/models/__init__.py @@ -1,6 +1,7 @@ import importlib.util -from .base_model import BaseModel + from ..utils.tools import get_class +from .base_model import BaseModel def get_model(name): diff --git a/gluefactory/models/base_model.py b/gluefactory/models/base_model.py index ed4e1078..7313d986 100644 --- a/gluefactory/models/base_model.py +++ b/gluefactory/models/base_model.py @@ -3,10 +3,11 @@ """ from abc import ABCMeta, abstractmethod +from copy import copy + import omegaconf from omegaconf import OmegaConf from torch import nn -from copy import copy class MetaModel(ABCMeta): diff --git a/gluefactory/models/cache_loader.py b/gluefactory/models/cache_loader.py index 40cc55d6..3fbf0f71 100644 --- a/gluefactory/models/cache_loader.py +++ b/gluefactory/models/cache_loader.py @@ -1,11 +1,12 @@ -import torch import string + import h5py +import torch -from .base_model import BaseModel -from ..settings import DATA_PATH from ..datasets.base_dataset import collate +from ..settings import DATA_PATH from ..utils.tensor import batch_to_device +from .base_model import BaseModel from .utils.misc import pad_to_length diff --git a/gluefactory/models/extractors/aliked.py b/gluefactory/models/extractors/aliked.py index 45bc46f3..80cd348a 100644 --- a/gluefactory/models/extractors/aliked.py +++ b/gluefactory/models/extractors/aliked.py @@ -1,10 +1,11 @@ +from typing import Callable, Optional + import torch -from torch import nn import torch.nn.functional as F -from torchvision.models import resnet -from typing import Optional, Callable -from torch.nn.modules.utils import _pair import torchvision +from torch import nn +from torch.nn.modules.utils import _pair +from torchvision.models import resnet from gluefactory.models.base_model import BaseModel diff --git a/gluefactory/models/extractors/disk_kornia.py b/gluefactory/models/extractors/disk_kornia.py index b403b04a..4d60973d 100644 --- a/gluefactory/models/extractors/disk_kornia.py +++ b/gluefactory/models/extractors/disk_kornia.py @@ -1,5 +1,5 @@ -import torch import kornia +import torch from ..base_model import BaseModel from ..utils.misc import pad_and_stack diff --git a/gluefactory/models/extractors/grid_extractor.py b/gluefactory/models/extractors/grid_extractor.py index 882a125d..dd221d97 100644 --- a/gluefactory/models/extractors/grid_extractor.py +++ b/gluefactory/models/extractors/grid_extractor.py @@ -1,6 +1,7 @@ -import torch import math +import torch + from ..base_model import BaseModel diff --git a/gluefactory/models/extractors/keynet_affnet_hardnet.py b/gluefactory/models/extractors/keynet_affnet_hardnet.py index 15f1dca2..b9091ea4 100644 --- a/gluefactory/models/extractors/keynet_affnet_hardnet.py +++ b/gluefactory/models/extractors/keynet_affnet_hardnet.py @@ -1,5 +1,5 @@ -import torch import kornia +import torch from ..base_model import BaseModel from ..utils.misc import pad_to_length diff --git a/gluefactory/models/extractors/mixed.py b/gluefactory/models/extractors/mixed.py index 3bef2a4e..5524cb6e 100644 --- a/gluefactory/models/extractors/mixed.py +++ b/gluefactory/models/extractors/mixed.py @@ -1,10 +1,8 @@ -from omegaconf import OmegaConf import torch.nn.functional as F +from omegaconf import OmegaConf -from ..base_model import BaseModel from .. import get_model - -# from ...geometry.depth import sample_fmap +from ..base_model import BaseModel to_ctr = OmegaConf.to_container # convert DictConfig to dict diff --git a/gluefactory/models/extractors/sift.py b/gluefactory/models/extractors/sift.py index 24d7b7bb..5eb0c956 100644 --- a/gluefactory/models/extractors/sift.py +++ b/gluefactory/models/extractors/sift.py @@ -1,12 +1,11 @@ +import cv2 import numpy as np -import torch import pycolmap -from scipy.spatial import KDTree +import torch from omegaconf import OmegaConf -import cv2 +from scipy.spatial import KDTree from ..base_model import BaseModel - from ..utils.misc import pad_to_length EPS = 1e-6 diff --git a/gluefactory/models/extractors/superpoint_open.py b/gluefactory/models/extractors/superpoint_open.py index 8da32a49..1f960407 100644 --- a/gluefactory/models/extractors/superpoint_open.py +++ b/gluefactory/models/extractors/superpoint_open.py @@ -5,11 +5,12 @@ The implementation of this model and its trained weights are made available under the MIT license. """ -import torch.nn as nn -import torch from collections import OrderedDict from types import SimpleNamespace +import torch +import torch.nn as nn + from ..base_model import BaseModel from ..utils.misc import pad_and_stack diff --git a/gluefactory/models/lines/deeplsd.py b/gluefactory/models/lines/deeplsd.py index 72fb5323..c35aa01e 100644 --- a/gluefactory/models/lines/deeplsd.py +++ b/gluefactory/models/lines/deeplsd.py @@ -1,9 +1,9 @@ +import deeplsd.models.deeplsd_inference as deeplsd_inference import numpy as np import torch -import deeplsd.models.deeplsd_inference as deeplsd_inference -from ..base_model import BaseModel from ...settings import DATA_PATH +from ..base_model import BaseModel class DeepLSD(BaseModel): diff --git a/gluefactory/models/lines/wireframe.py b/gluefactory/models/lines/wireframe.py index c2d086c5..ac0d0b5a 100644 --- a/gluefactory/models/lines/wireframe.py +++ b/gluefactory/models/lines/wireframe.py @@ -1,8 +1,8 @@ import torch from sklearn.cluster import DBSCAN -from ..base_model import BaseModel from .. import get_model +from ..base_model import BaseModel def sample_descriptors_corner_conv(keypoints, descriptors, s: int = 8): diff --git a/gluefactory/models/matchers/depth_matcher.py b/gluefactory/models/matchers/depth_matcher.py index 1d223655..125ded2b 100644 --- a/gluefactory/models/matchers/depth_matcher.py +++ b/gluefactory/models/matchers/depth_matcher.py @@ -1,9 +1,10 @@ -from ..base_model import BaseModel +import torch + from ...geometry.gt_generation import ( - gt_matches_from_pose_depth, gt_line_matches_from_pose_depth, + gt_matches_from_pose_depth, ) -import torch +from ..base_model import BaseModel class DepthMatcher(BaseModel): diff --git a/gluefactory/models/matchers/gluestick.py b/gluefactory/models/matchers/gluestick.py index 1df19b5f..0187e0c3 100644 --- a/gluefactory/models/matchers/gluestick.py +++ b/gluefactory/models/matchers/gluestick.py @@ -7,9 +7,9 @@ import torch.utils.checkpoint from torch import nn +from ...settings import DATA_PATH from ..base_model import BaseModel from ..utils.metrics import matcher_metrics -from ...settings import DATA_PATH warnings.filterwarnings("ignore", category=UserWarning) ETH_EPS = 1e-8 diff --git a/gluefactory/models/matchers/homography_matcher.py b/gluefactory/models/matchers/homography_matcher.py index 3ef346ee..d3642fb7 100644 --- a/gluefactory/models/matchers/homography_matcher.py +++ b/gluefactory/models/matchers/homography_matcher.py @@ -1,8 +1,8 @@ -from ..base_model import BaseModel from ...geometry.gt_generation import ( - gt_matches_from_homography, gt_line_matches_from_homography, + gt_matches_from_homography, ) +from ..base_model import BaseModel class HomographyMatcher(BaseModel): diff --git a/gluefactory/models/matchers/lightglue.py b/gluefactory/models/matchers/lightglue.py index 8589fa16..7671f609 100644 --- a/gluefactory/models/matchers/lightglue.py +++ b/gluefactory/models/matchers/lightglue.py @@ -1,15 +1,17 @@ import warnings +from pathlib import Path +from typing import Callable, List, Optional + import numpy as np import torch -from torch import nn import torch.nn.functional as F -from typing import Optional, List, Callable -from torch.utils.checkpoint import checkpoint from omegaconf import OmegaConf +from torch import nn +from torch.utils.checkpoint import checkpoint + from ...settings import DATA_PATH from ..utils.losses import NLLLoss from ..utils.metrics import matcher_metrics -from pathlib import Path FLASH_AVAILABLE = hasattr(F, "scaled_dot_product_attention") diff --git a/gluefactory/models/matchers/lightglue_pretrained.py b/gluefactory/models/matchers/lightglue_pretrained.py index 034684a4..2e7c71b6 100644 --- a/gluefactory/models/matchers/lightglue_pretrained.py +++ b/gluefactory/models/matchers/lightglue_pretrained.py @@ -1,7 +1,8 @@ -from ..base_model import BaseModel from lightglue import LightGlue as LightGlue_ from omegaconf import OmegaConf +from ..base_model import BaseModel + class LightGlue(BaseModel): default_conf = {"features": "superpoint", **LightGlue_.default_conf} diff --git a/gluefactory/models/matchers/nearest_neighbor_matcher.py b/gluefactory/models/matchers/nearest_neighbor_matcher.py index b3ad4270..7bbc8ae5 100644 --- a/gluefactory/models/matchers/nearest_neighbor_matcher.py +++ b/gluefactory/models/matchers/nearest_neighbor_matcher.py @@ -3,8 +3,9 @@ Optionally apply the mutual check and threshold the distance or ratio. """ -import torch import logging + +import torch import torch.nn.functional as F from ..base_model import BaseModel diff --git a/gluefactory/models/triplet_pipeline.py b/gluefactory/models/triplet_pipeline.py index 9bcc8daa..25385177 100644 --- a/gluefactory/models/triplet_pipeline.py +++ b/gluefactory/models/triplet_pipeline.py @@ -9,9 +9,10 @@ If no triplet is found, this falls back to two_view_pipeline.py """ -from .two_view_pipeline import TwoViewPipeline import torch + from ..utils.misc import get_twoview, stack_twoviews, unstack_twoviews +from .two_view_pipeline import TwoViewPipeline def has_triplet(data): diff --git a/gluefactory/models/two_view_pipeline.py b/gluefactory/models/two_view_pipeline.py index 2f521e98..9c517dc7 100644 --- a/gluefactory/models/two_view_pipeline.py +++ b/gluefactory/models/two_view_pipeline.py @@ -11,9 +11,9 @@ """ from omegaconf import OmegaConf -from .base_model import BaseModel -from . import get_model +from . import get_model +from .base_model import BaseModel to_ctr = OmegaConf.to_container # convert DictConfig to dict diff --git a/gluefactory/models/utils/misc.py b/gluefactory/models/utils/misc.py index 2cb03d65..e86d1add 100644 --- a/gluefactory/models/utils/misc.py +++ b/gluefactory/models/utils/misc.py @@ -1,5 +1,6 @@ import math from typing import List, Optional, Tuple + import torch diff --git a/gluefactory/robust_estimators/__init__.py b/gluefactory/robust_estimators/__init__.py index f5a85cd8..a9d9c9b9 100644 --- a/gluefactory/robust_estimators/__init__.py +++ b/gluefactory/robust_estimators/__init__.py @@ -1,4 +1,5 @@ import inspect + from .base_estimator import BaseEstimator diff --git a/gluefactory/robust_estimators/base_estimator.py b/gluefactory/robust_estimators/base_estimator.py index a94e35b5..29f8dd45 100644 --- a/gluefactory/robust_estimators/base_estimator.py +++ b/gluefactory/robust_estimators/base_estimator.py @@ -1,6 +1,7 @@ -from omegaconf import OmegaConf from copy import copy +from omegaconf import OmegaConf + class BaseEstimator: base_default_conf = { diff --git a/gluefactory/robust_estimators/homography/poselib.py b/gluefactory/robust_estimators/homography/poselib.py index 0edfe10f..e99e9493 100644 --- a/gluefactory/robust_estimators/homography/poselib.py +++ b/gluefactory/robust_estimators/homography/poselib.py @@ -1,6 +1,6 @@ import poselib -from omegaconf import OmegaConf import torch +from omegaconf import OmegaConf from ..base_estimator import BaseEstimator diff --git a/gluefactory/robust_estimators/relative_pose/opencv.py b/gluefactory/robust_estimators/relative_pose/opencv.py index b212ea32..34442a0f 100644 --- a/gluefactory/robust_estimators/relative_pose/opencv.py +++ b/gluefactory/robust_estimators/relative_pose/opencv.py @@ -1,9 +1,9 @@ import cv2 import numpy as np import torch -from ...geometry.wrappers import Pose -from ...geometry.utils import from_homogeneous +from ...geometry.utils import from_homogeneous +from ...geometry.wrappers import Pose from ..base_estimator import BaseEstimator diff --git a/gluefactory/robust_estimators/relative_pose/poselib.py b/gluefactory/robust_estimators/relative_pose/poselib.py index 35ab87cc..6c736e4e 100644 --- a/gluefactory/robust_estimators/relative_pose/poselib.py +++ b/gluefactory/robust_estimators/relative_pose/poselib.py @@ -1,8 +1,8 @@ import poselib -from omegaconf import OmegaConf import torch -from ...geometry.wrappers import Pose +from omegaconf import OmegaConf +from ...geometry.wrappers import Pose from ..base_estimator import BaseEstimator diff --git a/gluefactory/robust_estimators/relative_pose/pycolmap.py b/gluefactory/robust_estimators/relative_pose/pycolmap.py index c7d09460..21cb2720 100644 --- a/gluefactory/robust_estimators/relative_pose/pycolmap.py +++ b/gluefactory/robust_estimators/relative_pose/pycolmap.py @@ -1,8 +1,8 @@ import pycolmap -from omegaconf import OmegaConf import torch -from ...geometry.wrappers import Pose +from omegaconf import OmegaConf +from ...geometry.wrappers import Pose from ..base_estimator import BaseEstimator diff --git a/gluefactory/scripts/export_local_features.py b/gluefactory/scripts/export_local_features.py index 892f3333..7f3f0a94 100644 --- a/gluefactory/scripts/export_local_features.py +++ b/gluefactory/scripts/export_local_features.py @@ -1,14 +1,14 @@ +import argparse import logging from pathlib import Path -import argparse + import torch from omegaconf import OmegaConf +from ..datasets import get_dataset +from ..models import get_model from ..settings import DATA_PATH from ..utils.export_predictions import export_predictions -from ..models import get_model -from ..datasets import get_dataset - resize = 1600 diff --git a/gluefactory/scripts/export_megadepth.py b/gluefactory/scripts/export_megadepth.py index c94caeca..95e89d81 100644 --- a/gluefactory/scripts/export_megadepth.py +++ b/gluefactory/scripts/export_megadepth.py @@ -1,14 +1,15 @@ +import argparse import logging from pathlib import Path -import argparse + import torch from omegaconf import OmegaConf -from ..settings import DATA_PATH -from ..utils.export_predictions import export_predictions -from ..models import get_model from ..datasets import get_dataset from ..geometry.depth import sample_depth +from ..models import get_model +from ..settings import DATA_PATH +from ..utils.export_predictions import export_predictions resize = 1024 n_kpts = 2048 diff --git a/gluefactory/train.py b/gluefactory/train.py index 2d5b639a..08895d72 100644 --- a/gluefactory/train.py +++ b/gluefactory/train.py @@ -5,37 +5,37 @@ """ import argparse -from pathlib import Path -import signal -import re import copy -from collections import defaultdict +import re import shutil -import numpy as np +import signal +from collections import defaultdict +from pathlib import Path +from pydoc import locate -from omegaconf import OmegaConf -from tqdm import tqdm +import numpy as np import torch -from torch.utils.tensorboard import SummaryWriter +from omegaconf import OmegaConf from torch.cuda.amp import GradScaler, autocast -from pydoc import locate +from torch.utils.tensorboard import SummaryWriter +from tqdm import tqdm -from .models import get_model +from . import __module_name__, logger from .datasets import get_dataset +from .eval import run_benchmark +from .models import get_model +from .settings import EVAL_PATH, TRAINING_PATH +from .utils.experiments import get_best_checkpoint, get_last_checkpoint, save_experiment from .utils.stdout_capturing import capture_outputs +from .utils.tensor import batch_to_device from .utils.tools import ( AverageMetric, MedianMetric, - RecallMetric, PRMetric, - set_seed, + RecallMetric, fork_rng, + set_seed, ) -from .utils.tensor import batch_to_device -from .utils.experiments import get_last_checkpoint, get_best_checkpoint, save_experiment -from .eval import run_benchmark -from .settings import TRAINING_PATH, EVAL_PATH -from . import __module_name__, logger # @TODO: Fix pbar pollution in logs # @TODO: add plotting during evaluation diff --git a/gluefactory/utils/benchmark.py b/gluefactory/utils/benchmark.py index 401578bc..99b4f85f 100644 --- a/gluefactory/utils/benchmark.py +++ b/gluefactory/utils/benchmark.py @@ -1,7 +1,8 @@ -import torch -import numpy as np import time +import numpy as np +import torch + def benchmark(model, data, device, r=100): timings = np.zeros((r, 1)) diff --git a/gluefactory/utils/experiments.py b/gluefactory/utils/experiments.py index 849d0bc7..7723fcea 100644 --- a/gluefactory/utils/experiments.py +++ b/gluefactory/utils/experiments.py @@ -4,16 +4,17 @@ Author: Paul-Edouard Sarlin (skydes) """ -from pathlib import Path import logging +import os import re import shutil -from omegaconf import OmegaConf +from pathlib import Path + import torch -import os +from omegaconf import OmegaConf -from ..settings import TRAINING_PATH from ..models import get_model +from ..settings import TRAINING_PATH logger = logging.getLogger(__name__) diff --git a/gluefactory/utils/export_predictions.py b/gluefactory/utils/export_predictions.py index 084227f2..1157a520 100644 --- a/gluefactory/utils/export_predictions.py +++ b/gluefactory/utils/export_predictions.py @@ -4,11 +4,12 @@ or call from another script. """ -import torch -import numpy as np from pathlib import Path -from tqdm import tqdm + import h5py +import numpy as np +import torch +from tqdm import tqdm from .tensor import batch_to_device diff --git a/gluefactory/utils/image.py b/gluefactory/utils/image.py index 1e6a7e29..1a9b1250 100644 --- a/gluefactory/utils/image.py +++ b/gluefactory/utils/image.py @@ -1,10 +1,11 @@ +import collections.abc as collections from pathlib import Path -import torch -import kornia +from typing import Optional, Tuple + import cv2 +import kornia import numpy as np -from typing import Tuple, Optional -import collections.abc as collections +import torch from omegaconf import OmegaConf diff --git a/gluefactory/utils/stdout_capturing.py b/gluefactory/utils/stdout_capturing.py index 9baef920..bfa2b832 100644 --- a/gluefactory/utils/stdout_capturing.py +++ b/gluefactory/utils/stdout_capturing.py @@ -6,11 +6,12 @@ """ from __future__ import division, print_function, unicode_literals + import os -import sys import subprocess -from threading import Timer +import sys from contextlib import contextmanager +from threading import Timer def apply_backspaces_and_linefeeds(text): diff --git a/gluefactory/utils/tensor.py b/gluefactory/utils/tensor.py index a20c6412..f31bb580 100644 --- a/gluefactory/utils/tensor.py +++ b/gluefactory/utils/tensor.py @@ -3,8 +3,9 @@ """ import collections.abc as collections -import torch + import numpy as np +import torch string_classes = (str, bytes) diff --git a/gluefactory/utils/tools.py b/gluefactory/utils/tools.py index 21541e68..6a27f4a4 100644 --- a/gluefactory/utils/tools.py +++ b/gluefactory/utils/tools.py @@ -4,13 +4,14 @@ Author: Paul-Edouard Sarlin (skydes) """ -import time -import numpy as np import os -import torch import random -from contextlib import contextmanager +import time from collections.abc import Iterable +from contextlib import contextmanager + +import numpy as np +import torch class AverageMetric: diff --git a/gluefactory/visualization/global_frame.py b/gluefactory/visualization/global_frame.py index 41d33ec5..a403c9c9 100644 --- a/gluefactory/visualization/global_frame.py +++ b/gluefactory/visualization/global_frame.py @@ -1,14 +1,16 @@ +import functools import traceback -import numpy as np +from copy import deepcopy + import matplotlib.pyplot as plt -from omegaconf import OmegaConf +import numpy as np from matplotlib.widgets import Button -from copy import deepcopy -import functools +from omegaconf import OmegaConf + +from ..datasets.base_dataset import collate # from ..eval.export_predictions import load_predictions from ..models.cache_loader import CacheLoader -from ..datasets.base_dataset import collate from .tools import RadioHideTool diff --git a/gluefactory/visualization/tools.py b/gluefactory/visualization/tools.py index 1415807a..a095d06e 100644 --- a/gluefactory/visualization/tools.py +++ b/gluefactory/visualization/tools.py @@ -1,25 +1,25 @@ +import inspect +import sys +import warnings + import matplotlib.pyplot as plt +import torch from matplotlib.backend_tools import ToolToggleBase from matplotlib.widgets import RadioButtons, Slider -import warnings -import torch +from ..geometry.epipolar import T_to_F, generalized_epi_dist +from ..geometry.homography import sym_homography_error from ..visualization.viz2d import ( + cm_ranking, + cm_RdGn, + draw_epipolar_line, + get_line, + plot_color_line_matches, plot_heatmaps, plot_keypoints, plot_lines, plot_matches, - plot_color_line_matches, - cm_RdGn, - cm_ranking, - get_line, - draw_epipolar_line, ) -from ..geometry.homography import sym_homography_error -from ..geometry.epipolar import generalized_epi_dist, T_to_F - -import inspect -import sys with warnings.catch_warnings(): warnings.simplefilter("ignore") diff --git a/gluefactory/visualization/two_view_frame.py b/gluefactory/visualization/two_view_frame.py index fac2222c..3461eb0e 100644 --- a/gluefactory/visualization/two_view_frame.py +++ b/gluefactory/visualization/two_view_frame.py @@ -1,10 +1,9 @@ -import numpy as np import pprint -from . import viz2d -from .tools import __plot_dict__ +import numpy as np -from .tools import RadioHideTool, ToggleTool +from . import viz2d +from .tools import RadioHideTool, ToggleTool, __plot_dict__ class FormatPrinter(pprint.PrettyPrinter): diff --git a/gluefactory/visualization/visualize_batch.py b/gluefactory/visualization/visualize_batch.py index 09bdcbf7..3bd3f7b6 100644 --- a/gluefactory/visualization/visualize_batch.py +++ b/gluefactory/visualization/visualize_batch.py @@ -1,13 +1,7 @@ import torch from ..utils.tensor import batch_to_device -from .viz2d import ( - plot_image_grid, - plot_keypoints, - plot_matches, - cm_RdGn, - plot_heatmaps, -) +from .viz2d import cm_RdGn, plot_heatmaps, plot_image_grid, plot_keypoints, plot_matches def make_match_figures(pred_, data_, n_pairs=2): diff --git a/gluefactory/visualization/viz2d.py b/gluefactory/visualization/viz2d.py index 4a3a636b..42a000a3 100644 --- a/gluefactory/visualization/viz2d.py +++ b/gluefactory/visualization/viz2d.py @@ -6,8 +6,8 @@ """ import matplotlib -import matplotlib.pyplot as plt import matplotlib.patheffects as path_effects +import matplotlib.pyplot as plt import numpy as np import seaborn as sns diff --git a/pyproject.toml b/pyproject.toml index b0cc6d78..5185a753 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,10 +43,14 @@ extra = [ "deeplsd @ git+https://github.com/cvg/DeepLSD.git", "homography_est @ git+https://github.com/rpautrat/homography_est.git", ] -dev = ["black", "flake8", "jupyter"] +dev = ["black", "flake8", "isort"] [tool.setuptools.packages.find] include = ["gluefactory*"] [tool.setuptools.package-data] gluefactory = ["datasets/megadepth_scene_lists/*.txt", "configs/*.yaml"] + +[tool.isort] +profile = "black" +extend_skip = ["gluefactory_nonfree/"]