Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhance] Add phase for test.py #625

Open
wants to merge 3 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion mmyolo/utils/misc.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
import urllib
from copy import deepcopy
from typing import List, Union

import numpy as np
import torch
from mmengine.config import ConfigDict
from mmengine.utils import scandir
from prettytable import PrettyTable

Expand Down Expand Up @@ -55,7 +58,7 @@ def auto_arrange_images(image_list: list, image_column: int = 2) -> np.ndarray:
return image_show


def get_file_list(source_root: str) -> [list, dict]:
def get_file_list(source_root: str) -> Union[list, dict]:
"""Get file list.

Args:
Expand Down Expand Up @@ -131,3 +134,33 @@ def judge_keys(dataloader_cfg):
judge_keys(cfg.get('train_dataloader', {}))
judge_keys(cfg.get('val_dataloader', {}))
judge_keys(cfg.get('test_dataloader', {}))


def get_cfg_class_name(cfg_type: str):
return cfg_type.split('.')[-1]


def convert_to_val_pipeline(data_cfg: ConfigDict,
val_pipeline: List) -> ConfigDict:

if get_cfg_class_name(data_cfg.get('type')) in [
'MultiImageMixDataset', 'ClassBalancedDataset'
]:
# While evaluation, there shouldn't use multi image or oversample.
data_cfg = deepcopy(data_cfg.get('dataset'))
data_cfg['pipeline'] = val_pipeline
elif 'dataset' in data_cfg:
data = deepcopy(data_cfg.get('dataset'))
if isinstance(data, list):
# concat dataset
data_cfg['dataset'] = [
convert_to_val_pipeline(d, val_pipeline) for d in data
]
else:
# Deal with other dataset wrapper
data_cfg['dataset'] = convert_to_val_pipeline(data, val_pipeline)
else:
# Regular dataset
data_cfg['pipeline'] = val_pipeline
data_cfg['test_mode'] = True
return data_cfg
34 changes: 33 additions & 1 deletion tools/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import argparse
import os
import os.path as osp
import warnings

from mmdet.engine.hooks.utils import trigger_visualization_hook
from mmengine.config import Config, ConfigDict, DictAction
Expand All @@ -10,6 +11,7 @@

from mmyolo.registry import RUNNERS
from mmyolo.utils import is_metainfo_lower
from mmyolo.utils.misc import convert_to_val_pipeline


# TODO: support fuse_conv_bn
Expand All @@ -35,6 +37,13 @@ def parse_args():
'--tta',
action='store_true',
help='Whether to use test time augmentation')
parser.add_argument(
'--phase',
default='test',
type=str,
choices=['train', 'test', 'val'],
help='phase of dataset to test, accept "train" "test" and "val". '
'Defaults to "test".')
parser.add_argument(
'--show', action='store_true', help='show prediction results')
parser.add_argument(
Expand Down Expand Up @@ -109,14 +118,37 @@ def main():
# Determine whether the custom metainfo fields are all lowercase
is_metainfo_lower(cfg)

#
if args.phase == 'train':
# If test on train phase, it will use val pipline
val_data_cfg = cfg.val_dataloader.dataset
while 'dataset' in val_data_cfg:
val_data_cfg = val_data_cfg['dataset']
val_pipeline = val_data_cfg.pipeline
train_dataset = cfg.train_dataloader.dataset
test_data_cfg = convert_to_val_pipeline(train_dataset, val_pipeline)
cfg.test_dataloader.dataset = test_data_cfg
evaluator_cfg = cfg.val_evaluator
if evaluator_cfg.get('ann_file') is not None:
evaluator_cfg.pop('ann_file')
warnings.warn('When use train phase for test, `ann_file` will '
'be removed to use loaded annotation directly.')
cfg.test_evaluator = cfg.val_evaluator
elif args.phase == 'val':
test_data_cfg = cfg.val_dataloader.dataset
cfg.test_dataloader.dataset = cfg.val_dataloader.dataset
cfg.test_evaluator = cfg.val_evaluator
elif args.phase == 'test':
test_data_cfg = cfg.test_dataloader.dataset

if args.tta:
assert 'tta_model' in cfg, 'Cannot find ``tta_model`` in config.' \
" Can't use tta !"
assert 'tta_pipeline' in cfg, 'Cannot find ``tta_pipeline`` ' \
"in config. Can't use tta !"

cfg.model = ConfigDict(**cfg.tta_model, module=cfg.model)
test_data_cfg = cfg.test_dataloader.dataset
# test_data_cfg = cfg.test_dataloader.dataset
while 'dataset' in test_data_cfg:
test_data_cfg = test_data_cfg['dataset']

Expand Down