diff --git a/pyroengine/engine.py b/pyroengine/engine.py index 7b92359..372e304 100644 --- a/pyroengine/engine.py +++ b/pyroengine/engine.py @@ -48,6 +48,7 @@ class Engine: frame_saving_period: Send one frame over N to the api for our dataset cache_size: maximum number of alerts to save in cache day_time_strategy: strategy to define if it's daytime + save_captured_frames: save all captured frames for debugging kwargs: keyword args of Classifier Examples: @@ -76,6 +77,7 @@ def __init__( backup_size: int = 30, jpeg_quality: int = 80, day_time_strategy: Optional[str] = None, + save_captured_frames: Optional[bool] = False, **kwargs: Any, ) -> None: """Init engine""" @@ -102,6 +104,7 @@ def __init__( self.jpeg_quality = jpeg_quality self.cache_backup_period = cache_backup_period self.day_time_strategy = day_time_strategy + self.save_captured_frames = save_captured_frames # Local backup self._backup_size = backup_size @@ -282,6 +285,9 @@ def predict(self, frame: Image.Image, cam_id: Optional[str] = None) -> float: self._dump_cache() self.last_cache_dump = ts + if self.save_captured_frames: + self._local_backup(frame_resize, cam_id, is_alert=False) + return float(conf) def _stage_alert(self, frame: Image.Image, cam_id: str, ts: int, localization: list) -> None: @@ -341,14 +347,16 @@ def _process_alerts(self) -> None: logging.warning(e) break - def _local_backup(self, img: Image.Image, cam_id: str) -> None: + def _local_backup(self, img: Image.Image, cam_id: Optional[str], is_alert: bool = True) -> None: """Save image on device Args: img (Image.Image): Image to save cam_id (str): camera id (ip address) + is_alert (bool): is the frame an alert ? """ - backup_cache = self._cache.joinpath("backup/alerts/") + folder = "alerts" if is_alert else "save" + backup_cache = self._cache.joinpath(f"backup/{folder}/") self._clean_local_backup(backup_cache) # Dump old cache backup_cache = backup_cache.joinpath(f"{time.strftime('%Y%m%d')}/{cam_id}") backup_cache.mkdir(parents=True, exist_ok=True) diff --git a/src/run.py b/src/run.py index 5b6a33a..64a1f12 100644 --- a/src/run.py +++ b/src/run.py @@ -72,6 +72,7 @@ def main(args): cache_size=args.cache_size, jpeg_quality=args.jpeg_quality, day_time_strategy=args.day_time_strategy, + save_captured_frames=args.save_captured_frames, ) sys_controller = SystemController( @@ -117,6 +118,7 @@ def main(args): parser.add_argument("--protocol", type=str, default="https", help="Camera protocol") # Backup parser.add_argument("--backup-size", type=int, default=10000, help="Local backup can't be bigger than 10Go") + parser.add_argument("--save_captured_frames", type=bool, default=False, help="Save all captured frames locally") # Time config parser.add_argument("--period", type=int, default=30, help="Number of seconds between each camera stream analysis") diff --git a/tests/test_engine.py b/tests/test_engine.py index 3c44b60..1f2a48e 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -44,7 +44,7 @@ def test_engine_offline(tmpdir_factory, mock_wildfire_image, mock_forest_image): engine.clear_cache() # inference - engine = Engine(nb_consecutive_frames=4, cache_folder=folder) + engine = Engine(nb_consecutive_frames=4, cache_folder=folder, save_captured_frames=True) out = engine.predict(mock_forest_image) assert isinstance(out, float) and 0 <= out <= 1 assert len(engine._states["-1"]["last_predictions"]) == 1