Skip to content

Commit

Permalink
New model (#195)
Browse files Browse the repository at this point in the history
* update default color

* update model

* add downloading bar

* fix test

* add dep

* style
  • Loading branch information
MateoLostanlen authored May 21, 2024
1 parent e01b824 commit 59d2029
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 10 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ dependencies = [
"pyroclient @ git+https://github.com/pyronear/pyro-api.git@main#egg=pkg&subdirectory=client",
"requests>=2.20.0,<3.0.0",
"opencv-python==4.5.5.64",
"tqdm>=4.62.0",
]

[project.optional-dependencies]
Expand Down
14 changes: 11 additions & 3 deletions pyroengine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@

import cv2 # type: ignore[import-untyped]
import numpy as np
from tqdm import tqdm

__all__ = ["letterbox", "nms", "xywh2xyxy"]
__all__ = ["letterbox", "nms", "xywh2xyxy", "DownloadProgressBar"]


def xywh2xyxy(x: np.ndarray):
Expand All @@ -20,14 +21,14 @@ def xywh2xyxy(x: np.ndarray):


def letterbox(
im: np.ndarray, new_shape: tuple = (640, 640), color: tuple = (114, 114, 114), auto: bool = False, stride: int = 32
im: np.ndarray, new_shape: tuple = (640, 640), color: tuple = (0, 0, 0), auto: bool = False, stride: int = 32
):
"""Letterbox image transform for yolo models
Args:
im (np.ndarray): Input image
new_shape (tuple, optional): Image size. Defaults to (640, 640).
color (tuple, optional): Pixel fill value for the area outside the transformed image.
Defaults to (114, 114, 114).
Defaults to (0, 0, 0).
auto (bool, optional): auto padding. Defaults to True.
stride (int, optional): padding stride. Defaults to 32.
Returns:
Expand Down Expand Up @@ -109,3 +110,10 @@ def nms(boxes: np.ndarray, overlapThresh: int = 0):
indices = indices[indices != i]

return boxes[indices]


class DownloadProgressBar(tqdm):
def update_to(self, b=1, bsize=1, tsize=None):
if tsize is not None:
self.total = tsize
self.update(b * bsize - self.n)
10 changes: 6 additions & 4 deletions pyroengine/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
import onnxruntime
from PIL import Image

from .utils import letterbox, nms, xywh2xyxy
from .utils import DownloadProgressBar, letterbox, nms, xywh2xyxy

__all__ = ["Classifier"]

MODEL_URL = "https://github.com/pyronear/pyro-vision/releases/download/v0.2.0/yolov8s_v001.onnx"
MODEL_URL = "https://huggingface.co/pyronear/yolov8s/resolve/main/yolov8s.onnx"


class Classifier:
Expand All @@ -29,14 +29,16 @@ class Classifier:
model_path: model path
"""

def __init__(self, model_path: Optional[str] = "data/model.onnx", img_size: tuple = (384, 640)) -> None:
def __init__(self, model_path: Optional[str] = "data/model.onnx", img_size: tuple = (1024, 1024)) -> None:
if model_path is None:
model_path = "data/model.onnx"

if not os.path.isfile(model_path):
os.makedirs(os.path.split(model_path)[0], exist_ok=True)
print(f"Downloading model from {MODEL_URL} ...")
urlretrieve(MODEL_URL, model_path)
with DownloadProgressBar(unit="B", unit_scale=True, miniters=1, desc=model_path) as t:
urlretrieve(MODEL_URL, model_path, reporthook=t.update_to)
print("Model downloaded!")

self.ort_session = onnxruntime.InferenceSession(model_path)
self.img_size = img_size
Expand Down
6 changes: 3 additions & 3 deletions tests/test_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def test_classifier(mock_wildfire_image):
# Check preprocessing
out, pad = model.preprocess_image(mock_wildfire_image)
assert isinstance(out, np.ndarray) and out.dtype == np.float32
assert out.shape == (1, 3, 384, 640)
assert out.shape == (1, 3, 1024, 1024)
assert isinstance(pad, tuple)
# Check inference
out = model(mock_wildfire_image)
Expand All @@ -18,10 +18,10 @@ def test_classifier(mock_wildfire_image):
assert conf >= 0 and conf <= 1

# Test mask
mask = np.ones((384, 640))
mask = np.ones((1024, 640))
out = model(mock_wildfire_image, mask)
assert out.shape == (1, 5)

mask = np.zeros((384, 640))
mask = np.zeros((1024, 1024))
out = model(mock_wildfire_image, mask)
assert out.shape == (0, 5)

0 comments on commit 59d2029

Please sign in to comment.