Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add autoanchor for YOLO series #654

Open
wants to merge 20 commits into
base: dev
Choose a base branch
from
Open
23 changes: 23 additions & 0 deletions configs/_base_/autoanchor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
k_means_autoanchor = dict(
type='YOLOAutoAnchorHook',
optimizer=dict(
type='YOLOKMeansAnchorOptimizer',
iters=1000,
num_anchor_per_level=[3, 3, 3]))

de_autoanchor = dict(
type='YOLOAutoAnchorHook',
optimizer=dict(
type='YOLODEAnchorOptimizer',
iters=1000,
num_anchor_per_level=[3, 3, 3]))

v5_k_means_autoanchor = dict(
type='YOLOAutoAnchorHook',
optimizer=dict(
type='YOLOV5KMeansAnchorOptimizer',
iters=1000,
num_anchor_per_level=[3, 3, 3],
prior_match_thr=4.0,
mutation_args=[0.9, 0.1],
augment_args=[0.9, 0.1]))
5 changes: 4 additions & 1 deletion configs/yolov5/yolov5_s-v61_syncbn_8xb16-300e_coco.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
_base_ = ['../_base_/default_runtime.py', '../_base_/det_p5_tta.py']
_base_ = [
'../_base_/default_runtime.py', '../_base_/det_p5_tta.py',
'../_base_/autoanchor.py'
]

# ========================Frequently modified parameters======================
# -----data related-----
Expand Down
3 changes: 2 additions & 1 deletion mmyolo/engine/hooks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .ppyoloe_param_scheduler_hook import PPYOLOEParamSchedulerHook
from .switch_to_deploy_hook import SwitchToDeployHook
from .yolo_auto_anchor_hook import YOLOAutoAnchorHook
from .yolov5_param_scheduler_hook import YOLOv5ParamSchedulerHook
from .yolox_mode_switch_hook import YOLOXModeSwitchHook

__all__ = [
'YOLOv5ParamSchedulerHook', 'YOLOXModeSwitchHook', 'SwitchToDeployHook',
'PPYOLOEParamSchedulerHook'
'PPYOLOEParamSchedulerHook', 'YOLOAutoAnchorHook'
]
96 changes: 96 additions & 0 deletions mmyolo/engine/hooks/yolo_auto_anchor_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmengine.dist import broadcast, get_dist_info
from mmengine.hooks import Hook
from mmengine.logging import MMLogger
from mmengine.model import is_model_wrapper
from mmengine.runner import Runner

from mmyolo.registry import HOOKS, TASK_UTILS


@HOOKS.register_module()
class YOLOAutoAnchorHook(Hook):

priority = 48

# YOLOAutoAnchorHook takes priority over EMAHook.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

既然有这个要求,那么在这个hook初始化时候一定要记得打印下,说明这个hook优先级必须高于 emahook,提供更多信息


def __init__(self, optimizer):
yechenzhi marked this conversation as resolved.
Show resolved Hide resolved
self.optimizer = optimizer

def before_run(self, runner) -> None:

model = runner.model
if is_model_wrapper(model):
model = model.module

device = next(model.parameters()).device
anchors = torch.tensor(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个地方应该用 model 内部属性比较好,而不是用配置。而且要写的鲁棒点

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

直接用model.bbox_head.prior_generator.base_sizes 这样可以嘛?

runner.cfg.model.bbox_head.prior_generator.base_sizes,
device=device)
model.register_buffer('anchors', anchors)

def before_train(self, runner: Runner) -> None:

if runner.iter > 0:
return
model = runner.model
if is_model_wrapper(model):
model = model.module
print('begin reloading optimized anchors')

rank, _ = get_dist_info()

weights = model.state_dict()
anchors_tensor = weights['anchors']
if rank == 0 and not runner._has_loaded:
runner_dataset = runner.train_dataloader.dataset
yechenzhi marked this conversation as resolved.
Show resolved Hide resolved
self.optimizer.update(
dataset=runner_dataset,
device=runner_dataset[0]['inputs'].device,
input_shape=runner.cfg['img_scale'],
logger=MMLogger.get_current_instance())

optimizer = TASK_UTILS.build(self.optimizer)
anchors = optimizer.optimize()
device = next(model.parameters()).device
anchors_tensor = torch.tensor(anchors, device=device)

broadcast(anchors_tensor)
weights['anchors'] = anchors_tensor
model.load_state_dict(weights)

self.reinitialize(runner, model)

def before_val(self, runner: Runner) -> None:

model = runner.model
if is_model_wrapper(model):
model = model.module
print('begin reloading optimized anchors')
self.reinitialize(runner, model)

def before_test(self, runner: Runner) -> None:

model = runner.model
if is_model_wrapper(model):
model = model.module
print('begin reloading optimized anchors')
self.reinitialize(runner, model)

def reinitialize(self, runner: Runner, model) -> None:
anchors_tensor = model.state_dict()['anchors']
base_sizes = anchors_tensor.tolist()
device = anchors_tensor.device
prior_generator = runner.cfg.model.bbox_head.prior_generator
prior_generator.update(base_sizes=base_sizes)

model.bbox_head.prior_generator = TASK_UTILS.build(prior_generator)

priors_base_sizes = torch.tensor(
base_sizes, dtype=torch.float, device=device)
featmap_strides = torch.tensor(
model.bbox_head.featmap_strides, dtype=torch.float,
device=device)[:, None, None]
model.bbox_head.priors_base_sizes = priors_base_sizes / featmap_strides
6 changes: 5 additions & 1 deletion mmyolo/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .anchor_optimizers import (YOLODEAnchorOptimizer,
YOLOKMeansAnchorOptimizer,
YOLOV5KMeansAnchorOptimizer)
from .collect_env import collect_env
from .misc import is_metainfo_lower, switch_to_deploy
from .setup_env import register_all_modules

__all__ = [
'register_all_modules', 'collect_env', 'switch_to_deploy',
'is_metainfo_lower'
'is_metainfo_lower', 'YOLOKMeansAnchorOptimizer',
'YOLOV5KMeansAnchorOptimizer', 'YOLODEAnchorOptimizer'
]
Loading