Skip to content

Commit

Permalink
save all localy (#214)
Browse files Browse the repository at this point in the history
* save all localy

* fix style

* add test
  • Loading branch information
MateoLostanlen authored Jun 26, 2024
1 parent 12a6a20 commit aaf00e6
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 3 deletions.
12 changes: 10 additions & 2 deletions pyroengine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class Engine:
frame_saving_period: Send one frame over N to the api for our dataset
cache_size: maximum number of alerts to save in cache
day_time_strategy: strategy to define if it's daytime
save_captured_frames: save all captured frames for debugging
kwargs: keyword args of Classifier
Examples:
Expand Down Expand Up @@ -76,6 +77,7 @@ def __init__(
backup_size: int = 30,
jpeg_quality: int = 80,
day_time_strategy: Optional[str] = None,
save_captured_frames: Optional[bool] = False,
**kwargs: Any,
) -> None:
"""Init engine"""
Expand All @@ -102,6 +104,7 @@ def __init__(
self.jpeg_quality = jpeg_quality
self.cache_backup_period = cache_backup_period
self.day_time_strategy = day_time_strategy
self.save_captured_frames = save_captured_frames

# Local backup
self._backup_size = backup_size
Expand Down Expand Up @@ -282,6 +285,9 @@ def predict(self, frame: Image.Image, cam_id: Optional[str] = None) -> float:
self._dump_cache()
self.last_cache_dump = ts

if self.save_captured_frames:
self._local_backup(frame_resize, cam_id, is_alert=False)

return float(conf)

def _stage_alert(self, frame: Image.Image, cam_id: str, ts: int, localization: list) -> None:
Expand Down Expand Up @@ -341,14 +347,16 @@ def _process_alerts(self) -> None:
logging.warning(e)
break

def _local_backup(self, img: Image.Image, cam_id: str) -> None:
def _local_backup(self, img: Image.Image, cam_id: Optional[str], is_alert: bool = True) -> None:
"""Save image on device
Args:
img (Image.Image): Image to save
cam_id (str): camera id (ip address)
is_alert (bool): is the frame an alert ?
"""
backup_cache = self._cache.joinpath("backup/alerts/")
folder = "alerts" if is_alert else "save"
backup_cache = self._cache.joinpath(f"backup/{folder}/")
self._clean_local_backup(backup_cache) # Dump old cache
backup_cache = backup_cache.joinpath(f"{time.strftime('%Y%m%d')}/{cam_id}")
backup_cache.mkdir(parents=True, exist_ok=True)
Expand Down
2 changes: 2 additions & 0 deletions src/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def main(args):
cache_size=args.cache_size,
jpeg_quality=args.jpeg_quality,
day_time_strategy=args.day_time_strategy,
save_captured_frames=args.save_captured_frames,
)

sys_controller = SystemController(
Expand Down Expand Up @@ -117,6 +118,7 @@ def main(args):
parser.add_argument("--protocol", type=str, default="https", help="Camera protocol")
# Backup
parser.add_argument("--backup-size", type=int, default=10000, help="Local backup can't be bigger than 10Go")
parser.add_argument("--save_captured_frames", type=bool, default=False, help="Save all captured frames locally")

# Time config
parser.add_argument("--period", type=int, default=30, help="Number of seconds between each camera stream analysis")
Expand Down
2 changes: 1 addition & 1 deletion tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_engine_offline(tmpdir_factory, mock_wildfire_image, mock_forest_image):
engine.clear_cache()

# inference
engine = Engine(nb_consecutive_frames=4, cache_folder=folder)
engine = Engine(nb_consecutive_frames=4, cache_folder=folder, save_captured_frames=True)
out = engine.predict(mock_forest_image)
assert isinstance(out, float) and 0 <= out <= 1
assert len(engine._states["-1"]["last_predictions"]) == 1
Expand Down

0 comments on commit aaf00e6

Please sign in to comment.