diff --git a/README.md b/README.md index d598c390..2485cd12 100644 --- a/README.md +++ b/README.md @@ -166,6 +166,31 @@ Here are the results as Area Under the Curve (AUC) of the pose error at 5/10/20 +#### Scannet-1500 + +Running the evaluation commands automatically downloads the dataset, which takes about 1.1 GB of disk space. + +
+[Evaluating LightGlue] + +To evaluate the pre-trained SuperPoint+LightGlue model on Scannet-1500, run: +```bash +python -m gluefactory.eval.scannet1500 --conf superpoint+lightglue-official +# or the adaptive variant +python -m gluefactory.eval.scannet1500 --conf superpoint+lightglue-official \ + model.matcher.{depth_confidence=0.95,width_confidence=0.95} +``` + +Here are the results as Area Under the Curve (AUC) of the pose error at 5/10/20 degrees: + +| Methods | [OpenCV](../gluefactory/robust_estimators/relative_pose/opencv.py) | +| ------------------------------------------------------------ | ------------------ | +| [SuperPoint + SuperGlue](../gluefactory/configs/superpoint+superglue-official.yaml) | 17.4 / 33.9 / 49.5 | +| [SuperPoint + LightGlue](../gluefactory/configs/superpoint+lightglue-official.yaml) | 17.7 / 34.6 / 51.2 | +| [ALIKED + LightGlue](../gluefactory/configs/aliked+lightglue-official.yaml) | 18.4 / 33.9 / 49.7 | + +
+ #### ETH3D The dataset will be auto-downloaded if it is not found on disk, and will need about 6 GB of free disk space. diff --git a/gluefactory/configs/superpoint+lightglue-official.yaml b/gluefactory/configs/superpoint+lightglue-official.yaml index a03d66f2..da3417fb 100644 --- a/gluefactory/configs/superpoint+lightglue-official.yaml +++ b/gluefactory/configs/superpoint+lightglue-official.yaml @@ -27,3 +27,8 @@ benchmarks: model: extractor: max_num_keypoints: 1024 # overwrite config above + scannet1500: + eval: + estimator: opencv + ransac_th: 1.0 + diff --git a/gluefactory/configs/superpoint+superglue-official.yaml b/gluefactory/configs/superpoint+superglue-official.yaml index 090ff5a1..56dd3524 100644 --- a/gluefactory/configs/superpoint+superglue-official.yaml +++ b/gluefactory/configs/superpoint+superglue-official.yaml @@ -23,4 +23,8 @@ benchmarks: model: extractor: max_num_keypoints: 1024 # overwrite config above + scannet1500: + eval: + estimator: opencv + ransac_th: 1.0 diff --git a/gluefactory/datasets/image_pairs.py b/gluefactory/datasets/image_pairs.py index 08bd7603..8e44e532 100644 --- a/gluefactory/datasets/image_pairs.py +++ b/gluefactory/datasets/image_pairs.py @@ -32,11 +32,16 @@ def parse_camera(calib_elems) -> Camera: def parse_relative_pose(pose_elems) -> Pose: - # assert len(calib_list) == 9 - R, t = pose_elems[:9], pose_elems[9:12] - R = np.array([float(x) for x in R]).reshape(3, 3).astype(np.float32) - t = np.array([float(x) for x in t]).astype(np.float32) - return Pose.from_Rt(R, t) + if len(pose_elems) == 12: + R, t = pose_elems[:9], pose_elems[9:12] + R = np.array([float(x) for x in R]).reshape(3, 3).astype(np.float32) + t = np.array([float(x) for x in t]).astype(np.float32) + return Pose.from_Rt(R, t) + elif len(pose_elems) == 16: + T = np.array([float(x) for x in pose_elems]).reshape(4, 4).astype(np.float32) + return Pose.from_4x4mat(T) + else: + raise ValueError(f"Can not interpret pose {pose_elems}.") class ImagePairs(BaseDataset, torch.utils.data.Dataset): @@ -81,7 +86,8 @@ def __getitem__(self, idx): data["view1"]["camera"] = parse_camera(pair_data[11:20]).scale( data1["scales"] ) - data["T_0to1"] = parse_relative_pose(pair_data[20:32]) + data["T_0to1"] = parse_relative_pose(pair_data[20:]) + data["T_1to0"] = data["T_0to1"].inv() elif self.conf.extra_data == "homography": data["H_0to1"] = ( data1["transform"] diff --git a/gluefactory/eval/scannet1500.py b/gluefactory/eval/scannet1500.py new file mode 100644 index 00000000..8d64d2c7 --- /dev/null +++ b/gluefactory/eval/scannet1500.py @@ -0,0 +1,190 @@ +import logging +import zipfile +from collections import defaultdict +from collections.abc import Iterable +from pathlib import Path +from pprint import pprint + +import matplotlib.pyplot as plt +import numpy as np +import torch +from omegaconf import OmegaConf +from tqdm import tqdm + +from ..datasets import get_dataset +from ..models.cache_loader import CacheLoader +from ..settings import DATA_PATH, EVAL_PATH +from ..utils.export_predictions import export_predictions +from ..visualization.viz2d import plot_cumulative +from .eval_pipeline import EvalPipeline +from .io import get_eval_parser, load_model, parse_eval_args +from .utils import eval_matches_epipolar, eval_poses, eval_relative_pose_robust + +logger = logging.getLogger(__name__) + + +class ScanNet1500Pipeline(EvalPipeline): + default_conf = { + "data": { + "name": "image_pairs", + "pairs": "scannet1500/pairs_calibrated.txt", + "root": "scannet1500/", + "extra_data": "relative_pose", + "preprocessing": { + "side": "long", + }, + "num_workers": 14, + }, + "model": { + "ground_truth": { + "name": None, # remove gt matches + } + }, + "eval": { + "estimator": "opencv", + "ransac_th": 1.0, # -1 runs a bunch of thresholds and selects the best + }, + } + + export_keys = [ + "keypoints0", + "keypoints1", + "keypoint_scores0", + "keypoint_scores1", + "matches0", + "matches1", + "matching_scores0", + "matching_scores1", + ] + optional_export_keys = [] + + def _init(self, conf): + if not (DATA_PATH / "scannet1500").exists(): + logger.info("Downloading the MegaDepth-1500 dataset.") + url = "https://cvg-data.inf.ethz.ch/scannet/scannet1500.zip" + zip_path = DATA_PATH / url.rsplit("/", 1)[-1] + zip_path.parent.mkdir(exist_ok=True, parents=True) + torch.hub.download_url_to_file(url, zip_path) + with zipfile.ZipFile(zip_path) as fid: + fid.extractall(DATA_PATH) + zip_path.unlink() + + @classmethod + def get_dataloader(self, data_conf=None): + """Returns a data loader with samples for each eval datapoint""" + data_conf = data_conf if data_conf else self.default_conf["data"] + dataset = get_dataset(data_conf["name"])(data_conf) + return dataset.get_data_loader("test") + + def get_predictions(self, experiment_dir, model=None, overwrite=False): + """Export a prediction file for each eval datapoint""" + pred_file = experiment_dir / "predictions.h5" + if not pred_file.exists() or overwrite: + if model is None: + model = load_model(self.conf.model, self.conf.checkpoint) + export_predictions( + self.get_dataloader(self.conf.data), + model, + pred_file, + keys=self.export_keys, + optional_keys=self.optional_export_keys, + ) + return pred_file + + def run_eval(self, loader, pred_file): + """Run the eval on cached predictions""" + conf = self.conf.eval + results = defaultdict(list) + test_thresholds = ( + ([conf.ransac_th] if conf.ransac_th > 0 else [0.5, 1.0, 1.5, 2.0, 2.5, 3.0]) + if not isinstance(conf.ransac_th, Iterable) + else conf.ransac_th + ) + pose_results = defaultdict(lambda: defaultdict(list)) + cache_loader = CacheLoader({"path": str(pred_file), "collate": None}).eval() + for i, data in enumerate(tqdm(loader)): + pred = cache_loader(data) + # add custom evaluations here + results_i = eval_matches_epipolar(data, pred) + for th in test_thresholds: + pose_results_i = eval_relative_pose_robust( + data, + pred, + {"estimator": conf.estimator, "ransac_th": th}, + ) + [pose_results[th][k].append(v) for k, v in pose_results_i.items()] + + # we also store the names for later reference + results_i["names"] = data["name"][0] + if "scene" in data.keys(): + results_i["scenes"] = data["scene"][0] + + for k, v in results_i.items(): + results[k].append(v) + + # summarize results as a dict[str, float] + # you can also add your custom evaluations here + summaries = {} + for k, v in results.items(): + arr = np.array(v) + if not np.issubdtype(np.array(v).dtype, np.number): + continue + summaries[f"m{k}"] = round(np.mean(arr), 3) + + best_pose_results, best_th = eval_poses( + pose_results, auc_ths=[5, 10, 20], key="rel_pose_error" + ) + results = {**results, **pose_results[best_th]} + summaries = { + **summaries, + **best_pose_results, + } + + figures = { + "pose_recall": plot_cumulative( + {self.conf.eval.estimator: results["rel_pose_error"]}, + [0, 30], + unit="°", + title="Pose ", + ) + } + + return summaries, figures, results + + +if __name__ == "__main__": + from .. import logger # overwrite the logger + + dataset_name = Path(__file__).stem + parser = get_eval_parser() + args = parser.parse_intermixed_args() + + default_conf = OmegaConf.create(ScanNet1500Pipeline.default_conf) + + # mingle paths + output_dir = Path(EVAL_PATH, dataset_name) + output_dir.mkdir(exist_ok=True, parents=True) + + name, conf = parse_eval_args( + dataset_name, + args, + "configs/", + default_conf, + ) + + experiment_dir = output_dir / name + experiment_dir.mkdir(exist_ok=True) + + pipeline = ScanNet1500Pipeline(conf) + s, f, r = pipeline.run( + experiment_dir, + overwrite=args.overwrite, + overwrite_eval=args.overwrite_eval, + ) + + pprint(s) + + if args.plot: + for name, fig in f.items(): + fig.canvas.manager.set_window_title(name) + plt.show()