diff --git a/pyproject.toml b/pyproject.toml index 9510d216..53f69c24 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dependencies = [ "pyroclient @ git+https://github.com/pyronear/pyro-api.git@main#egg=pkg&subdirectory=client", "requests>=2.20.0,<3.0.0", "opencv-python==4.5.5.64", + "tqdm>=4.62.0", ] [project.optional-dependencies] diff --git a/pyroengine/utils.py b/pyroengine/utils.py index 17dd0575..7c37026f 100644 --- a/pyroengine/utils.py +++ b/pyroengine/utils.py @@ -6,8 +6,9 @@ import cv2 # type: ignore[import-untyped] import numpy as np +from tqdm import tqdm -__all__ = ["letterbox", "nms", "xywh2xyxy"] +__all__ = ["letterbox", "nms", "xywh2xyxy", "DownloadProgressBar"] def xywh2xyxy(x: np.ndarray): @@ -20,14 +21,14 @@ def xywh2xyxy(x: np.ndarray): def letterbox( - im: np.ndarray, new_shape: tuple = (640, 640), color: tuple = (114, 114, 114), auto: bool = False, stride: int = 32 + im: np.ndarray, new_shape: tuple = (640, 640), color: tuple = (0, 0, 0), auto: bool = False, stride: int = 32 ): """Letterbox image transform for yolo models Args: im (np.ndarray): Input image new_shape (tuple, optional): Image size. Defaults to (640, 640). color (tuple, optional): Pixel fill value for the area outside the transformed image. - Defaults to (114, 114, 114). + Defaults to (0, 0, 0). auto (bool, optional): auto padding. Defaults to True. stride (int, optional): padding stride. Defaults to 32. Returns: @@ -109,3 +110,10 @@ def nms(boxes: np.ndarray, overlapThresh: int = 0): indices = indices[indices != i] return boxes[indices] + + +class DownloadProgressBar(tqdm): + def update_to(self, b=1, bsize=1, tsize=None): + if tsize is not None: + self.total = tsize + self.update(b * bsize - self.n) diff --git a/pyroengine/vision.py b/pyroengine/vision.py index a9ea22dd..7d7365e9 100644 --- a/pyroengine/vision.py +++ b/pyroengine/vision.py @@ -11,11 +11,11 @@ import onnxruntime from PIL import Image -from .utils import letterbox, nms, xywh2xyxy +from .utils import DownloadProgressBar, letterbox, nms, xywh2xyxy __all__ = ["Classifier"] -MODEL_URL = "https://github.com/pyronear/pyro-vision/releases/download/v0.2.0/yolov8s_v001.onnx" +MODEL_URL = "https://huggingface.co/pyronear/yolov8s/resolve/main/yolov8s.onnx" class Classifier: @@ -29,14 +29,16 @@ class Classifier: model_path: model path """ - def __init__(self, model_path: Optional[str] = "data/model.onnx", img_size: tuple = (384, 640)) -> None: + def __init__(self, model_path: Optional[str] = "data/model.onnx", img_size: tuple = (1024, 1024)) -> None: if model_path is None: model_path = "data/model.onnx" if not os.path.isfile(model_path): os.makedirs(os.path.split(model_path)[0], exist_ok=True) print(f"Downloading model from {MODEL_URL} ...") - urlretrieve(MODEL_URL, model_path) + with DownloadProgressBar(unit="B", unit_scale=True, miniters=1, desc=model_path) as t: + urlretrieve(MODEL_URL, model_path, reporthook=t.update_to) + print("Model downloaded!") self.ort_session = onnxruntime.InferenceSession(model_path) self.img_size = img_size diff --git a/tests/test_vision.py b/tests/test_vision.py index dada6dbe..a974e5c1 100644 --- a/tests/test_vision.py +++ b/tests/test_vision.py @@ -9,7 +9,7 @@ def test_classifier(mock_wildfire_image): # Check preprocessing out, pad = model.preprocess_image(mock_wildfire_image) assert isinstance(out, np.ndarray) and out.dtype == np.float32 - assert out.shape == (1, 3, 384, 640) + assert out.shape == (1, 3, 1024, 1024) assert isinstance(pad, tuple) # Check inference out = model(mock_wildfire_image) @@ -18,10 +18,10 @@ def test_classifier(mock_wildfire_image): assert conf >= 0 and conf <= 1 # Test mask - mask = np.ones((384, 640)) + mask = np.ones((1024, 640)) out = model(mock_wildfire_image, mask) assert out.shape == (1, 5) - mask = np.zeros((384, 640)) + mask = np.zeros((1024, 1024)) out = model(mock_wildfire_image, mask) assert out.shape == (0, 5)