Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
MateoLostanlen committed Jul 19, 2023
1 parent faf9c81 commit 83c5285
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pyroengine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def predict(self, frame: Image.Image, cam_id: Optional[str] = None) -> float:
except ConnectionError:
stream.seek(0) # "Rewind" the stream to the beginning so we can read its content

return conf
return float(conf)

def _upload_frame(self, cam_id: str, media_data: bytes) -> Response:
"""Save frame"""
Expand Down
6 changes: 4 additions & 2 deletions tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from datetime import datetime
from pathlib import Path

import numpy as np
from dotenv import load_dotenv

from pyroengine.engine import Engine
Expand Down Expand Up @@ -37,16 +38,17 @@ def test_engine_offline(tmpdir_factory, mock_wildfire_image, mock_forest_image):
engine._dump_cache()

# Cache dump loading
engine = Engine(cache_folder=folder + "model.onnx")
engine = Engine(cache_folder=folder)
assert len(engine._alerts) == 1
engine.clear_cache()

# inference
engine = Engine(alert_relaxation=3, cache_folder=folder + "model.onnx")
engine = Engine(alert_relaxation=3, cache_folder=folder)
out = engine.predict(mock_forest_image)
assert isinstance(out, float) and 0 <= out <= 1
assert engine._states["-1"]["consec"] == 0
out = engine.predict(mock_wildfire_image)

assert isinstance(out, float) and 0 <= out <= 1
assert engine._states["-1"]["consec"] == 1
# Alert relaxation
Expand Down
4 changes: 3 additions & 1 deletion tests/test_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,6 @@ def test_classifier(mock_wildfire_image):
assert out.shape == (1, 3, 384, 640)
# Check inference
out = model(mock_wildfire_image)
assert out >= 0 and out <= 1
assert out.shape == (1, 5)
conf = np.max(out[:, 4])
assert conf >= 0 and conf <= 1

0 comments on commit 83c5285

Please sign in to comment.