Skip to content

Commit

Permalink
Merge branch 'develop'
Browse files Browse the repository at this point in the history
  • Loading branch information
pythonlessons committed May 24, 2024
2 parents 4bc9edc + 733e773 commit f303345
Show file tree
Hide file tree
Showing 12 changed files with 240 additions and 362 deletions.
11 changes: 11 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
## [1.2.5] - 2024-05-04
### Added
- Added exception in `mltu.dataProvider.DataProvider` to raise ValueError when dataset is not iterable
- Added custom training code for YoloV8 object detector: `Tutorials\11_Yolov8\train_yolov8.py`
- Added custom trained inference code for YoloV8 object detector:`Tutorials\11_Yolov8\test_yolov8.py`

### Changed
- Fixed `RandomElasticTransform` in `mltu.augmentors` to handle elastic transformation not to exceed image boundaries
- Modified `YoloPreprocessor` in `mltu.torch.yolo.preprocessors` to output dictionary with np.arrays istead of lists


## [1.2.4] - 2024-03-21
### Added
- Added `RandomElasticTransform` to `mltu.augmentors` to work with `Image` objects
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,5 @@ Each tutorial has its own requirements.txt file for a specific mltu version. As
8. [Handwriting words recognition with PyTorch](https://pylessons.com/handwriting-recognition-pytorch), code in ```Tutorials\08_handwriting_recognition_torch``` folder;
9. [Transformer training with TensorFlow for Translation task](https://pylessons.com/transformers-training), code in ```Tutorials\09_translation_transformer``` folder;
10. [Speech Recognition in Python | finetune wav2vec2 model for a custom ASR model](https://youtu.be/h6ooEGzjkj0), code in ```Tutorials\10_wav2vec2_torch``` folder;
11. [YOLOv8: Real-Time Object Detection Simplified](https://youtu.be/vegL__weCxY), code in ```Tutorials\11_Yolov8``` folder;
11. [YOLOv8: Real-Time Object Detection Simplified](https://youtu.be/vegL__weCxY), code in ```Tutorials\11_Yolov8``` folder;
12. [YOLOv8: Customizing Object Detector training](https://youtu.be/ysYiV1CbCyY), code in ```Tutorials\11_Yolov8\train_yolov8.py``` folder;
176 changes: 174 additions & 2 deletions Tutorials/11_Yolov8/README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# Run Ultralytics YOLOv8 pretrained model

YouTube tutorial link: [YOLOv8: Real-Time Object Detection Simplified](https://youtu.be/vegL__weCxY)
YouTube tutorial link:
- [YOLOv8: Real-Time Object Detection Simplified](https://youtu.be/vegL__weCxY);
- [YOLOv8: Customizing Object Detector training](https://youtu.be/ysYiV1CbCyY);

First, I recommend you to install the required packages in a virtual environment:
```bash
mltu==1.2.3
mltu==1.2.5
ultralytics==8.1.28
torch==2.0.0
torchvision==0.15.1
Expand Down Expand Up @@ -134,5 +136,175 @@ while True:
break

cap.release()
cv2.destroyAllWindows()
```

## Customize YoloV8 Object Detector training:
```python
import os
import time
import torch
from mltu.preprocessors import ImageReader
from mltu.annotations.images import CVImage
from mltu.transformers import ImageResizer, ImageShowCV2, ImageNormalizer
from mltu.augmentors import RandomBrightness, RandomRotate, RandomErodeDilate, RandomSharpen, \
RandomMirror, RandomFlip, RandomGaussianBlur, RandomSaltAndPepper, RandomDropBlock, RandomMosaic, RandomElasticTransform
from mltu.torch.model import Model
from mltu.torch.dataProvider import DataProvider
from mltu.torch.yolo.annotation import VOCAnnotationReader
from mltu.torch.yolo.preprocessors import YoloPreprocessor
from mltu.torch.yolo.loss import v8DetectionLoss
from mltu.torch.yolo.metrics import YoloMetrics
from mltu.torch.yolo.optimizer import build_optimizer, AccumulativeOptimizer
from mltu.torch.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard, Model2onnx, WarmupCosineDecay

from ultralytics.nn.tasks import DetectionModel
from ultralytics.engine.model import Model as BaseModel

# https://www.kaggle.com/datasets/andrewmvd/car-plate-detection
annotations_path = "Datasets/car-plate-detection/annotations"

# Create a dataset from the annotations, the dataset is a list of lists where each list contains the [image path, annotation path]
dataset = [[None, os.path.join(annotations_path, f)] for f in os.listdir(annotations_path)]

# Make sure torch can see GPU device, it is not recommended to train with CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

img_size = 416
labels = {0: "licence"}

# Create a data provider for the dataset
data_provider = DataProvider(
dataset=dataset,
skip_validation=True,
batch_size=16,
data_preprocessors=[
VOCAnnotationReader(labels=labels),
ImageReader(CVImage),
],
transformers=[
# ImageShowCV2(),
ImageResizer(img_size, img_size),
ImageNormalizer(transpose_axis=True),
],
batch_postprocessors=[
YoloPreprocessor(device, img_size)
],
numpy=False,
)

# split the dataset into train and test
train_data_provider, val_data_provider = data_provider.split(0.9, shuffle=False)

# Attaach augmentation to the train data provider
train_data_provider.augmentors = [
RandomBrightness(),
RandomErodeDilate(),
RandomSharpen(),
RandomMirror(),
RandomFlip(),
RandomElasticTransform(),
RandomGaussianBlur(),
RandomSaltAndPepper(),
RandomRotate(angle=10),
RandomDropBlock(),
RandomMosaic(),
]

base_model = BaseModel("yolov8n.pt")
# Create a YOLO model
model = DetectionModel('yolov8n.yaml', nc=len(labels))

# Load the weight from base model
try: model.load_state_dict(base_model.model.state_dict(), strict=False)
except: pass

model.to(device)

for k, v in model.named_parameters():
if any(x in k for x in [".dfl"]):
print("freezing", k)
v.requires_grad = False
elif not v.requires_grad:
v.requires_grad = True

lr = 1e-3
optimizer = build_optimizer(model.model, name="AdamW", lr=lr, weight_decay=0.0, momentum=0.937, decay=0.0005)
optimizer = AccumulativeOptimizer(optimizer, 16, 64)

# create model object that will handle training and testing of the network
model = Model(
model,
optimizer,
v8DetectionLoss(model),
metrics=[YoloMetrics(nc=len(labels))],
log_errors=False,
output_path=f"Models/11_Yolov8/{int(time.time())}",
clip_grad_norm=10.0,
ema=True,
)

modelCheckpoint = ModelCheckpoint(monitor="val_fitness", mode="max", save_best_only=True, verbose=True)
tensorBoard = TensorBoard()
earlyStopping = EarlyStopping(monitor="val_fitness", mode="max", patience=31, verbose=True)
model2onnx = Model2onnx(input_shape=(1, 3, img_size, img_size), verbose=True, opset_version=14,
dynamic_axes = {"input": {0: "batch_size", 2: "height", 3: "width"},
"output": {0: "batch_size", 2: "anchors"}},
metadata={"classes": labels})
warmupCosineDecayBias = WarmupCosineDecay(lr_after_warmup=lr, final_lr=lr, initial_lr=0.1,
warmup_steps=len(train_data_provider), warmup_epochs=10, ignore_param_groups=[1, 2]) # lr0
warmupCosineDecay = WarmupCosineDecay(lr_after_warmup=lr, final_lr=lr/10, initial_lr=1e-7,
warmup_steps=len(train_data_provider), warmup_epochs=10, decay_epochs=190, ignore_param_groups=[0]) # lr1 and lr2

# Train the model
history = model.fit(
train_data_provider,
test_dataProvider=val_data_provider,
epochs=200,
callbacks=[
modelCheckpoint,
tensorBoard,
earlyStopping,
model2onnx,
warmupCosineDecayBias,
warmupCosineDecay
]
)
```

## Test Custom trained YoloV8 Object Detector:
```python
import os
import cv2
from mltu.annotations.detections import Detections
from mltu.torch.yolo.detectors.onnx_detector import Detector as OnnxDetector

# https://www.kaggle.com/datasets/andrewmvd/car-plate-detection
images_path = "Datasets/car-plate-detection/images"

input_width, input_height = 416, 416
confidence_threshold = 0.5
iou_threshold = 0.5

detector = OnnxDetector("Models/11_Yolov8/1714135287/model.onnx", input_width, input_height, confidence_threshold, iou_threshold, force_cpu=False)

for image_path in os.listdir(images_path):

frame = cv2.imread(os.path.join(images_path, image_path))

# Perform Yolo object detection
detections: Detections = detector(frame)

# Apply the detections to the frame
frame = detections.applyToFrame(frame)

# Print the FPS
print(detector.fps)

# Display the output image
cv2.imshow("Object Detection", frame)
if cv2.waitKey(0) & 0xFF == ord('q'):
break

cv2.destroyAllWindows()
```
2 changes: 1 addition & 1 deletion Tutorials/11_Yolov8/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
mltu==1.2.3
mltu==1.2.5
ultralytics==8.1.28
torch==2.0.0
torchvision==0.15.1
Expand Down
3 changes: 2 additions & 1 deletion Tutorials/11_Yolov8/run_pretrained.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import cv2
from ultralytics.engine.model import Model as BaseModel
from mltu.annotations.detections import Detections
from mltu.torch.yolo.detectors.torch_detector import Detector as TorchDetector
from mltu.torch.yolo.detectors.onnx_detector import Detector as OnnxDetector

Expand All @@ -18,7 +19,7 @@
break

# Perform Yolo object detection
detections = detector(frame)
detections: Detections = detector(frame)

# Apply the detections to the frame
frame = detections.applyToFrame(frame)
Expand Down
33 changes: 33 additions & 0 deletions Tutorials/11_Yolov8/test_yolov8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import os
import cv2
from mltu.annotations.detections import Detections
from mltu.torch.yolo.detectors.onnx_detector import Detector as OnnxDetector

# https://www.kaggle.com/datasets/andrewmvd/car-plate-detection
images_path = "Datasets/car-plate-detection/images"

input_width, input_height = 416, 416
confidence_threshold = 0.5
iou_threshold = 0.5

detector = OnnxDetector("Models/11_Yolov8/1714135287/model.onnx", input_width, input_height, confidence_threshold, iou_threshold, force_cpu=False)

for image_path in os.listdir(images_path):

frame = cv2.imread(os.path.join(images_path, image_path))

# Perform Yolo object detection
detections: Detections = detector(frame)

# Apply the detections to the frame
frame = detections.applyToFrame(frame)

# Print the FPS
print(detector.fps)

# Display the output image
cv2.imshow("Object Detection", frame)
if cv2.waitKey(0) & 0xFF == ord('q'):
break

cv2.destroyAllWindows()
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
from mltu.torch.yolo.loss import v8DetectionLoss
from mltu.torch.yolo.metrics import YoloMetrics
from mltu.torch.yolo.optimizer import build_optimizer, AccumulativeOptimizer
from mltu.torch.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, TensorBoard, Model2onnx, WarmupCosineDecay
from mltu.torch.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard, Model2onnx, WarmupCosineDecay

from ultralytics.nn.tasks import DetectionModel
from ultralytics.engine.model import Model as BaseModel


# https://www.kaggle.com/datasets/andrewmvd/car-plate-detection
annotations_path = "Datasets/car-plate-detection/annotations"

# Create a dataset from the annotations, the dataset is a list of lists where each list contains the [image path, annotation path]
Expand Down Expand Up @@ -72,6 +72,7 @@
# Create a YOLO model
model = DetectionModel('yolov8n.yaml', nc=len(labels))

# Load the weight from base model
try: model.load_state_dict(base_model.model.state_dict(), strict=False)
except: pass

Expand All @@ -95,7 +96,7 @@
v8DetectionLoss(model),
metrics=[YoloMetrics(nc=len(labels))],
log_errors=False,
output_path=f"Models/detector/{int(time.time())}",
output_path=f"Models/11_Yolov8/{int(time.time())}",
clip_grad_norm=10.0,
ema=True,
)
Expand Down
2 changes: 1 addition & 1 deletion mltu/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "1.2.4"
__version__ = "1.2.5"

from .annotations.images import Image
from .annotations.images import CVImage
Expand Down
2 changes: 2 additions & 0 deletions mltu/augmentors.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,6 +916,8 @@ def __call__(self, image: Image, annotation: typing.Any) -> typing.Tuple[Image,
detections = []
for detection in annotation:
x_min, y_min, x_max, y_max = detection.xyxy_abs
x_max = min(x_max, dx.shape[1] - 1)
y_max = min(y_max, dy.shape[0] - 1)
new_x_min = min(max(0, x_min + dx[y_min, x_min]), image.width - 1)
new_y_min = min(max(0, y_min + dy[y_min, x_min]), image.height - 1)
new_x_max = min(max(0, x_max + dx[y_max, x_max]), image.width - 1)
Expand Down
4 changes: 4 additions & 0 deletions mltu/dataProvider.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ def __init__(
else:
self.logger.info("Skipping Dataset validation...")

# Check if dataset has length
if not len(dataset):
raise ValueError("Dataset must be iterable")

if limit:
self.logger.info(f"Limiting dataset to {limit} samples.")
self._dataset = self._dataset[:limit]
Expand Down
11 changes: 6 additions & 5 deletions mltu/torch/yolo/preprocessors.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import torch
import typing
import numpy as np

class YoloPreprocessor:
def __init__(self, device, imgsz=640):
def __init__(self, device: torch.device, imgsz: int=640):
self.device = device
self.imgsz = imgsz

def __call__(self, images, annotations):
def __call__(self, images, annotations) -> typing.Tuple[np.ndarray, dict]:
batch = {
"ori_shape": [],
"resized_shape": [],
Expand All @@ -23,8 +24,8 @@ def __call__(self, images, annotations):
batch["bboxes"].append(detection.xywh)
batch["batch_idx"].append(i)

batch["cls"] = torch.tensor(batch["cls"]).to(self.device)
batch["bboxes"] = torch.tensor(batch["bboxes"]).to(self.device)
batch["batch_idx"] = torch.tensor(batch["batch_idx"]).to(self.device)
batch["cls"] = torch.tensor(np.array(batch["cls"])).to(self.device)
batch["bboxes"] = torch.tensor(np.array(batch["bboxes"])).to(self.device)
batch["batch_idx"] = torch.tensor(np.array(batch["batch_idx"])).to(self.device)

return np.array(images), batch
Loading

0 comments on commit f303345

Please sign in to comment.