diff --git a/deploy/Deepstream/README.md b/deploy/Deepstream/README.md new file mode 100644 index 00000000..ad10e061 --- /dev/null +++ b/deploy/Deepstream/README.md @@ -0,0 +1,63 @@ +# YOLOv6 in Deepstream + +## Dependencies +- Deepstream6.0 +- TensorRT-8.0 + +## Step 1: Get the TensorRT +Please Follow the file [TensorRT README](../TensorRT/README.md) to get TensorRT engine file `yolov6n.trt`. + +```shell +python ./deploy/ONNX/export_onnx.py \ + --weights yolov6n.pt \ + --img 640 \ + --batch 1 +``` + +```shell +python3 onnx_to_tensorrt.py --fp16 --int8 -v \ + --max_calibration_size=${MAX_CALIBRATION_SIZE} \ + --calibration-data=${CALIBRATION_DATA} \ + --calibration-cache=${CACHE_FILENAME} \ + --preprocess_func=${PREPROCESS_FUNC} \ + --explicit-batch \ + --onnx ${ONNX_MODEL} -o ${OUTPUT} +``` +### Example +```shell +python3 onnx_to_tensorrt.py --fp16 --onnx ${ONNX_MODEL} -o ${OUTPUT} +``` + +## Step 2: Build the So file +Execute the following Command, get the libnvdsinfer_custom_impl_Yolov6.so file. +```shell +cd nvdsparsebbox_YoloV6 +export CUDA_VER=11.4 # for dGPU +make +``` + +## Step 3: Run the Demo +```shell +mv yolov6n.trt Deepstream # move yolov6n.trt file to Deepstream folder +deepstream-app -c ds_app_config_yoloV6.txt +``` + +## Additional Experimental Results + +**The TensorRT performance of Yolov6n was tested on COCO2017 val datasets, as shown below.** + +--------------------- +|Yolov6N(640) FP16 | AP | area | maxDet | +|-----------------|------------------------|-------------|-----------------------| +|Average Precision| (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.357 | +|Average Precision| (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.512 | +|Average Precision| (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.383 | +|Average Precision| (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.170 | +|Average Precision| (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.398 | +|Average Precision| (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.511 | +|Average Recall | (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.299 | +|Average Recall | (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.490 | +|Average Recall | (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.538 | +|Average Recall | (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.296 | +|Average Recall | (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.605 | +|Average Recall | (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.729 | \ No newline at end of file diff --git a/deploy/Deepstream/config_infer_prinary_yoloV6.txt b/deploy/Deepstream/config_infer_prinary_yoloV6.txt new file mode 100644 index 00000000..66b303a3 --- /dev/null +++ b/deploy/Deepstream/config_infer_prinary_yoloV6.txt @@ -0,0 +1,94 @@ +################################################################################ +# Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. +################################################################################ + +# Following properties are mandatory when engine files are not specified: +# int8-calib-file(Only in INT8), model-file-format +# Caffemodel mandatory properties: model-file, proto-file, output-blob-names +# UFF: uff-file, input-dims, uff-input-blob-name, output-blob-names +# ONNX: onnx-file +# +# Mandatory properties for detectors: +# num-detected-classes +# +# Optional properties for detectors: +# cluster-mode(Default=Group Rectangles), interval(Primary mode only, Default=0) +# custom-lib-path +# parse-bbox-func-name +# +# Mandatory properties for classifiers: +# classifier-threshold, is-classifier +# +# Optional properties for classifiers: +# classifier-async-mode(Secondary mode only, Default=false) +# +# Optional properties in secondary mode: +# operate-on-gie-id(Default=0), operate-on-class-ids(Defaults to all classes), +# input-object-min-width, input-object-min-height, input-object-max-width, +# input-object-max-height +# +# Following properties are always recommended: +# batch-size(Default=1) +# +# Other optional properties: +# net-scale-factor(Default=1), network-mode(Default=0 i.e FP32), +# model-color-format(Default=0 i.e. RGB) model-engine-file, labelfile-path, +# mean-file, gie-unique-id(Default=0), offsets, process-mode (Default=1 i.e. primary), +# custom-lib-path, network-mode(Default=0 i.e FP32) +# +# The values in the config file are overridden by values set through GObject +# properties. + +[property] +gpu-id=0 +net-scale-factor=0.0039215697906911373 +maintain-aspect-ratio=1 +symmetric-padding=1 +scaling-filter=1 +scaling-compute-hw=0 +#0=RGB, 1=BGR +model-color-format=0 + +model-engine-file=yolov6n.trt +force-implicit-batch-dim=1 +batch-size=1 + +labelfile-path=labels.txt +# int8-calib-file=yolov3-calibration.table.trt7.0 +## 0=FP32, 1=INT8, 2=FP16 mode +network-mode=2 +num-detected-classes=80 +gie-unique-id=1 +# Integer 0: Detector 1: Classifier +network-type=0 +# is-classifier=0 +## 1=DBSCAN, 2=NMS, 3= DBSCAN+NMS Hybrid, 4 = None(No clustering) +cluster-mode=4 +# lib path +parse-bbox-func-name=NvDsInferParseCustomYoloV6 +custom-lib-path=nvdsparsebbox_YoloV6/libnvdsinfer_custom_impl_Yolov6.so +# engine-create-func-name=NvDsInferYoloCudaEngineGet +#scaling-filter=0 +#scaling-compute-hw=0 + +[class-attrs-all] +nms-iou-threshold=0.3 +pre-cluster-threshold=0.7 \ No newline at end of file diff --git a/deploy/Deepstream/ds_app_config_yoloV6.txt b/deploy/Deepstream/ds_app_config_yoloV6.txt new file mode 100644 index 00000000..42a8cf44 --- /dev/null +++ b/deploy/Deepstream/ds_app_config_yoloV6.txt @@ -0,0 +1,136 @@ +################################################################################ +# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. +################################################################################ + +[application] +enable-perf-measurement=1 +perf-measurement-interval-sec=5 +#gie-kitti-output-dir=streamscl + +[tiled-display] +enable=1 +rows=1 +columns=1 +width=1920 +height=1080 +gpu-id=0 +#(0): nvbuf-mem-default - Default memory allocated, specific to particular platform +#(1): nvbuf-mem-cuda-pinned - Allocate Pinned/Host cuda memory, applicable for Tesla +#(2): nvbuf-mem-cuda-device - Allocate Device cuda memory, applicable for Tesla +#(3): nvbuf-mem-cuda-unified - Allocate Unified cuda memory, applicable for Tesla +#(4): nvbuf-mem-surface-array - Allocate Surface Array memory, applicable for Jetson +nvbuf-memory-type=0 + +[source0] +enable=1 +#Type - 1=CameraV4L2 2=URI 3=MultiURI +type=3 +uri=file://./sample_1080p_h264.mp4 +num-sources=1 +gpu-id=0 +# (0): memtype_device - Memory type Device +# (1): memtype_pinned - Memory type Host Pinned +# (2): memtype_unified - Memory type Unified +cudadec-memtype=0 + +[sink0] +enable=1 +#Type - 1=FakeSink 2=EglSink 3=File +type=3 +sync=0 +source-id=0 +gpu-id=0 +nvbuf-memory-type=0 + +container=1 +codec=1 +bitrate=4000000 +iframeinterval=30 +output-file=output_yolov6.mp4 + +[osd] +enable=1 +gpu-id=0 +border-width=1 +text-size=15 +text-color=1;1;1;1; +text-bg-color=0.3;0.3;0.3;1 +font=Serif +show-clock=0 +clock-x-offset=800 +clock-y-offset=820 +clock-text-size=12 +clock-color=1;0;0;0 +nvbuf-memory-type=0 + +[streammux] +gpu-id=0 +##Boolean property to inform muxer that sources are live +live-source=0 +batch-size=1 +##time out in usec, to wait after the first buffer is available +##to push the batch even if the complete batch is not formed +batched-push-timeout=40000 +## Set muxer output width and height +width=1920 +height=1080 +##Enable to maintain aspect ratio wrt source, and allow black borders, works +##along with width, height properties +enable-padding=0 +nvbuf-memory-type=0 + +# config-file property is mandatory for any gie section. +# Other properties are optional and if set will override the properties set in +# the infer config file. +[primary-gie] +enable=1 +gpu-id=0 +#model-engine-file=model_b1_gpu0_int8.engine +labelfile-path=labels.txt +batch-size=1 +#Required by the app for OSD, not a plugin property +bbox-border-color0=1;0;0;1 +bbox-border-color1=0;1;1;1 +bbox-border-color2=0;0;1;1 +bbox-border-color3=0;1;0;1 +interval=2 +gie-unique-id=1 +nvbuf-memory-type=0 +config-file=config_infer_prinary_yoloV6.txt + +[tracker] +enable=1 +# For NvDCF and DeepSORT tracker, tracker-width and tracker-height must be a multiple of 32, respectively +tracker-width=640 +tracker-height=384 +ll-lib-file=/opt/nvidia/deepstream/deepstream-6.0/lib/libnvds_nvmultiobjecttracker.so +# ll-config-file required to set different tracker types +# ll-config-file=/opt/nvidia/deepstream/deepstream-6.0//samples/configs/deepstream-app/config_tracker_IOU.yml +ll-config-file=/opt/nvidia/deepstream/deepstream-6.0/samples/configs/deepstream-app/config_tracker_NvDCF_perf.yml +# ll-config-file=/opt/nvidia/deepstream/deepstream-6.0//samples/configs/deepstream-app/config_tracker_NvDCF_accuracy.yml +# ll-config-file=/opt/nvidia/deepstream/deepstream-6.0//samples/configs/deepstream-app/config_tracker_DeepSORT.yml +gpu-id=0 +enable-batch-process=1 +enable-past-frame=1 +display-tracking-id=1 + +[tests] +file-loop=0 \ No newline at end of file diff --git a/deploy/Deepstream/labels.txt b/deploy/Deepstream/labels.txt new file mode 100644 index 00000000..16315f2b --- /dev/null +++ b/deploy/Deepstream/labels.txt @@ -0,0 +1,80 @@ +person +bicycle +car +motorbike +aeroplane +bus +train +truck +boat +traffic light +fire hydrant +stop sign +parking meter +bench +bird +cat +dog +horse +sheep +cow +elephant +bear +zebra +giraffe +backpack +umbrella +handbag +tie +suitcase +frisbee +skis +snowboard +sports ball +kite +baseball bat +baseball glove +skateboard +surfboard +tennis racket +bottle +wine glass +cup +fork +knife +spoon +bowl +banana +apple +sandwich +orange +broccoli +carrot +hot dog +pizza +donut +cake +chair +sofa +pottedplant +bed +diningtable +toilet +tvmonitor +laptop +mouse +remote +keyboard +cell phone +microwave +oven +toaster +sink +refrigerator +book +clock +vase +scissors +teddy bear +hair drier +toothbrush \ No newline at end of file diff --git a/deploy/Deepstream/nvdsparsebbox_YoloV6/Makefile b/deploy/Deepstream/nvdsparsebbox_YoloV6/Makefile new file mode 100644 index 00000000..b5c83fe5 --- /dev/null +++ b/deploy/Deepstream/nvdsparsebbox_YoloV6/Makefile @@ -0,0 +1,56 @@ +################################################################################ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. +################################################################################ + +CUDA_VER?=11.4 +ifeq ($(CUDA_VER),) + $(error "CUDA_VER is not set") +endif +CC:= g++ +NVCC:=/usr/local/cuda-$(CUDA_VER)/bin/nvcc + +CFLAGS:= -Wall -std=c++11 -shared -fPIC -Wno-error=deprecated-declarations +CFLAGS+= -I/opt/nvidia/deepstream/deepstream-6.0/sources/includes -I/usr/local/cuda-$(CUDA_VER)/include + +LIBS:= -L/usr/local/TensorRT-8.0.1.6/lib -lnvinfer_plugin -lnvinfer -lnvparsers -L/usr/local/cuda-$(CUDA_VER)/lib64 -lcudart -lcublas -lstdc++fs +LFLAGS:= -shared -Wl,--start-group $(LIBS) -Wl,--end-group + +INCS:= $(wildcard *.h) +SRCFILES:= parsebbox_YoloV6v2.cpp trt_utils.cpp + +TARGET_LIB:= libnvdsinfer_custom_impl_Yolov6.so + +TARGET_OBJS:= $(SRCFILES:.cpp=.o) +TARGET_OBJS:= $(TARGET_OBJS:.cu=.o) + +all: $(TARGET_LIB) + +%.o: %.cpp $(INCS) Makefile + $(CC) -c -o $@ $(CFLAGS) $< + +%.o: %.cu $(INCS) Makefile + $(NVCC) -c -o $@ --compiler-options '-fPIC' $< + +$(TARGET_LIB) : $(TARGET_OBJS) + $(CC) -o $@ $(TARGET_OBJS) $(LFLAGS) + +clean: + rm -rf $(TARGET_LIB) \ No newline at end of file diff --git a/deploy/Deepstream/nvdsparsebbox_YoloV6/libnvdsinfer_custom_impl_Yolov6.so b/deploy/Deepstream/nvdsparsebbox_YoloV6/libnvdsinfer_custom_impl_Yolov6.so new file mode 100755 index 00000000..0d29767f Binary files /dev/null and b/deploy/Deepstream/nvdsparsebbox_YoloV6/libnvdsinfer_custom_impl_Yolov6.so differ diff --git a/deploy/Deepstream/nvdsparsebbox_YoloV6/parsebbox_YoloV6v2.cpp b/deploy/Deepstream/nvdsparsebbox_YoloV6/parsebbox_YoloV6v2.cpp new file mode 100644 index 00000000..c5af0a7f --- /dev/null +++ b/deploy/Deepstream/nvdsparsebbox_YoloV6/parsebbox_YoloV6v2.cpp @@ -0,0 +1,327 @@ +/* + * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ + +#include +#include +#include +#include +#include +#include +#include +#include "nvdsinfer_custom_impl.h" +#include "trt_utils.h" + +static const int NUM_CLASSES_YOLO = 80; +static const float NMS_THRESHOLD = 0.65; +static const float CLS_THRESHOLD = 0.5; + +extern "C" bool NvDsInferParseCustomYoloV6( + std::vector const& outputLayersInfo, + NvDsInferNetworkInfo const& networkInfo, + NvDsInferParseDetectionParams const& detectionParams, + std::vector& objectList); + +/* This is a sample bounding box parsing function for the sample YoloV3 detector model */ +static NvDsInferParseObjectInfo convertBBox(const float& bx, const float& by, const float& bw, + const float& bh, const uint& netW, const uint& netH) +{ + NvDsInferParseObjectInfo b; + // Restore coordinates to network input resolution + float xCenter = bx ; + float yCenter = by ; + float x0 = xCenter - bw / 2; + float y0 = yCenter - bh / 2; + float x1 = x0 + bw; + float y1 = y0 + bh; + + x0 = clamp(x0, 0, netW); + y0 = clamp(y0, 0, netH); + x1 = clamp(x1, 0, netW); + y1 = clamp(y1, 0, netH); + + b.left = x0; + b.width = clamp(x1 - x0, 0, netW); + b.top = y0; + b.height = clamp(y1 - y0, 0, netH); + + return b; +} + +static void addBBoxProposal(const float bx, const float by, const float bw, const float bh, + const uint& netW, const uint& netH, const int maxIndex, + const float maxProb, std::vector& binfo) +{ + NvDsInferParseObjectInfo bbi = convertBBox(bx, by, bw, bh, netW, netH); + if (bbi.width < 1 || bbi.height < 1) return; + + bbi.detectionConfidence = maxProb; + bbi.classId = maxIndex; + binfo.push_back(bbi); +} + +/* +static inline std::vector +SortLayers(const std::vector & outputLayersInfo) +{ + std::vector outLayers; + for (auto const &layer : outputLayersInfo) + { + outLayers.push_back (&layer); + } + std::sort(outLayers.begin(), outLayers.end(), + [](const NvDsInferLayerInfo* a, const NvDsInferLayerInfo* b) { + return a->inferDims.d[1] < b->inferDims.d[1]; + }); + return outLayers; +} +*/ +/* +static std::vector +decodeYoloV6Tensor( + const float* detections, + const uint gridSizeW, const uint gridSizeH, const uint stride, const uint numBBoxes, + const uint numOutputClasses, const uint& netW, + const uint& netH) +{ + std::vector binfo; + for (uint y = 0; y < gridSizeH; ++y) { + for (uint x = 0; x < gridSizeW; ++x) { + for (uint b = 0; b < numBBoxes; ++b) + { + const int numGridCells = gridSizeH * gridSizeW; + const int bbindex = y * gridSizeW + x; + const float bx + = x + detections[bbindex + numGridCells * (b * (5 + numOutputClasses) + 0)]; + const float by + = y + detections[bbindex + numGridCells * (b * (5 + numOutputClasses) + 1)]; + const float bw + = stride * exp(detections[bbindex + numGridCells * (b * (5 + numOutputClasses) + 2)]); + const float bh + = stride * exp(detections[bbindex + numGridCells * (b * (5 + numOutputClasses) + 3)]); + + const float objectness + = detections[bbindex + numGridCells * (b * (5 + numOutputClasses) + 4)]; + + float maxProb = 0.0f; + int maxIndex = -1; + + for (uint i = 0; i < numOutputClasses; ++i) + { + float prob + = (detections[bbindex + + numGridCells * (b * (5 + numOutputClasses) + (5 + i))]); + + if (prob > maxProb) + { + maxProb = prob; + maxIndex = i; + } + } + maxProb = objectness * maxProb; + + addBBoxProposal(bx, by, bw, bh, stride, netW, netH, maxIndex, maxProb, binfo); + } + } + } + return binfo; +} +*/ + +static std::vector +nonMaximumSuppression(const float nmsThresh, std::vector binfo) +{ + auto overlap1D = [](float x1min, float x1max, float x2min, float x2max) -> float { + if (x1min > x2min) + { + std::swap(x1min, x2min); + std::swap(x1max, x2max); + } + return x1max < x2min ? 0 : std::min(x1max, x2max) - x2min; + }; + + auto computeIoU + = [&overlap1D](NvDsInferParseObjectInfo& bbox1, NvDsInferParseObjectInfo& bbox2) -> float { + float overlapX + = overlap1D(bbox1.left, bbox1.left + bbox1.width, bbox2.left, bbox2.left + bbox2.width); + float overlapY + = overlap1D(bbox1.top, bbox1.top + bbox1.height, bbox2.top, bbox2.top + bbox2.height); + float area1 = (bbox1.width) * (bbox1.height); + float area2 = (bbox2.width) * (bbox2.height); + float overlap2D = overlapX * overlapY; + float u = area1 + area2 - overlap2D; + return u == 0 ? 0 : overlap2D / u; + }; + + std::stable_sort(binfo.begin(), binfo.end(), + [](const NvDsInferParseObjectInfo& b1, const NvDsInferParseObjectInfo& b2) { + return b1.detectionConfidence > b2.detectionConfidence; + }); + + std::vector out; + for (auto i : binfo) + { + bool keep = true; + for (auto j : out) + { + if (keep) + { + float overlap = computeIoU(i, j); + keep = overlap <= nmsThresh; + } + else + break; + } + if (keep) out.push_back(i); + } + return out; +} + + +static std::vector +parseYoloV6BBox(const NvDsInferLayerInfo& feat, const uint numOutputClasses, const uint& netW, + const uint& netH) +{ + std::vector> binfo; + binfo.resize(NUM_CLASSES_YOLO); + + const float* detections = (const float*)feat.buffer; + auto numBBoxes = feat.inferDims.d[0]; + const int numBBoxCells = feat.inferDims.d[1]; + + for (uint b = 0; b < numBBoxes; ++b) + { + const float bx + = detections[b * numBBoxCells + 0]; + const float by + = detections[b * numBBoxCells + 1]; + const float bw + = detections[b * numBBoxCells + 2]; + const float bh + = detections[b * numBBoxCells + 3]; + + const float objectness + = detections[b * numBBoxCells + 4]; + + float maxProb = 0.0f; + int maxIndex = -1; + + for (uint i = 0; i < numOutputClasses; ++i) + { + float prob + = (detections[b * numBBoxCells + (5 + i)]); + + if (prob > maxProb) + { + maxProb = prob; + maxIndex = i; + } + } + maxProb = objectness * maxProb; + if(maxProb > CLS_THRESHOLD) + { + // std::vector bboxInfo; + addBBoxProposal(bx, by, bw, bh, netW, netH, maxIndex, maxProb, binfo[maxIndex]); + // binfo[maxIndex].push_back(bboxInfo); + } + } + // NMS + std::vector objects = {}; + for(int cls_id = 0; cls_id < NUM_CLASSES_YOLO; ++cls_id) + { + std::vector outObjs = nonMaximumSuppression(NMS_THRESHOLD, binfo[cls_id]); + objects.insert(objects.end(), outObjs.begin(), outObjs.end()); + } + + return objects; +} + +static bool NvDsInferParseYoloV6( + std::vector const& outputLayersInfo, + NvDsInferNetworkInfo const& networkInfo, + NvDsInferParseDetectionParams const& detectionParams, + std::vector& objectList) +{ + // const uint kNUM_BBOXES = 1; + // const uint kLAYER_NUM = 1; + + // const std::vector sortedLayers = + // SortLayers (outputLayersInfo); + + // if (sortedLayers.size() != kLAYER_NUM) { + // std::cerr << "ERROR: yoloV6 output layer.size: " << sortedLayers.size() + // << " does not match mask.size: " << kLAYER_NUM << std::endl; + // return false; + // } + + if (NUM_CLASSES_YOLO != detectionParams.numClassesConfigured) + { + std::cerr << "WARNING: Num classes mismatch. Configured:" + << detectionParams.numClassesConfigured + << ", detected by network: " << NUM_CLASSES_YOLO << std::endl; + } + + + + // for (uint idx = 0; idx < sortedLayers.size(); ++idx) + // { + // dimensions: batch, 8400, 85 + // bbox order: center_x, center_y, width, height + const NvDsInferLayerInfo &layer = outputLayersInfo[0]; + + assert(layer.inferDims.numDims == 2); + + + + /* + const uint gridSizeH = layer.inferDims.d[1]; + const uint gridSizeW = layer.inferDims.d[2]; + const uint stride = DIVUP(networkInfo.width, gridSizeW); + assert(stride == DIVUP(networkInfo.height, gridSizeH)); + + std::vector outObjs = + decodeYoloV6Tensor((const float*)(layer.buffer), gridSizeW, gridSizeH, stride, kNUM_BBOXES, + NUM_CLASSES_YOLO, networkInfo.width, networkInfo.height); + */ + // objects.insert(objects.end(), outObjs.begin(), outObjs.end()); + // } + + std::vector objects = parseYoloV6BBox( + layer, NUM_CLASSES_YOLO, networkInfo.width, networkInfo.height ); + + objectList = objects; // 赋值运算符被调用 + + return true; +} + +/* C-linkage to prevent name-mangling */ +extern "C" bool NvDsInferParseCustomYoloV6( + std::vector const& outputLayersInfo, + NvDsInferNetworkInfo const& networkInfo, + NvDsInferParseDetectionParams const& detectionParams, + std::vector& objectList) +{ + return NvDsInferParseYoloV6 ( + outputLayersInfo, networkInfo, detectionParams, objectList); +} + +/* Check that the custom function has been defined correctly */ +CHECK_CUSTOM_PARSE_FUNC_PROTOTYPE(NvDsInferParseCustomYoloV6); \ No newline at end of file diff --git a/deploy/Deepstream/nvdsparsebbox_YoloV6/parsebbox_YoloV6v2.o b/deploy/Deepstream/nvdsparsebbox_YoloV6/parsebbox_YoloV6v2.o new file mode 100644 index 00000000..88bdbcad Binary files /dev/null and b/deploy/Deepstream/nvdsparsebbox_YoloV6/parsebbox_YoloV6v2.o differ diff --git a/deploy/Deepstream/nvdsparsebbox_YoloV6/trt_utils.cpp b/deploy/Deepstream/nvdsparsebbox_YoloV6/trt_utils.cpp new file mode 100644 index 00000000..6c356cae --- /dev/null +++ b/deploy/Deepstream/nvdsparsebbox_YoloV6/trt_utils.cpp @@ -0,0 +1,447 @@ +/* + * Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ + +#include "trt_utils.h" + +#include +#include +#include +#include +#include +#include + +#include "NvInferPlugin.h" + +static void leftTrim(std::string& s) +{ + s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int ch) { return !isspace(ch); })); +} + +static void rightTrim(std::string& s) +{ + s.erase(std::find_if(s.rbegin(), s.rend(), [](int ch) { return !isspace(ch); }).base(), s.end()); +} + +std::string trim(std::string s) +{ + leftTrim(s); + rightTrim(s); + return s; +} + +float clamp(const float val, const float minVal, const float maxVal) +{ + assert(minVal <= maxVal); + return std::min(maxVal, std::max(minVal, val)); +} + +bool fileExists(const std::string fileName, bool verbose) +{ + if (!std::experimental::filesystem::exists(std::experimental::filesystem::path(fileName))) + { + if (verbose) std::cout << "File does not exist : " << fileName << std::endl; + return false; + } + return true; +} + +std::vector loadWeights(const std::string weightsFilePath, const std::string& networkType) +{ + assert(fileExists(weightsFilePath)); + std::cout << "Loading pre-trained weights..." << std::endl; + std::ifstream file(weightsFilePath, std::ios_base::binary); + assert(file.good()); + std::string line; + + if (networkType == "yolov2") + { + // Remove 4 int32 bytes of data from the stream belonging to the header + file.ignore(4 * 4); + } + else if ((networkType == "yolov3") || (networkType == "yolov3-tiny") + || (networkType == "yolov2-tiny")) + { + // Remove 5 int32 bytes of data from the stream belonging to the header + file.ignore(4 * 5); + } + else + { + std::cout << "Invalid network type" << std::endl; + assert(0); + } + + std::vector weights; + char floatWeight[4]; + while (!file.eof()) + { + file.read(floatWeight, 4); + assert(file.gcount() == 4); + weights.push_back(*reinterpret_cast(floatWeight)); + if (file.peek() == std::istream::traits_type::eof()) break; + } + std::cout << "Loading weights of " << networkType << " complete!" + << std::endl; + std::cout << "Total Number of weights read : " << weights.size() << std::endl; + return weights; +} + +std::string dimsToString(const nvinfer1::Dims d) +{ + std::stringstream s; + assert(d.nbDims >= 1); + for (int i = 0; i < d.nbDims - 1; ++i) + { + s << std::setw(4) << d.d[i] << " x"; + } + s << std::setw(4) << d.d[d.nbDims - 1]; + + return s.str(); +} + +int getNumChannels(nvinfer1::ITensor* t) +{ + nvinfer1::Dims d = t->getDimensions(); + assert(d.nbDims == 3); + + return d.d[0]; +} + +uint64_t get3DTensorVolume(nvinfer1::Dims inputDims) +{ + assert(inputDims.nbDims == 3); + return inputDims.d[0] * inputDims.d[1] * inputDims.d[2]; +} + +nvinfer1::ILayer* netAddMaxpool(int layerIdx, std::map& block, + nvinfer1::ITensor* input, nvinfer1::INetworkDefinition* network) +{ + assert(block.at("type") == "maxpool"); + assert(block.find("size") != block.end()); + assert(block.find("stride") != block.end()); + + int size = std::stoi(block.at("size")); + int stride = std::stoi(block.at("stride")); + + nvinfer1::IPoolingLayer* pool + = network->addPooling(*input, nvinfer1::PoolingType::kMAX, nvinfer1::DimsHW{size, size}); + assert(pool); + std::string maxpoolLayerName = "maxpool_" + std::to_string(layerIdx); + pool->setStride(nvinfer1::DimsHW{stride, stride}); + pool->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER); + pool->setName(maxpoolLayerName.c_str()); + + return pool; +} + +nvinfer1::ILayer* netAddConvLinear(int layerIdx, std::map& block, + std::vector& weights, + std::vector& trtWeights, int& weightPtr, + int& inputChannels, nvinfer1::ITensor* input, + nvinfer1::INetworkDefinition* network) +{ + assert(block.at("type") == "convolutional"); + assert(block.find("batch_normalize") == block.end()); + assert(block.at("activation") == "linear"); + assert(block.find("filters") != block.end()); + assert(block.find("pad") != block.end()); + assert(block.find("size") != block.end()); + assert(block.find("stride") != block.end()); + + int filters = std::stoi(block.at("filters")); + int padding = std::stoi(block.at("pad")); + int kernelSize = std::stoi(block.at("size")); + int stride = std::stoi(block.at("stride")); + int pad; + if (padding) + pad = (kernelSize - 1) / 2; + else + pad = 0; + // load the convolution layer bias + nvinfer1::Weights convBias{nvinfer1::DataType::kFLOAT, nullptr, filters}; + float* val = new float[filters]; + for (int i = 0; i < filters; ++i) + { + val[i] = weights[weightPtr]; + weightPtr++; + } + convBias.values = val; + trtWeights.push_back(convBias); + // load the convolutional layer weights + int size = filters * inputChannels * kernelSize * kernelSize; + nvinfer1::Weights convWt{nvinfer1::DataType::kFLOAT, nullptr, size}; + val = new float[size]; + for (int i = 0; i < size; ++i) + { + val[i] = weights[weightPtr]; + weightPtr++; + } + convWt.values = val; + trtWeights.push_back(convWt); + nvinfer1::IConvolutionLayer* conv = network->addConvolution( + *input, filters, nvinfer1::DimsHW{kernelSize, kernelSize}, convWt, convBias); + assert(conv != nullptr); + std::string convLayerName = "conv_" + std::to_string(layerIdx); + conv->setName(convLayerName.c_str()); + conv->setStride(nvinfer1::DimsHW{stride, stride}); + conv->setPadding(nvinfer1::DimsHW{pad, pad}); + + return conv; +} + +nvinfer1::ILayer* netAddConvBNLeaky(int layerIdx, std::map& block, + std::vector& weights, + std::vector& trtWeights, int& weightPtr, + int& inputChannels, nvinfer1::ITensor* input, + nvinfer1::INetworkDefinition* network) +{ + assert(block.at("type") == "convolutional"); + assert(block.find("batch_normalize") != block.end()); + assert(block.at("batch_normalize") == "1"); + assert(block.at("activation") == "leaky"); + assert(block.find("filters") != block.end()); + assert(block.find("pad") != block.end()); + assert(block.find("size") != block.end()); + assert(block.find("stride") != block.end()); + + bool batchNormalize, bias; + if (block.find("batch_normalize") != block.end()) + { + batchNormalize = (block.at("batch_normalize") == "1"); + bias = false; + } + else + { + batchNormalize = false; + bias = true; + } + // all conv_bn_leaky layers assume bias is false + assert(batchNormalize == true && bias == false); + UNUSED(batchNormalize); + UNUSED(bias); + + int filters = std::stoi(block.at("filters")); + int padding = std::stoi(block.at("pad")); + int kernelSize = std::stoi(block.at("size")); + int stride = std::stoi(block.at("stride")); + int pad; + if (padding) + pad = (kernelSize - 1) / 2; + else + pad = 0; + + /***** CONVOLUTION LAYER *****/ + /*****************************/ + // batch norm weights are before the conv layer + // load BN biases (bn_biases) + std::vector bnBiases; + for (int i = 0; i < filters; ++i) + { + bnBiases.push_back(weights[weightPtr]); + weightPtr++; + } + // load BN weights + std::vector bnWeights; + for (int i = 0; i < filters; ++i) + { + bnWeights.push_back(weights[weightPtr]); + weightPtr++; + } + // load BN running_mean + std::vector bnRunningMean; + for (int i = 0; i < filters; ++i) + { + bnRunningMean.push_back(weights[weightPtr]); + weightPtr++; + } + // load BN running_var + std::vector bnRunningVar; + for (int i = 0; i < filters; ++i) + { + // 1e-05 for numerical stability + bnRunningVar.push_back(sqrt(weights[weightPtr] + 1.0e-5)); + weightPtr++; + } + // load Conv layer weights (GKCRS) + int size = filters * inputChannels * kernelSize * kernelSize; + nvinfer1::Weights convWt{nvinfer1::DataType::kFLOAT, nullptr, size}; + float* val = new float[size]; + for (int i = 0; i < size; ++i) + { + val[i] = weights[weightPtr]; + weightPtr++; + } + convWt.values = val; + trtWeights.push_back(convWt); + nvinfer1::Weights convBias{nvinfer1::DataType::kFLOAT, nullptr, 0}; + trtWeights.push_back(convBias); + nvinfer1::IConvolutionLayer* conv = network->addConvolution( + *input, filters, nvinfer1::DimsHW{kernelSize, kernelSize}, convWt, convBias); + assert(conv != nullptr); + std::string convLayerName = "conv_" + std::to_string(layerIdx); + conv->setName(convLayerName.c_str()); + conv->setStride(nvinfer1::DimsHW{stride, stride}); + conv->setPadding(nvinfer1::DimsHW{pad, pad}); + + /***** BATCHNORM LAYER *****/ + /***************************/ + size = filters; + // create the weights + nvinfer1::Weights shift{nvinfer1::DataType::kFLOAT, nullptr, size}; + nvinfer1::Weights scale{nvinfer1::DataType::kFLOAT, nullptr, size}; + nvinfer1::Weights power{nvinfer1::DataType::kFLOAT, nullptr, size}; + float* shiftWt = new float[size]; + for (int i = 0; i < size; ++i) + { + shiftWt[i] + = bnBiases.at(i) - ((bnRunningMean.at(i) * bnWeights.at(i)) / bnRunningVar.at(i)); + } + shift.values = shiftWt; + float* scaleWt = new float[size]; + for (int i = 0; i < size; ++i) + { + scaleWt[i] = bnWeights.at(i) / bnRunningVar[i]; + } + scale.values = scaleWt; + float* powerWt = new float[size]; + for (int i = 0; i < size; ++i) + { + powerWt[i] = 1.0; + } + power.values = powerWt; + trtWeights.push_back(shift); + trtWeights.push_back(scale); + trtWeights.push_back(power); + // Add the batch norm layers + nvinfer1::IScaleLayer* bn = network->addScale( + *conv->getOutput(0), nvinfer1::ScaleMode::kCHANNEL, shift, scale, power); + assert(bn != nullptr); + std::string bnLayerName = "batch_norm_" + std::to_string(layerIdx); + bn->setName(bnLayerName.c_str()); + /***** ACTIVATION LAYER *****/ + /****************************/ + nvinfer1::ITensor* bnOutput = bn->getOutput(0); + nvinfer1::IActivationLayer* leaky = network->addActivation( + *bnOutput, nvinfer1::ActivationType::kLEAKY_RELU); + leaky->setAlpha(0.1); + assert(leaky != nullptr); + std::string leakyLayerName = "leaky_" + std::to_string(layerIdx); + leaky->setName(leakyLayerName.c_str()); + + return leaky; +} + +nvinfer1::ILayer* netAddUpsample(int layerIdx, std::map& block, + std::vector& weights, + std::vector& trtWeights, int& inputChannels, + nvinfer1::ITensor* input, nvinfer1::INetworkDefinition* network) +{ + assert(block.at("type") == "upsample"); + nvinfer1::Dims inpDims = input->getDimensions(); + assert(inpDims.nbDims == 3); + assert(inpDims.d[1] == inpDims.d[2]); + int h = inpDims.d[1]; + int w = inpDims.d[2]; + int stride = std::stoi(block.at("stride")); + // add pre multiply matrix as a constant + nvinfer1::Dims preDims{3, + {1, stride * h, w}}; + + int size = stride * h * w; + nvinfer1::Weights preMul{nvinfer1::DataType::kFLOAT, nullptr, size}; + float* preWt = new float[size]; + /* (2*h * w) + [ [1, 0, ..., 0], + [1, 0, ..., 0], + [0, 1, ..., 0], + [0, 1, ..., 0], + ..., + ..., + [0, 0, ..., 1], + [0, 0, ..., 1] ] + */ + for (int i = 0, idx = 0; i < h; ++i) + { + for (int s = 0; s < stride; ++s) + { + for (int j = 0; j < w; ++j, ++idx) + { + preWt[idx] = (i == j) ? 1.0 : 0.0; + } + } + } + preMul.values = preWt; + trtWeights.push_back(preMul); + nvinfer1::IConstantLayer* preM = network->addConstant(preDims, preMul); + assert(preM != nullptr); + std::string preLayerName = "preMul_" + std::to_string(layerIdx); + preM->setName(preLayerName.c_str()); + // add post multiply matrix as a constant + nvinfer1::Dims postDims{3, + {1, h, stride * w}}; + + size = stride * h * w; + nvinfer1::Weights postMul{nvinfer1::DataType::kFLOAT, nullptr, size}; + float* postWt = new float[size]; + /* (h * 2*w) + [ [1, 1, 0, 0, ..., 0, 0], + [0, 0, 1, 1, ..., 0, 0], + ..., + ..., + [0, 0, 0, 0, ..., 1, 1] ] + */ + for (int i = 0, idx = 0; i < h; ++i) + { + for (int j = 0; j < stride * w; ++j, ++idx) + { + postWt[idx] = (j / stride == i) ? 1.0 : 0.0; + } + } + postMul.values = postWt; + trtWeights.push_back(postMul); + nvinfer1::IConstantLayer* post_m = network->addConstant(postDims, postMul); + assert(post_m != nullptr); + std::string postLayerName = "postMul_" + std::to_string(layerIdx); + post_m->setName(postLayerName.c_str()); + // add matrix multiply layers for upsampling + nvinfer1::IMatrixMultiplyLayer* mm1 + = network->addMatrixMultiply(*preM->getOutput(0), nvinfer1::MatrixOperation::kNONE, *input, + nvinfer1::MatrixOperation::kNONE); + assert(mm1 != nullptr); + std::string mm1LayerName = "mm1_" + std::to_string(layerIdx); + mm1->setName(mm1LayerName.c_str()); + nvinfer1::IMatrixMultiplyLayer* mm2 + = network->addMatrixMultiply(*mm1->getOutput(0), nvinfer1::MatrixOperation::kNONE, + *post_m->getOutput(0), nvinfer1::MatrixOperation::kNONE); + assert(mm2 != nullptr); + std::string mm2LayerName = "mm2_" + std::to_string(layerIdx); + mm2->setName(mm2LayerName.c_str()); + return mm2; +} + +void printLayerInfo(std::string layerIndex, std::string layerName, std::string layerInput, + std::string layerOutput, std::string weightPtr) +{ + std::cout << std::setw(6) << std::left << layerIndex << std::setw(15) << std::left << layerName; + std::cout << std::setw(20) << std::left << layerInput << std::setw(20) << std::left + << layerOutput; + std::cout << std::setw(6) << std::left << weightPtr << std::endl; +} \ No newline at end of file diff --git a/deploy/Deepstream/nvdsparsebbox_YoloV6/trt_utils.h b/deploy/Deepstream/nvdsparsebbox_YoloV6/trt_utils.h new file mode 100644 index 00000000..61cd7f87 --- /dev/null +++ b/deploy/Deepstream/nvdsparsebbox_YoloV6/trt_utils.h @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ + + +#ifndef __TRT_UTILS_H__ +#define __TRT_UTILS_H__ + +#include +#include +#include +#include +#include +#include +#include + +#include "NvInfer.h" + +#define UNUSED(expr) (void)(expr) +#define DIVUP(n, d) ((n) + (d)-1) / (d) + +std::string trim(std::string s); +float clamp(const float val, const float minVal, const float maxVal); +bool fileExists(const std::string fileName, bool verbose = true); +std::vector loadWeights(const std::string weightsFilePath, const std::string& networkType); +std::string dimsToString(const nvinfer1::Dims d); +int getNumChannels(nvinfer1::ITensor* t); +uint64_t get3DTensorVolume(nvinfer1::Dims inputDims); + +// Helper functions to create yolo engine +nvinfer1::ILayer* netAddMaxpool(int layerIdx, std::map& block, + nvinfer1::ITensor* input, nvinfer1::INetworkDefinition* network); +nvinfer1::ILayer* netAddConvLinear(int layerIdx, std::map& block, + std::vector& weights, + std::vector& trtWeights, int& weightPtr, + int& inputChannels, nvinfer1::ITensor* input, + nvinfer1::INetworkDefinition* network); +nvinfer1::ILayer* netAddConvBNLeaky(int layerIdx, std::map& block, + std::vector& weights, + std::vector& trtWeights, int& weightPtr, + int& inputChannels, nvinfer1::ITensor* input, + nvinfer1::INetworkDefinition* network); +nvinfer1::ILayer* netAddUpsample(int layerIdx, std::map& block, + std::vector& weights, + std::vector& trtWeights, int& inputChannels, + nvinfer1::ITensor* input, nvinfer1::INetworkDefinition* network); +void printLayerInfo(std::string layerIndex, std::string layerName, std::string layerInput, + std::string layerOutput, std::string weightPtr); + +#endif \ No newline at end of file diff --git a/deploy/Deepstream/nvdsparsebbox_YoloV6/trt_utils.o b/deploy/Deepstream/nvdsparsebbox_YoloV6/trt_utils.o new file mode 100644 index 00000000..6c4c811c Binary files /dev/null and b/deploy/Deepstream/nvdsparsebbox_YoloV6/trt_utils.o differ