Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add autoanchor for YOLO series #654

Open
wants to merge 20 commits into
base: dev
Choose a base branch
from
Open
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
_base_ = '../yolov5_s-v61_syncbn_8xb16-300e_coco.py'

model = dict(
bbox_head=dict(prior_generator=dict(type='YOLOAutoAnchorGenerator')))

custom_hooks = [
dict(
type='YOLOAutoAnchorHook',
optimizer=dict(
type='YOLOKMeansAnchorOptimizer',
hhaAndroid marked this conversation as resolved.
Show resolved Hide resolved
iters=1000,
num_anchor_per_level=[3, 3, 3]))
]
2 changes: 2 additions & 0 deletions mmyolo/core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .anchor import * # noqa: F401, F403
12 changes: 12 additions & 0 deletions mmyolo/core/anchor/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .anchor_generator import YOLOAutoAnchorGenerator
from .anchor_optimizer import (YOLODEAnchorOptimizer,
YOLOKMeansAnchorOptimizer,
YOLOV5KMeansAnchorOptimizer)

__all__ = [
'YOLOAutoAnchorGenerator',
'YOLOKMeansAnchorOptimizer',
'YOLOV5KMeansAnchorOptimizer',
'YOLODEAnchorOptimizer',
]
40 changes: 40 additions & 0 deletions mmyolo/core/anchor/anchor_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmdet.models.task_modules import YOLOAnchorGenerator
from torch.nn.modules.utils import _pair

from mmyolo.registry import TASK_UTILS


@TASK_UTILS.register_module()
class YOLOAutoAnchorGenerator(nn.Module, YOLOAnchorGenerator):
hhaAndroid marked this conversation as resolved.
Show resolved Hide resolved
"""AutoAnchor generator for YOLO.

Args:
strides (list[int] | list[tuple[int, int]]): Strides of anchors
in multiple feature levels.
base_sizes (list[list[tuple[int, int]]]): The basic sizes
of anchors in multiple levels.
"""

def __init__(self, strides, base_sizes, use_box_type: bool = False):
super().__init__()
self.strides = [_pair(stride) for stride in strides]
self.centers = [(stride[0] / 2., stride[1] / 2.)
for stride in self.strides]
self.use_box_type = use_box_type
self.register_buffer('anchors', torch.tensor(base_sizes))

@property
def base_sizes(self):
T = []
num_anchor_per_level = len(self.anchors[0])
for base_sizes_per_level in self.anchors:
assert num_anchor_per_level == len(base_sizes_per_level)
T.append([_pair(base_size) for base_size in base_sizes_per_level])
return T

@property
def base_anchors(self):
return self.gen_base_anchors()
Loading