diff --git a/mmyolo/utils/misc.py b/mmyolo/utils/misc.py index c90f52b94..f0e4d3863 100644 --- a/mmyolo/utils/misc.py +++ b/mmyolo/utils/misc.py @@ -1,9 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. import os import urllib +from copy import deepcopy +from typing import List, Union import numpy as np import torch +from mmengine.config import ConfigDict from mmengine.utils import scandir from prettytable import PrettyTable @@ -55,7 +58,7 @@ def auto_arrange_images(image_list: list, image_column: int = 2) -> np.ndarray: return image_show -def get_file_list(source_root: str) -> [list, dict]: +def get_file_list(source_root: str) -> Union[list, dict]: """Get file list. Args: @@ -131,3 +134,33 @@ def judge_keys(dataloader_cfg): judge_keys(cfg.get('train_dataloader', {})) judge_keys(cfg.get('val_dataloader', {})) judge_keys(cfg.get('test_dataloader', {})) + + +def get_cfg_class_name(cfg_type: str): + return cfg_type.split('.')[-1] + + +def convert_to_val_pipeline(data_cfg: ConfigDict, + val_pipeline: List) -> ConfigDict: + + if get_cfg_class_name(data_cfg.get('type')) in [ + 'MultiImageMixDataset', 'ClassBalancedDataset' + ]: + # While evaluation, there shouldn't use multi image or oversample. + data_cfg = deepcopy(data_cfg.get('dataset')) + data_cfg['pipeline'] = val_pipeline + elif 'dataset' in data_cfg: + data = deepcopy(data_cfg.get('dataset')) + if isinstance(data, list): + # concat dataset + data_cfg['dataset'] = [ + convert_to_val_pipeline(d, val_pipeline) for d in data + ] + else: + # Deal with other dataset wrapper + data_cfg['dataset'] = convert_to_val_pipeline(data, val_pipeline) + else: + # Regular dataset + data_cfg['pipeline'] = val_pipeline + data_cfg['test_mode'] = True + return data_cfg diff --git a/tools/test.py b/tools/test.py index c05defe3c..91f499e5b 100644 --- a/tools/test.py +++ b/tools/test.py @@ -2,6 +2,7 @@ import argparse import os import os.path as osp +import warnings from mmdet.engine.hooks.utils import trigger_visualization_hook from mmengine.config import Config, ConfigDict, DictAction @@ -10,6 +11,7 @@ from mmyolo.registry import RUNNERS from mmyolo.utils import is_metainfo_lower +from mmyolo.utils.misc import convert_to_val_pipeline # TODO: support fuse_conv_bn @@ -35,6 +37,13 @@ def parse_args(): '--tta', action='store_true', help='Whether to use test time augmentation') + parser.add_argument( + '--phase', + default='test', + type=str, + choices=['train', 'test', 'val'], + help='phase of dataset to test, accept "train" "test" and "val". ' + 'Defaults to "test".') parser.add_argument( '--show', action='store_true', help='show prediction results') parser.add_argument( @@ -109,6 +118,29 @@ def main(): # Determine whether the custom metainfo fields are all lowercase is_metainfo_lower(cfg) + # + if args.phase == 'train': + # If test on train phase, it will use val pipline + val_data_cfg = cfg.val_dataloader.dataset + while 'dataset' in val_data_cfg: + val_data_cfg = val_data_cfg['dataset'] + val_pipeline = val_data_cfg.pipeline + train_dataset = cfg.train_dataloader.dataset + test_data_cfg = convert_to_val_pipeline(train_dataset, val_pipeline) + cfg.test_dataloader.dataset = test_data_cfg + evaluator_cfg = cfg.val_evaluator + if evaluator_cfg.get('ann_file') is not None: + evaluator_cfg.pop('ann_file') + warnings.warn('When use train phase for test, `ann_file` will ' + 'be removed to use loaded annotation directly.') + cfg.test_evaluator = cfg.val_evaluator + elif args.phase == 'val': + test_data_cfg = cfg.val_dataloader.dataset + cfg.test_dataloader.dataset = cfg.val_dataloader.dataset + cfg.test_evaluator = cfg.val_evaluator + elif args.phase == 'test': + test_data_cfg = cfg.test_dataloader.dataset + if args.tta: assert 'tta_model' in cfg, 'Cannot find ``tta_model`` in config.' \ " Can't use tta !" @@ -116,7 +148,7 @@ def main(): "in config. Can't use tta !" cfg.model = ConfigDict(**cfg.tta_model, module=cfg.model) - test_data_cfg = cfg.test_dataloader.dataset + # test_data_cfg = cfg.test_dataloader.dataset while 'dataset' in test_data_cfg: test_data_cfg = test_data_cfg['dataset']