From 203b1d041b682acbe9da5c95b8020e85cff0888c Mon Sep 17 00:00:00 2001 From: Mateo Date: Mon, 31 Jul 2023 15:02:01 +0200 Subject: [PATCH] Occlusion mask (#165) * add occlusion mask * trest mask * missing args * use cam_key * update test * add a test * clip values --- pyroengine/engine.py | 12 +++++++++++- pyroengine/vision.py | 23 ++++++++++++++++++++--- tests/test_vision.py | 9 +++++++++ 3 files changed, 40 insertions(+), 4 deletions(-) diff --git a/pyroengine/engine.py b/pyroengine/engine.py index 86af6326..32c424c1 100644 --- a/pyroengine/engine.py +++ b/pyroengine/engine.py @@ -15,6 +15,7 @@ from pathlib import Path from typing import Any, Dict, Optional, Tuple +import cv2 import numpy as np from PIL import Image from pyroclient import client @@ -150,6 +151,15 @@ def __init__( "ongoing": False, } + self.occlusion_masks = {"-1": None} + if isinstance(cam_creds, dict): + for cam_id in cam_creds: + mask_file = cache_folder + "/occlusion_masks/" + cam_id + ".jpg" + if os.path.isfile(mask_file): + self.occlusion_masks[cam_id] = cv2.imread(mask_file, 0) + else: + self.occlusion_masks[cam_id] = None + # Restore pending alerts cache self._alerts: deque = deque([], cache_size) self._cache = Path(cache_folder) # with Docker, the path has to be a bind volume @@ -279,7 +289,7 @@ def predict(self, frame: Image.Image, cam_id: Optional[str] = None) -> float: if is_day_time(self._cache, frame, self.day_time_strategy): # Inference with ONNX - preds = self.model(frame.convert("RGB")) + preds = self.model(frame.convert("RGB"), self.occlusion_masks[cam_key]) conf = self._update_states(frame_resize, preds, cam_key) # Log analysis result diff --git a/pyroengine/vision.py b/pyroengine/vision.py index 85801ec0..fa133595 100644 --- a/pyroengine/vision.py +++ b/pyroengine/vision.py @@ -39,12 +39,12 @@ def __init__(self, model_path: Optional[str] = "data/model.onnx", img_size: tupl self.ort_session = onnxruntime.InferenceSession(model_path) self.img_size = img_size - def preprocess_image(self, pil_img: Image.Image) -> np.ndarray: + def preprocess_image(self, pil_img: Image.Image, mask: np.array = None) -> np.ndarray: """Preprocess an image for inference Args: pil_img: a valid pillow image - img_size: image size + mask: occlusion mask to drop prediction in an area Returns: the resized and normalized image of shape (1, C, H, W) @@ -57,7 +57,7 @@ def preprocess_image(self, pil_img: Image.Image) -> np.ndarray: return np_img - def __call__(self, pil_img: Image.Image) -> np.ndarray: + def __call__(self, pil_img: Image.Image, occlusion_mask: np.array = None) -> np.ndarray: np_img = self.preprocess_image(pil_img) # ONNX inference @@ -77,4 +77,21 @@ def __call__(self, pil_img: Image.Image) -> np.ndarray: else: y = np.zeros((0, 5)) # normalize output + # Remove prediction in occlusion mask + if occlusion_mask is not None: + hm, wm = occlusion_mask.shape + keep = [] + for p in y.copy(): + p[:4:2] *= wm + p[1:4:2] *= hm + p[:4:2] = np.clip(p[:4:2], 0, wm) + p[:4:2] = np.clip(p[:4:2], 0, hm) + x0, y0, x1, y1 = p.astype("int")[:4] + if np.sum(occlusion_mask[y0:y1, x0:x1]) > 0: + keep.append(True) + else: + keep.append(False) + + y = y[keep] + return y diff --git a/tests/test_vision.py b/tests/test_vision.py index a0733f49..5c895de9 100644 --- a/tests/test_vision.py +++ b/tests/test_vision.py @@ -15,3 +15,12 @@ def test_classifier(mock_wildfire_image): assert out.shape == (1, 5) conf = np.max(out[:, 4]) assert conf >= 0 and conf <= 1 + + # Test mask + mask = np.ones((384, 640)) + out = model(mock_wildfire_image, mask) + assert out.shape == (1, 5) + + mask = np.zeros((384, 640)) + out = model(mock_wildfire_image, mask) + assert out.shape == (0, 5)