-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
137 lines (108 loc) · 4.48 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os
import cv2
import torch
from numpy import random
from models.experimental import attempt_load
from utils.datasets import LoadImages
from utils.general import check_img_size, non_max_suppression, scale_coords, xyxy2xywh
from utils.plots import plot_one_box
from apscheduler.schedulers.background import BackgroundScheduler
from scan import update_repair_date
from flask import Flask, jsonify, request, send_file, Response
app = Flask(__name__)
# Create a background scheduler
scheduler = BackgroundScheduler(daemon=True)
# Schedule the task to run every 10 seconds
scheduler.add_job(update_repair_date, 'interval', seconds=10)
# Start the scheduler
scheduler.start()
weights = 'model.pt'
imgsz = 640
conf_thres = 0.25
iou_thres = 0.45
# Initialize
device = torch.device('cpu')
# Load model
model = attempt_load(weights, map_location=device) # load FP32 model
stride = int(model.stride.max()) # model stride
imgsz = check_img_size(imgsz, s=stride) # check img_size
def hex_to_rgb(hex):
rgb = []
for i in (0, 2, 4):
decimal = int(hex[i:i+2], 16)
rgb.append(decimal)
return rgb
def detect(source='buffer/image.jpg', filename='noname.jpg'):
# Set Dataloader
dataset = LoadImages(source, img_size=imgsz, stride=stride)
# Get names and colors
names = model.module.names if hasattr(model, 'module') else model.names
# colors = [[random.randint(0, 255) for _ in range(3)] for _ in names]
hex_color = ['f38ba8', 'f9e2af', '94e2d5', '74c7ec', 'b4befe']
colors = []
for h in hex_color:
colors.append(hex_to_rgb(h))
print(colors)
danger = False
class_count = [0, 0, 0, 0, 0]
for path, img, im0s, vid_cap in dataset:
img = torch.from_numpy(img).to(device)
img = img.float() # uint8 to fp16/32
img /= 255.0 # 0 - 255 to 0.0 - 1.0
if img.ndimension() == 3:
img = img.unsqueeze(0)
# Inference
with torch.no_grad(): # Calculating gradients would cause a GPU memory leak
pred = model(img, augment=False)[0]
# Apply NMS
pred = non_max_suppression(pred, conf_thres, iou_thres, classes=None, agnostic=False)
# Process detections
for i, det in enumerate(pred): # detections per image
p, s, im0, frame = path, '', im0s, getattr(dataset, 'frame', 0)
img_path = 'buffer/result.jpg'
txt_path = 'buffer/label.txt'
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
if len(det):
danger = True
# Rescale boxes from img_size to im0 size
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
for c in det[:, -1].unique():
n = (det[:, -1] == c).sum() # detections per class
s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string
# Write results
for *xyxy, conf, cls in reversed(det):
# Write to file
xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
line = (cls, *xywh, conf) # label format
with open(txt_path, 'w') as f:
f.write(('%g ' * len(line)).rstrip() % line + '\n')
# Add bbox to image
label = f'{names[int(cls)]} {conf:.2f}'
plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=0)
class_count[int(cls)] += 1
# Save results (image with detections)
cv2.imwrite(img_path, im0)
cv2.imwrite(f'static/{filename}', im0)
return img_path, danger, class_count
@app.route("/infer", methods=["POST"])
def predict():
if request.method == "POST":
file = request.files["file"]
file.save("buffer/image.jpg")
img_path, danger, class_count = detect("buffer/image.jpg", file.filename)
if danger:
return {
"electronic_device": class_count[0],
"laptop": class_count[1],
"scissors": class_count[2],
"knife": class_count[3],
"gun": class_count[4],
}
else:
return Response(status=200, mimetype='application/json', response='null')
def run():
app.debug = True
port = int(os.environ.get("PORT", 8000))
app.run(host='0.0.0.0', port=port, debug=True)
if __name__ == "__main__":
run()