Skip to content

Commit

Permalink
Occlusion mask (#165)
Browse files Browse the repository at this point in the history
* add occlusion mask

* trest mask

* missing args

* use cam_key

* update test

* add a test

* clip values
  • Loading branch information
MateoLostanlen authored Jul 31, 2023
1 parent 6501eec commit 203b1d0
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 4 deletions.
12 changes: 11 additions & 1 deletion pyroengine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pathlib import Path
from typing import Any, Dict, Optional, Tuple

import cv2
import numpy as np
from PIL import Image
from pyroclient import client
Expand Down Expand Up @@ -150,6 +151,15 @@ def __init__(
"ongoing": False,
}

self.occlusion_masks = {"-1": None}
if isinstance(cam_creds, dict):
for cam_id in cam_creds:
mask_file = cache_folder + "/occlusion_masks/" + cam_id + ".jpg"
if os.path.isfile(mask_file):
self.occlusion_masks[cam_id] = cv2.imread(mask_file, 0)
else:
self.occlusion_masks[cam_id] = None

# Restore pending alerts cache
self._alerts: deque = deque([], cache_size)
self._cache = Path(cache_folder) # with Docker, the path has to be a bind volume
Expand Down Expand Up @@ -279,7 +289,7 @@ def predict(self, frame: Image.Image, cam_id: Optional[str] = None) -> float:

if is_day_time(self._cache, frame, self.day_time_strategy):
# Inference with ONNX
preds = self.model(frame.convert("RGB"))
preds = self.model(frame.convert("RGB"), self.occlusion_masks[cam_key])
conf = self._update_states(frame_resize, preds, cam_key)

# Log analysis result
Expand Down
23 changes: 20 additions & 3 deletions pyroengine/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ def __init__(self, model_path: Optional[str] = "data/model.onnx", img_size: tupl
self.ort_session = onnxruntime.InferenceSession(model_path)
self.img_size = img_size

def preprocess_image(self, pil_img: Image.Image) -> np.ndarray:
def preprocess_image(self, pil_img: Image.Image, mask: np.array = None) -> np.ndarray:
"""Preprocess an image for inference
Args:
pil_img: a valid pillow image
img_size: image size
mask: occlusion mask to drop prediction in an area
Returns:
the resized and normalized image of shape (1, C, H, W)
Expand All @@ -57,7 +57,7 @@ def preprocess_image(self, pil_img: Image.Image) -> np.ndarray:

return np_img

def __call__(self, pil_img: Image.Image) -> np.ndarray:
def __call__(self, pil_img: Image.Image, occlusion_mask: np.array = None) -> np.ndarray:
np_img = self.preprocess_image(pil_img)

# ONNX inference
Expand All @@ -77,4 +77,21 @@ def __call__(self, pil_img: Image.Image) -> np.ndarray:
else:
y = np.zeros((0, 5)) # normalize output

# Remove prediction in occlusion mask
if occlusion_mask is not None:
hm, wm = occlusion_mask.shape
keep = []
for p in y.copy():
p[:4:2] *= wm
p[1:4:2] *= hm
p[:4:2] = np.clip(p[:4:2], 0, wm)
p[:4:2] = np.clip(p[:4:2], 0, hm)
x0, y0, x1, y1 = p.astype("int")[:4]
if np.sum(occlusion_mask[y0:y1, x0:x1]) > 0:
keep.append(True)
else:
keep.append(False)

y = y[keep]

return y
9 changes: 9 additions & 0 deletions tests/test_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,12 @@ def test_classifier(mock_wildfire_image):
assert out.shape == (1, 5)
conf = np.max(out[:, 4])
assert conf >= 0 and conf <= 1

# Test mask
mask = np.ones((384, 640))
out = model(mock_wildfire_image, mask)
assert out.shape == (1, 5)

mask = np.zeros((384, 640))
out = model(mock_wildfire_image, mask)
assert out.shape == (0, 5)

0 comments on commit 203b1d0

Please sign in to comment.