Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Occlusion mask #165

Merged
merged 8 commits into from
Jul 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion pyroengine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pathlib import Path
from typing import Any, Dict, Optional, Tuple

import cv2
import numpy as np
from PIL import Image
from pyroclient import client
Expand Down Expand Up @@ -150,6 +151,15 @@ def __init__(
"ongoing": False,
}

self.occlusion_masks = {"-1": None}
if isinstance(cam_creds, dict):
for cam_id in cam_creds:
mask_file = cache_folder + "/occlusion_masks/" + cam_id + ".jpg"
if os.path.isfile(mask_file):
self.occlusion_masks[cam_id] = cv2.imread(mask_file, 0)
else:
self.occlusion_masks[cam_id] = None

# Restore pending alerts cache
self._alerts: deque = deque([], cache_size)
self._cache = Path(cache_folder) # with Docker, the path has to be a bind volume
Expand Down Expand Up @@ -279,7 +289,7 @@ def predict(self, frame: Image.Image, cam_id: Optional[str] = None) -> float:

if is_day_time(self._cache, frame, self.day_time_strategy):
# Inference with ONNX
preds = self.model(frame.convert("RGB"))
preds = self.model(frame.convert("RGB"), self.occlusion_masks[cam_key])
conf = self._update_states(frame_resize, preds, cam_key)

# Log analysis result
Expand Down
23 changes: 20 additions & 3 deletions pyroengine/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ def __init__(self, model_path: Optional[str] = "data/model.onnx", img_size: tupl
self.ort_session = onnxruntime.InferenceSession(model_path)
self.img_size = img_size

def preprocess_image(self, pil_img: Image.Image) -> np.ndarray:
def preprocess_image(self, pil_img: Image.Image, mask: np.array = None) -> np.ndarray:
"""Preprocess an image for inference

Args:
pil_img: a valid pillow image
img_size: image size
mask: occlusion mask to drop prediction in an area

Returns:
the resized and normalized image of shape (1, C, H, W)
Expand All @@ -57,7 +57,7 @@ def preprocess_image(self, pil_img: Image.Image) -> np.ndarray:

return np_img

def __call__(self, pil_img: Image.Image) -> np.ndarray:
def __call__(self, pil_img: Image.Image, occlusion_mask: np.array = None) -> np.ndarray:
np_img = self.preprocess_image(pil_img)

# ONNX inference
Expand All @@ -77,4 +77,21 @@ def __call__(self, pil_img: Image.Image) -> np.ndarray:
else:
y = np.zeros((0, 5)) # normalize output

# Remove prediction in occlusion mask
if occlusion_mask is not None:
hm, wm = occlusion_mask.shape
keep = []
for p in y.copy():
p[:4:2] *= wm
p[1:4:2] *= hm
p[:4:2] = np.clip(p[:4:2], 0, wm)
p[:4:2] = np.clip(p[:4:2], 0, hm)
x0, y0, x1, y1 = p.astype("int")[:4]
if np.sum(occlusion_mask[y0:y1, x0:x1]) > 0:
keep.append(True)
else:
keep.append(False)

y = y[keep]

return y
9 changes: 9 additions & 0 deletions tests/test_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,12 @@ def test_classifier(mock_wildfire_image):
assert out.shape == (1, 5)
conf = np.max(out[:, 4])
assert conf >= 0 and conf <= 1

# Test mask
mask = np.ones((384, 640))
out = model(mock_wildfire_image, mask)
assert out.shape == (1, 5)

mask = np.zeros((384, 640))
out = model(mock_wildfire_image, mask)
assert out.shape == (0, 5)
Loading