Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Scannet1500 dataset #25

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,31 @@ Here are the results as Area Under the Curve (AUC) of the pose error at 5/10/20
</details>


#### Scannet-1500

Running the evaluation commands automatically downloads the dataset, which takes about 1.1 GB of disk space.

<details>
<summary>[Evaluating LightGlue]</summary>

To evaluate the pre-trained SuperPoint+LightGlue model on Scannet-1500, run:
```bash
python -m gluefactory.eval.scannet1500 --conf superpoint+lightglue-official
# or the adaptive variant
python -m gluefactory.eval.scannet1500 --conf superpoint+lightglue-official \
model.matcher.{depth_confidence=0.95,width_confidence=0.95}
```

Here are the results as Area Under the Curve (AUC) of the pose error at 5/10/20 degrees:

| Methods | [OpenCV](../gluefactory/robust_estimators/relative_pose/opencv.py) |
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: add PoseLib

| ------------------------------------------------------------ | ------------------ |
| [SuperPoint + SuperGlue](../gluefactory/configs/superpoint+superglue-official.yaml) | 17.4 / 33.9 / 49.5 |
| [SuperPoint + LightGlue](../gluefactory/configs/superpoint+lightglue-official.yaml) | 17.7 / 34.6 / 51.2 |
| [ALIKED + LightGlue](../gluefactory/configs/aliked+lightglue-official.yaml) | 18.4 / 33.9 / 49.7 |

</details>

#### ETH3D

The dataset will be auto-downloaded if it is not found on disk, and will need about 6 GB of free disk space.
Expand Down
5 changes: 5 additions & 0 deletions gluefactory/configs/superpoint+lightglue-official.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,8 @@ benchmarks:
model:
extractor:
max_num_keypoints: 1024 # overwrite config above
scannet1500:
eval:
estimator: opencv
ransac_th: 1.0

4 changes: 4 additions & 0 deletions gluefactory/configs/superpoint+superglue-official.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,8 @@ benchmarks:
model:
extractor:
max_num_keypoints: 1024 # overwrite config above
scannet1500:
eval:
estimator: opencv
ransac_th: 1.0

18 changes: 12 additions & 6 deletions gluefactory/datasets/image_pairs.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,16 @@ def parse_camera(calib_elems) -> Camera:


def parse_relative_pose(pose_elems) -> Pose:
# assert len(calib_list) == 9
R, t = pose_elems[:9], pose_elems[9:12]
R = np.array([float(x) for x in R]).reshape(3, 3).astype(np.float32)
t = np.array([float(x) for x in t]).astype(np.float32)
return Pose.from_Rt(R, t)
if len(pose_elems) == 12:
R, t = pose_elems[:9], pose_elems[9:12]
R = np.array([float(x) for x in R]).reshape(3, 3).astype(np.float32)
t = np.array([float(x) for x in t]).astype(np.float32)
return Pose.from_Rt(R, t)
elif len(pose_elems) == 16:
T = np.array([float(x) for x in pose_elems]).reshape(4, 4).astype(np.float32)
return Pose.from_4x4mat(T)
else:
raise ValueError(f"Can not interpret pose {pose_elems}.")


class ImagePairs(BaseDataset, torch.utils.data.Dataset):
Expand Down Expand Up @@ -81,7 +86,8 @@ def __getitem__(self, idx):
data["view1"]["camera"] = parse_camera(pair_data[11:20]).scale(
data1["scales"]
)
data["T_0to1"] = parse_relative_pose(pair_data[20:32])
data["T_0to1"] = parse_relative_pose(pair_data[20:])
data["T_1to0"] = data["T_0to1"].inv()
elif self.conf.extra_data == "homography":
data["H_0to1"] = (
data1["transform"]
Expand Down
190 changes: 190 additions & 0 deletions gluefactory/eval/scannet1500.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
import logging
import zipfile
from collections import defaultdict
from collections.abc import Iterable
from pathlib import Path
from pprint import pprint

import matplotlib.pyplot as plt
import numpy as np
import torch
from omegaconf import OmegaConf
from tqdm import tqdm

from ..datasets import get_dataset
from ..models.cache_loader import CacheLoader
from ..settings import DATA_PATH, EVAL_PATH
from ..utils.export_predictions import export_predictions
from ..visualization.viz2d import plot_cumulative
from .eval_pipeline import EvalPipeline
from .io import get_eval_parser, load_model, parse_eval_args
from .utils import eval_matches_epipolar, eval_poses, eval_relative_pose_robust

logger = logging.getLogger(__name__)


class ScanNet1500Pipeline(EvalPipeline):
default_conf = {
"data": {
"name": "image_pairs",
"pairs": "scannet1500/pairs_calibrated.txt",
"root": "scannet1500/",
"extra_data": "relative_pose",
"preprocessing": {
"side": "long",
},
"num_workers": 14,
},
"model": {
"ground_truth": {
"name": None, # remove gt matches
}
},
"eval": {
"estimator": "opencv",
"ransac_th": 1.0, # -1 runs a bunch of thresholds and selects the best
},
}

export_keys = [
"keypoints0",
"keypoints1",
"keypoint_scores0",
"keypoint_scores1",
"matches0",
"matches1",
"matching_scores0",
"matching_scores1",
]
optional_export_keys = []

def _init(self, conf):
if not (DATA_PATH / "scannet1500").exists():
logger.info("Downloading the MegaDepth-1500 dataset.")
url = "https://cvg-data.inf.ethz.ch/scannet/scannet1500.zip"
Comment on lines +63 to +64
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update logging and download link

zip_path = DATA_PATH / url.rsplit("/", 1)[-1]
zip_path.parent.mkdir(exist_ok=True, parents=True)
torch.hub.download_url_to_file(url, zip_path)
with zipfile.ZipFile(zip_path) as fid:
fid.extractall(DATA_PATH)
zip_path.unlink()

@classmethod
def get_dataloader(self, data_conf=None):
"""Returns a data loader with samples for each eval datapoint"""
data_conf = data_conf if data_conf else self.default_conf["data"]
dataset = get_dataset(data_conf["name"])(data_conf)
return dataset.get_data_loader("test")

def get_predictions(self, experiment_dir, model=None, overwrite=False):
"""Export a prediction file for each eval datapoint"""
pred_file = experiment_dir / "predictions.h5"
if not pred_file.exists() or overwrite:
if model is None:
model = load_model(self.conf.model, self.conf.checkpoint)
export_predictions(
self.get_dataloader(self.conf.data),
model,
pred_file,
keys=self.export_keys,
optional_keys=self.optional_export_keys,
)
return pred_file

def run_eval(self, loader, pred_file):
"""Run the eval on cached predictions"""
conf = self.conf.eval
results = defaultdict(list)
test_thresholds = (
([conf.ransac_th] if conf.ransac_th > 0 else [0.5, 1.0, 1.5, 2.0, 2.5, 3.0])
if not isinstance(conf.ransac_th, Iterable)
else conf.ransac_th
)
pose_results = defaultdict(lambda: defaultdict(list))
cache_loader = CacheLoader({"path": str(pred_file), "collate": None}).eval()
for i, data in enumerate(tqdm(loader)):
pred = cache_loader(data)
# add custom evaluations here
results_i = eval_matches_epipolar(data, pred)
for th in test_thresholds:
pose_results_i = eval_relative_pose_robust(
data,
pred,
{"estimator": conf.estimator, "ransac_th": th},
)
[pose_results[th][k].append(v) for k, v in pose_results_i.items()]

# we also store the names for later reference
results_i["names"] = data["name"][0]
if "scene" in data.keys():
results_i["scenes"] = data["scene"][0]

for k, v in results_i.items():
results[k].append(v)

# summarize results as a dict[str, float]
# you can also add your custom evaluations here
summaries = {}
for k, v in results.items():
arr = np.array(v)
if not np.issubdtype(np.array(v).dtype, np.number):
continue
summaries[f"m{k}"] = round(np.mean(arr), 3)

best_pose_results, best_th = eval_poses(
pose_results, auc_ths=[5, 10, 20], key="rel_pose_error"
)
results = {**results, **pose_results[best_th]}
summaries = {
**summaries,
**best_pose_results,
}

figures = {
"pose_recall": plot_cumulative(
{self.conf.eval.estimator: results["rel_pose_error"]},
[0, 30],
unit="°",
title="Pose ",
)
}

return summaries, figures, results


if __name__ == "__main__":
from .. import logger # overwrite the logger

dataset_name = Path(__file__).stem
parser = get_eval_parser()
args = parser.parse_intermixed_args()

default_conf = OmegaConf.create(ScanNet1500Pipeline.default_conf)

# mingle paths
output_dir = Path(EVAL_PATH, dataset_name)
output_dir.mkdir(exist_ok=True, parents=True)

name, conf = parse_eval_args(
dataset_name,
args,
"configs/",
default_conf,
)

experiment_dir = output_dir / name
experiment_dir.mkdir(exist_ok=True)

pipeline = ScanNet1500Pipeline(conf)
s, f, r = pipeline.run(
experiment_dir,
overwrite=args.overwrite,
overwrite_eval=args.overwrite_eval,
)

pprint(s)

if args.plot:
for name, fig in f.items():
fig.canvas.manager.set_window_title(name)
plt.show()