Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Donwload model only when needed #200

Merged
merged 10 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ dependencies = [
"pyroclient @ git+https://github.com/pyronear/pyro-api.git@main#egg=pkg&subdirectory=client",
"requests>=2.20.0,<3.0.0",
"opencv-python==4.5.5.64",
"tqdm>=4.62.0",
"tqdm>=4.62.0",
"huggingface_hub==0.23.1",
]

[project.optional-dependencies]
Expand Down
69 changes: 62 additions & 7 deletions pyroengine/vision.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,33 @@
# Copyright (C) 2022-2024, Pyronear.
# Copyright (C) 2023-2024, Pyronear.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

import json
import os
from typing import Optional, Tuple
from urllib.request import urlretrieve

import cv2 # type: ignore[import-untyped]
import numpy as np
import onnxruntime
from huggingface_hub import HfApi # type: ignore[import-untyped]
from PIL import Image

from .utils import DownloadProgressBar, nms, xywh2xyxy

__all__ = ["Classifier"]

MODEL_URL = "https://huggingface.co/pyronear/yolov8s/resolve/main/model.onnx"
MODEL_ID = "pyronear/yolov8s"
MODEL_NAME = "model.onnx"
METADATA_PATH = "data/model_metadata.json"


# Utility function to save metadata
def save_metadata(metadata_path, metadata):
with open(metadata_path, "w") as f:
json.dump(metadata, f)


class Classifier:
Expand All @@ -34,16 +45,60 @@ def __init__(self, model_path: Optional[str] = "data/model.onnx", base_img_size:
if model_path is None:
model_path = "data/model.onnx"

if not os.path.isfile(model_path):
os.makedirs(os.path.split(model_path)[0], exist_ok=True)
print(f"Downloading model from {MODEL_URL} ...")
with DownloadProgressBar(unit="B", unit_scale=True, miniters=1, desc=model_path) as t:
urlretrieve(MODEL_URL, model_path, reporthook=t.update_to)
print("Model downloaded!")
# Get the expected SHA256 from Hugging Face
api = HfApi()
model_info = api.model_info(MODEL_ID, files_metadata=True)
expected_sha256 = self.get_sha(model_info.siblings)

if not expected_sha256:
raise ValueError("SHA256 hash for the model file not found in the Hugging Face model metadata.")

# Check if the model file exists
if os.path.isfile(model_path):
# Load existing metadata
metadata = self.load_metadata(METADATA_PATH)
if metadata and metadata.get("sha256") == expected_sha256:
print("Model already exists and the SHA256 hash matches. No download needed.")
else:
print("Model exists but the SHA256 hash does not match or the file doesn't exist.")
os.remove(model_path)
self.download_model(model_path, expected_sha256)
else:
self.download_model(model_path, expected_sha256)

self.ort_session = onnxruntime.InferenceSession(model_path)
self.base_img_size = base_img_size

def get_sha(self, siblings):
# Extract the SHA256 hash from the model files metadata
for file in siblings:
if file.rfilename == os.path.basename(MODEL_NAME):
expected_sha256 = file.lfs.sha256
break
return expected_sha256

def download_model(self, model_path, expected_sha256):
# Ensure the directory exists
os.makedirs(os.path.split(model_path)[0], exist_ok=True)

# Download the model
print(f"Downloading model from {MODEL_URL} ...")
with DownloadProgressBar(unit="B", unit_scale=True, miniters=1, desc=model_path) as t:
urlretrieve(MODEL_URL, model_path, reporthook=t.update_to)
print("Model downloaded!")

# Save the metadata
metadata = {"sha256": expected_sha256}
save_metadata(METADATA_PATH, metadata)
print("Metadata saved!")

# Utility function to load metadata
def load_metadata(self, metadata_path):
if os.path.exists(metadata_path):
with open(metadata_path, "r") as f:
return json.load(f)
return None

def preprocess_image(self, pil_img: Image.Image, new_img_size: list) -> Tuple[np.ndarray, Tuple[int, int]]:
"""Preprocess an image for inference

Expand Down
119 changes: 99 additions & 20 deletions tests/test_vision.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,105 @@
import os
from unittest.mock import MagicMock, patch

import numpy as np
import pytest

from pyroengine.vision import Classifier

METADATA_PATH = "data/model_metadata.json"
model_path = "data/model.onnx"
sha = "12b9b5728dfa2e60502dcde2914bfdc4e9378caa57611c567a44cdd6228838c2"


def custom_isfile_false(path):
if path == model_path:
return False # or True based on your test case
return True # Default behavior for other paths


def custom_isfile_true(path):
if path == model_path:
return True # or True based on your test case
return True # Default behavior for other paths


# Test for the case : the model doesn't exist
def test_classifier(mock_wildfire_image):
# Instantiate the ONNX model
model = Classifier()
# Check preprocessing
out = model.preprocess_image(mock_wildfire_image, (1024, 576))
assert isinstance(out, np.ndarray) and out.dtype == np.float32
assert out.shape == (1, 3, 576, 1024)
# Check inference
out = model(mock_wildfire_image)
assert out.shape == (1, 5)
conf = np.max(out[:, 4])
assert conf >= 0 and conf <= 1

# Test mask
mask = np.ones((1024, 576))
out = model(mock_wildfire_image, mask)
assert out.shape == (1, 5)

mask = np.zeros((1024, 1024))
out = model(mock_wildfire_image, mask)
assert out.shape == (0, 5)
print("test_classifier")
with patch("os.path.isfile", side_effect=custom_isfile_false):
# Instantiate the ONNX model
model = Classifier()
# Check preprocessing
out = model.preprocess_image(mock_wildfire_image, (1024, 576))
assert isinstance(out, np.ndarray) and out.dtype == np.float32
assert out.shape == (1, 3, 576, 1024)
# Check inference
out = model(mock_wildfire_image)
assert out.shape == (1, 5)
conf = np.max(out[:, 4])
assert conf >= 0 and conf <= 1

# Test mask
mask = np.ones((1024, 576))
out = model(mock_wildfire_image, mask)
assert out.shape == (1, 5)

mask = np.zeros((1024, 1024))
out = model(mock_wildfire_image, mask)
assert out.shape == (0, 5)
os.remove(model_path)
os.remove(METADATA_PATH)


# Test that the model is not loaded
def test_no_download():
print("test_no_download")
data = {"sha256": sha}
with patch("os.path.isfile", side_effect=custom_isfile_true):
with patch("pyroengine.vision.Classifier.load_metadata", return_value=data):
with patch("onnxruntime.InferenceSession", return_value=None):
Classifier()
assert os.path.isfile(model_path) is False


# Test if sha are not the same
@patch("pyroengine.vision.urlretrieve")
@patch("pyroengine.vision.DownloadProgressBar")
def test_sha_inequality(mock_download_progress, mock_urlretrieve):
print("test_sha_inequality")
data = {"sha256": "falsesha"}

# Mock urlretrieve to create a fake file
def fake_urlretrieve(url, filename, reporthook=None):
with open(filename, "w") as f:
f.write("fake model content")

mock_urlretrieve.side_effect = fake_urlretrieve
# Mock the DownloadProgressBar context manager
mock_progress_bar_instance = MagicMock()
mock_download_progress.return_value.__enter__.return_value = mock_progress_bar_instance

with patch("os.path.isfile", side_effect=custom_isfile_true):
with patch("pyroengine.vision.Classifier.load_metadata", return_value=data):
with patch(
"pyroengine.vision.Classifier.get_sha",
return_value=sha,
):
with patch("onnxruntime.InferenceSession", return_value=None):
with patch("os.remove", return_value=True):
model = Classifier()

assert os.path.isfile(model_path) is True
assert model.load_metadata("non_existent_metadata.json") is None
os.remove(model_path)
os.remove(METADATA_PATH)


# Test for raising ValueError if expected_sha256 is not found
def test_raise_value_error_if_no_sha256():
print("test_raise_value_error_if_no_sha256")
with patch("pyroengine.vision.Classifier.get_sha", return_value=""):
with pytest.raises(
ValueError, match="SHA256 hash for the model file not found in the Hugging Face model metadata."
):
Classifier(model_path="non_existent_model.onnx")
Loading