Skip to content

Commit

Permalink
完成yolov4 tiny
Browse files Browse the repository at this point in the history
  • Loading branch information
Steven committed Aug 27, 2024
1 parent 0b5538c commit c532847
Show file tree
Hide file tree
Showing 10 changed files with 1,524 additions and 28 deletions.
19 changes: 14 additions & 5 deletions Note.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,27 @@
- 安裝albumentations可以試看看不要移除opencv,因為移除會讓opencv的RGB2BGR故障
- albumentations >= 1.14.11 要直接手動幫mmdetection寫補丁https://github.com/open-mmlab/mmdetection/pull/11870/files

* 訓練指令
# 常用指令

- 訓練指令
python tools/train.py configs/custom_dataset/yolov4_l_mish_1xb16-300e_5car.py

* inference 指令
- inference 指令
configs/custom_dataset/yolov4_l_mish_1xb16-300e_5car.py weights/last.pt

configs/yolov5/yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py weights/yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187.pth
- 列印config
python tools/misc/print_config.py configs/custom_dataset/yolov4_l_mish_1xb16-300e_5car.py --save-path /home/ai_server/2005013/mmyolo/mmyolo/my_exp/check_config.json

- 切割訓練集
python tools/misc/coco_split.py --json data/eastcoast_5car/annotations/result.json \
--out-dir ./data/debug_eastcoast_5car/annotations \
--ratios 0.99 0.005 0.005\
--shuffle

- 命名規則
* 命名規則
https://mmyolo.readthedocs.io/zh-cn/latest/tutorials/config.html?highlight=syncbn#id13

- Hook說明
* Hook說明
https://mmengine.readthedocs.io/zh-cn/latest/design/hook.html

```python
Expand Down
15 changes: 13 additions & 2 deletions configs/custom_dataset/yolov4_l_mish_1xb16-300e_5car.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,26 @@
]

# 根据自己的 GPU 情况,修改 batch size,YOLOv5-s 默认为 8卡 x 16bs
train_batch_size_per_gpu = 16
train_batch_size_per_gpu = 64
train_num_workers = 4 # 推荐使用 train_num_workers = nGPU x 4

save_epoch_intervals = 4 # 每 interval 轮迭代进行一次保存一次权重


# 根据自己的 GPU 情况,修改 base_lr,修改的比例是 base_lr_default * (your_bs / default_bs)
base_lr = _base_.base_lr * train_batch_size_per_gpu / (8 * 16)

model = dict(
backbone=dict(frozen_stages=4),
bbox_head=dict(
head_module=dict(num_classes=num_classes),
loss_cls=dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=True,
reduction='mean',
loss_weight=_base_.loss_cls_weight *
(num_classes / 80 * 3 / _base_.num_det_layers))),
)

train_dataloader = dict(
batch_size=train_batch_size_per_gpu,
num_workers=train_num_workers,
Expand Down
283 changes: 283 additions & 0 deletions configs/yolov4/yolov4_tiny_1x16b-300e_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,283 @@
_base_ = ['../_base_/default_runtime.py', '../_base_/det_p5_tta.py']

# ========================Frequently modified parameters======================
# -----data related-----
data_root = 'data/coco/' # Root path of data
# Path of train annotation file
train_ann_file = 'annotations/instances_train2017.json'
train_data_prefix = 'train2017/' # Prefix of train image path
# Path of val annotation file
val_ann_file = 'annotations/instances_val2017.json'
val_data_prefix = 'val2017/' # Prefix of val image path

num_classes = 80 # Number of classes for classification
# Batch size of a single GPU during training
train_batch_size_per_gpu = 16
# Worker to pre-fetch data for each single GPU during training
train_num_workers = 8
# persistent_workers must be False if num_workers is 0
persistent_workers = True

# -----model related-----
# Basic size of multi-scale prior box
anchors = [[(10, 14), (23, 27), (37, 58)], [(81, 82), (135, 169), (344, 319)]]

# -----train val related-----
# Base learning rate for optim_wrapper. Corresponding to 8xb16=128 bs
base_lr = 0.01
max_epochs = 300 # Maximum training epochs

model_test_cfg = dict(
# The config of multi-label for multi-class prediction.
multi_label=True,
# The number of boxes before NMS
nms_pre=30000,
score_thr=0.001, # Threshold to filter out boxes.
nms=dict(type='nms', iou_threshold=0.65), # NMS type and threshold
max_per_img=300) # Max number of detections of each image

# ========================Possible modified parameters========================
# -----data related-----
img_scale = (416, 416) # width, height
# Dataset type, this will be used to define the dataset
dataset_type = 'YOLOv5CocoDataset'
# Batch size of a single GPU during validation
val_batch_size_per_gpu = 1
# Worker to pre-fetch data for each single GPU during validation
val_num_workers = 2

# Config of batch shapes. Only on val.
# It means not used if batch_shapes_cfg is None.
batch_shapes_cfg = dict(
type='BatchShapePolicy',
batch_size=val_batch_size_per_gpu,
img_size=img_scale[0],
# The image scale of padding should be divided by pad_size_divisor
size_divisor=32,
# Additional paddings for pixel scale
extra_pad_ratio=0.5)

# -----model related-----
# The scaling factor that controls the depth of the network structure
deepen_factor = 1.0
# The scaling factor that controls the width of the network structure
widen_factor = 1.0
# Strides of multi-scale prior box
strides = [8, 16]
num_det_layers = 3 # The number of model output scales
norm_cfg = dict(type='BN', momentum=0.03, eps=0.001) # Normalization config

# -----train val related-----
affine_scale = 0.5 # YOLOv5RandomAffine scaling ratio
loss_cls_weight = 0.5
loss_bbox_weight = 0.05
loss_obj_weight = 1.0
prior_match_thr = 4. # Priori box matching threshold
# The obj loss weights of the three output layers
obj_level_weights = [4., 1.]
lr_factor = 0.01 # Learning rate scaling factor
weight_decay = 0.0005
# Save model checkpoint and validation intervals
save_checkpoint_intervals = 10
# The maximum checkpoints to keep.
max_keep_ckpts = 3
# Single-scale training is recommended to
# be turned on, which can speed up training.
env_cfg = dict(cudnn_benchmark=True)

#randomness = dict(seed=0)
# ===============================Unmodified in most cases====================
model = dict(
type='YOLODetector',
data_preprocessor=dict(
type='mmdet.DetDataPreprocessor',
mean=[0., 0., 0.],
std=[255., 255., 255.],
bgr_to_rgb=True),
backbone=dict(
type='YOLOv4TinyBackbone',
norm_cfg=norm_cfg,
act_cfg=dict(type='LeakyReLU', inplace=True)),
neck=dict(
type='V4TinyUpSampleBlock',
in_channels=[512, 256],
out_channels=[256, 128],
norm_cfg=norm_cfg,
act_cfg=dict(type='LeakyReLU', inplace=True)),
bbox_head=dict(
type='YOLOv3Head',
head_module=dict(
type='YOLOv3HeadModule',
num_classes=num_classes,
in_channels=[256, 384],
featmap_strides=strides,
num_base_priors=3),
prior_generator=dict(
type='mmdet.YOLOAnchorGenerator',
base_sizes=anchors,
strides=strides),
# scaled based on number of detection layers
loss_cls=dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=True,
reduction='mean',
loss_weight=loss_cls_weight *
(num_classes / 80 * 3 / num_det_layers)),
loss_bbox=dict(
type='IoULoss',
iou_mode='ciou',
bbox_format='xywh',
eps=1e-7,
reduction='mean',
loss_weight=loss_bbox_weight * (3 / num_det_layers),
return_iou=True),
loss_obj=dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=True,
reduction='mean',
loss_weight=loss_obj_weight *
((img_scale[0] / 640)**2 * 3 / num_det_layers)),
prior_match_thr=prior_match_thr,
obj_level_weights=obj_level_weights),
test_cfg=model_test_cfg)

albu_train_transforms = [
dict(type='Blur', p=0.01),
dict(type='MedianBlur', p=0.01),
dict(type='ToGray', p=0.01),
dict(type='CLAHE', p=0.01)
]

pre_transform = [
dict(type='LoadImageFromFile', backend_args=_base_.backend_args),
dict(type='LoadAnnotations', with_bbox=True)
]

train_pipeline = [
*pre_transform,
dict(
type='Mosaic',
img_scale=img_scale,
pad_val=114.0,
pre_transform=pre_transform),
dict(
type='YOLOv5RandomAffine',
max_rotate_degree=0.0,
max_shear_degree=0.0,
scaling_ratio_range=(1 - affine_scale, 1 + affine_scale),
# img_scale is (width, height)
border=(-img_scale[0] // 2, -img_scale[1] // 2),
border_val=(114, 114, 114)),
dict(
type='mmdet.Albu',
transforms=albu_train_transforms,
bbox_params=dict(
type='BboxParams',
format='pascal_voc',
label_fields=['gt_bboxes_labels', 'gt_ignore_flags']),
keymap={
'img': 'image',
'gt_bboxes': 'bboxes'
}),
dict(type='YOLOv5HSVRandomAug'),
dict(type='mmdet.RandomFlip', prob=0.5),
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'flip',
'flip_direction'))
]

train_dataloader = dict(
batch_size=train_batch_size_per_gpu,
num_workers=train_num_workers,
persistent_workers=persistent_workers,
pin_memory=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file=train_ann_file,
data_prefix=dict(img=train_data_prefix),
filter_cfg=dict(filter_empty_gt=False, min_size=32),
pipeline=train_pipeline))

test_pipeline = [
dict(type='LoadImageFromFile', backend_args=_base_.backend_args),
dict(type='YOLOv5KeepRatioResize', scale=img_scale),
dict(
type='LetterResize',
scale=img_scale,
allow_scale_up=False,
pad_val=dict(img=114)),
dict(type='LoadAnnotations', with_bbox=True, _scope_='mmdet'),
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'pad_param'))
]

val_dataloader = dict(
batch_size=val_batch_size_per_gpu,
num_workers=val_num_workers,
persistent_workers=persistent_workers,
pin_memory=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
test_mode=True,
data_prefix=dict(img=val_data_prefix),
ann_file=val_ann_file,
pipeline=test_pipeline,
batch_shapes_cfg=batch_shapes_cfg))

test_dataloader = val_dataloader

param_scheduler = None
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(
type='SGD',
lr=base_lr,
momentum=0.937,
weight_decay=weight_decay,
nesterov=True,
batch_size_per_gpu=train_batch_size_per_gpu),
constructor='YOLOv5OptimizerConstructor')

default_hooks = dict(
param_scheduler=dict(
type='YOLOv5ParamSchedulerHook',
scheduler_type='linear',
lr_factor=lr_factor,
max_epochs=max_epochs),
checkpoint=dict(
type='CheckpointHook',
interval=save_checkpoint_intervals,
save_best='auto',
max_keep_ckpts=max_keep_ckpts))

custom_hooks = [
dict(
type='EMAHook',
ema_type='ExpMomentumEMA',
momentum=0.0001,
update_buffers=True,
strict_load=False,
priority=49)
]

val_evaluator = dict(
type='mmdet.CocoMetric',
proposal_nums=(100, 1, 10),
ann_file=data_root + val_ann_file,
metric='bbox')
test_evaluator = val_evaluator

train_cfg = dict(
type='EpochBasedTrainLoop',
max_epochs=max_epochs,
val_interval=save_checkpoint_intervals)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
6 changes: 3 additions & 3 deletions mmyolo/models/backbones/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .base_backbone import BaseBackbone
from .csp_darknet import (YOLOv4CSPDarknet, YOLOv5CSPDarknet, YOLOv8CSPDarknet,
YOLOXCSPDarknet)
from .csp_darknet import (YOLOv4CSPDarknet, YOLOv4TinyBackbone,
YOLOv5CSPDarknet, YOLOv8CSPDarknet, YOLOXCSPDarknet)
from .csp_resnet import PPYOLOECSPResNet
from .cspnext import CSPNeXt
from .efficient_rep import YOLOv6CSPBep, YOLOv6EfficientRep
Expand All @@ -10,5 +10,5 @@
__all__ = [
'YOLOv5CSPDarknet', 'BaseBackbone', 'YOLOv6EfficientRep', 'YOLOv6CSPBep',
'YOLOXCSPDarknet', 'CSPNeXt', 'YOLOv7Backbone', 'PPYOLOECSPResNet',
'YOLOv8CSPDarknet', 'YOLOv4CSPDarknet'
'YOLOv8CSPDarknet', 'YOLOv4CSPDarknet', 'YOLOv4TinyBackbone'
]
Loading

0 comments on commit c532847

Please sign in to comment.