-
Notifications
You must be signed in to change notification settings - Fork 0
/
PicoDet.py
181 lines (167 loc) · 8.97 KB
/
PicoDet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import cv2
import numpy as np
import argparse
import onnxruntime as ort
import math
import time
from functools import wraps
# 计时器记录函数执行性能的装饰器
def timeit(func):
@wraps(func)
def wrapper(*args, **kwargs):
start = time.clock()
ret = func(*args, **kwargs)
end = time.clock()
print('used:', end - start)
return ret
return wrapper
class PicoDet():
def __init__(self, model_pb_path, label_path, prob_threshold=0.4, iou_threshold=0.3):
self.classes = list(map(lambda x: x.strip(), open(label_path, 'r').readlines()))
self.num_classes = len(self.classes)
self.prob_threshold = prob_threshold
self.iou_threshold = iou_threshold
self.mean = np.array([103.53, 116.28, 123.675], dtype=np.float32).reshape(1, 1, 3)
self.std = np.array([57.375, 57.12, 58.395], dtype=np.float32).reshape(1, 1, 3)
so = ort.SessionOptions()
so.log_severity_level = 3
self.net = ort.InferenceSession(model_pb_path, so)
self.input_shape = (self.net.get_inputs()[0].shape[2], self.net.get_inputs()[0].shape[3])
self.num_outs = int(len(self.net.get_outputs()) * 0.5)
self.reg_max = int(self.net.get_outputs()[self.num_outs].shape[-1] / 4) - 1
self.project = np.arange(self.reg_max + 1)
self.strides = [int(8 * (2**i)) for i in range(self.num_outs)]
self.mlvl_anchors = []
for i in range(len(self.strides)):
anchors = self._make_grid((math.ceil(self.input_shape[0] / self.strides[i]), math.ceil(self.input_shape[1] / self.strides[i])), self.strides[i])
self.mlvl_anchors.append(anchors)
def _make_grid(self, featmap_size, stride):
feat_h, feat_w = featmap_size
shift_x = np.arange(0, feat_w) * stride
shift_y = np.arange(0, feat_h) * stride
xv, yv = np.meshgrid(shift_x, shift_y)
xv = xv.flatten()
yv = yv.flatten()
cx = xv + 0.5 * (stride-1)
cy = yv + 0.5 * (stride - 1)
return np.stack((cx, cy), axis=-1)
def softmax(self,x, axis=1):
x_exp = np.exp(x)
# 如果是列向量,则axis=0
x_sum = np.sum(x_exp, axis=axis, keepdims=True)
s = x_exp / x_sum
return s
def _normalize(self, img): ### c++: https://blog.csdn.net/wuqingshan2010/article/details/107727909
img = img.astype(np.float32)
img = (img / 255.0 - self.mean / 255.0) / (self.std / 255.0)
# img = (img - self.mean) / (self.std)
return img
def resize_image(self, srcimg, keep_ratio=True):
top, left, newh, neww = 0, 0, self.input_shape[0], self.input_shape[1]
if keep_ratio and srcimg.shape[0] != srcimg.shape[1]:
hw_scale = srcimg.shape[0] / srcimg.shape[1]
if hw_scale > 1:
newh, neww = self.input_shape[0], int(self.input_shape[1] / hw_scale)
img = cv2.resize(srcimg, (neww, newh), interpolation=cv2.INTER_AREA)
left = int((self.input_shape[1] - neww) * 0.5)
img = cv2.copyMakeBorder(img, 0, 0, left, self.input_shape[1] - neww - left, cv2.BORDER_CONSTANT,
value=0) # add border
else:
newh, neww = int(self.input_shape[0] * hw_scale), self.input_shape[1]
img = cv2.resize(srcimg, (neww, newh), interpolation=cv2.INTER_AREA)
top = int((self.input_shape[0] - newh) * 0.5)
img = cv2.copyMakeBorder(img, top, self.input_shape[0] - newh - top, 0, 0, cv2.BORDER_CONSTANT, value=0)
else:
img = cv2.resize(srcimg, self.input_shape, interpolation=cv2.INTER_AREA)
return img, newh, neww, top, left
def post_process(self, preds):
cls_scores, bbox_preds = preds[:self.num_outs], preds[self.num_outs:]
det_bboxes, det_conf, det_classid = self.get_bboxes_single(cls_scores, bbox_preds, 1, rescale=False)
return det_bboxes.astype(np.int32), det_conf, det_classid
def get_bboxes_single(self, cls_scores, bbox_preds, scale_factor, rescale=False):
mlvl_bboxes = []
mlvl_scores = []
for stride, cls_score, bbox_pred, anchors in zip(self.strides, cls_scores, bbox_preds, self.mlvl_anchors):
if cls_score.ndim==3:
cls_score = cls_score.squeeze(axis=0)
if bbox_pred.ndim==3:
bbox_pred = bbox_pred.squeeze(axis=0)
bbox_pred = self.softmax(bbox_pred.reshape(-1, self.reg_max + 1), axis=1)
# bbox_pred = np.sum(bbox_pred * np.expand_dims(self.project, axis=0), axis=1).reshape((-1, 4))
bbox_pred = np.dot(bbox_pred, self.project).reshape(-1,4)
bbox_pred *= stride
# nms_pre = cfg.get('nms_pre', -1)
nms_pre = 1000
if nms_pre > 0 and cls_score.shape[0] > nms_pre:
max_scores = cls_score.max(axis=1)
topk_inds = max_scores.argsort()[::-1][0:nms_pre]
anchors = anchors[topk_inds, :]
bbox_pred = bbox_pred[topk_inds, :]
cls_score = cls_score[topk_inds, :]
bboxes = self.distance2bbox(anchors, bbox_pred, max_shape=self.input_shape)
mlvl_bboxes.append(bboxes)
mlvl_scores.append(cls_score)
mlvl_bboxes = np.concatenate(mlvl_bboxes, axis=0)
if rescale:
mlvl_bboxes /= scale_factor
mlvl_scores = np.concatenate(mlvl_scores, axis=0)
bboxes_wh = mlvl_bboxes.copy()
bboxes_wh[:, 2:4] = bboxes_wh[:, 2:4] - bboxes_wh[:, 0:2] ####xywh
classIds = np.argmax(mlvl_scores, axis=1)
confidences = np.max(mlvl_scores, axis=1) ####max_class_confidence
indices = cv2.dnn.NMSBoxes(bboxes_wh.tolist(), confidences.tolist(), self.prob_threshold, self.iou_threshold)
if len(indices)>0:
indices=indices.flatten()
mlvl_bboxes = mlvl_bboxes[indices]
confidences = confidences[indices]
classIds = classIds[indices]
return mlvl_bboxes, confidences, classIds
else:
print('nothing detect')
return np.array([]), np.array([]), np.array([])
def distance2bbox(self, points, distance, max_shape=None):
x1 = points[:, 0] - distance[:, 0]
y1 = points[:, 1] - distance[:, 1]
x2 = points[:, 0] + distance[:, 2]
y2 = points[:, 1] + distance[:, 3]
if max_shape is not None:
x1 = np.clip(x1, 0, max_shape[1])
y1 = np.clip(y1, 0, max_shape[0])
x2 = np.clip(x2, 0, max_shape[1])
y2 = np.clip(y2, 0, max_shape[0])
return np.stack([x1, y1, x2, y2], axis=-1)
@timeit
def detect(self, srcimg):
img, newh, neww, top, left = self.resize_image(srcimg)
img = self._normalize(img)
blob = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
outs = self.net.run(None, {self.net.get_inputs()[0].name: blob})
det_bboxes, det_conf, det_classid = self.post_process(outs)
# results = []
ratioh,ratiow = srcimg.shape[0]/newh,srcimg.shape[1]/neww
for i in range(det_bboxes.shape[0]):
xmin, ymin, xmax, ymax = max(int((det_bboxes[i,0] - left) * ratiow), 0), max(int((det_bboxes[i,1] - top) * ratioh), 0), min(
int((det_bboxes[i,2] - left) * ratiow), srcimg.shape[1]), min(int((det_bboxes[i,3] - top) * ratioh), srcimg.shape[0])
# results.append((xmin, ymin, xmax, ymax, self.classes[det_classid[i]], det_conf[i]))
cv2.rectangle(srcimg, (xmin, ymin), (xmax, ymax), (0, 0, 255), thickness=1)
# print(self.classes[det_classid[i]]+': '+str(round(det_conf[i], 3)))
cv2.putText(srcimg, self.classes[det_classid[i]]+': '+str(round(det_conf[i], 3)), (xmin, ymin - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), thickness=1)
# cv2.imwrite('result.jpg', srcimg)
return srcimg
if __name__=='__main__':
# 参考来自:https://github.com/hpc203
parser = argparse.ArgumentParser()
parser.add_argument('--imgpath', type=str, default='models/person.jpg', help="image path")
parser.add_argument('--modelpath', type=str, default='models/picodet_s_416_coco.onnx', choices=["models/picodet_m_320_coco.onnx", "models/picodet_m_416_coco.onnx", "models/picodet_s_320_coco.onnx", "models/picodet_s_416_coco.onnx"], help="onnx filepath")
parser.add_argument('--classfile', type=str, default='models/models.names', help="classname filepath")
parser.add_argument('--confThreshold', default=0.5, type=float, help='class confidence')
parser.add_argument('--nmsThreshold', default=0.6, type=float, help='nms iou thresh')
args = parser.parse_args()
srcimg = cv2.imread(args.imgpath)
net = PicoDet(args.modelpath, args.classfile, prob_threshold=args.confThreshold, iou_threshold=args.nmsThreshold)
srcimg = net.detect(srcimg)
winName = 'Deep learning object detection in onnxruntime'
cv2.namedWindow(winName, cv2.WINDOW_NORMAL)
cv2.imshow(winName, srcimg)
cv2.waitKey(0)
cv2.destroyAllWindows()