Merge pull request #20535 from SamFC10:onnx-q
authorJebastin Nadar <njebastin10@gmail.com>
Mon, 4 Oct 2021 18:07:38 +0000 (23:37 +0530)
committerGitHub <noreply@github.com>
Mon, 4 Oct 2021 18:07:38 +0000 (18:07 +0000)
dnn : int8 quantized layers support in onnx importer

* added quantized layers support in onnx importer

* added more cases in eltwise node, some more checks

* added tests for quantized nodes

* relax thresholds for failed tests, address review comments

* refactoring based on review comments

* added support for unsupported cases and pre-quantized resnet50 test

* relax thresholds due to int8 resize layer

modules/dnn/include/opencv2/dnn/all_layers.hpp
modules/dnn/src/dnn.cpp
modules/dnn/src/init.cpp
modules/dnn/src/int8layers/quantization_utils.cpp [moved from modules/dnn/src/int8layers/quantize_dequantize_layer.cpp with 73% similarity]
modules/dnn/src/layers/resize_layer.cpp
modules/dnn/src/onnx/onnx_graph_simplifier.cpp
modules/dnn/src/onnx/onnx_importer.cpp
modules/dnn/test/test_int8_layers.cpp
modules/dnn/test/test_onnx_importer.cpp

index cfe6595..9fde7ad 100644 (file)
@@ -387,6 +387,13 @@ CV__DNN_INLINE_NS_BEGIN
         static Ptr<DequantizeLayer> create(const LayerParams &params);
     };
 
+    class CV_EXPORTS RequantizeLayer : public Layer
+    {
+    public:
+        float scale, shift;
+        static Ptr<RequantizeLayer> create(const LayerParams &params);
+    };
+
     class CV_EXPORTS ConcatLayer : public Layer
     {
     public:
index 4e38b03..24daeb2 100644 (file)
@@ -4055,6 +4055,9 @@ int Net::addLayer(const String &name, const String &type, const int &dtype, Laye
     if (params.get<bool>("has_dynamic_shapes", false))
         impl->hasDynamicShapes = true;
 
+    if (dtype == CV_8S)
+        impl->netWasQuantized = true;
+
     return id;
 }
 
@@ -4389,7 +4392,7 @@ Net Net::quantize(InputArrayOfArrays calibData, int inputsDtype, int outputsDtyp
         // Layers with multiple outputs. Number of outputs is equal to number of inputs
         if (ld.type == "Blank" || ld.type == "Dropout" || ld.type == "Identity" || ld.type == "Silence" ||
             ld.type == "Flatten" || ld.type == "Padding" || ld.type == "Permute" || ld.type == "Reshape" ||
-            ld.type == "ReLU6" || ld.type == "Reorg" || ld.type == "ShuffleChannel" ||
+            ld.type == "ReLU6" || ld.type == "Reorg" || ld.type == "ShuffleChannel" || ld.type == "Resize" ||
            (ld.type == "ReLU" && !ld.params.get<float>("negative_slope", 0.f)) /* ReLU with negative slope 0 */)
         {
             for (int i = 0; i < ld.outputBlobs.size(); i++)
index 9d8a378..123cb17 100644 (file)
@@ -144,6 +144,7 @@ void initializeLayerFactory()
 
     CV_DNN_REGISTER_LAYER_CLASS(Quantize,         QuantizeLayer);
     CV_DNN_REGISTER_LAYER_CLASS(Dequantize,       DequantizeLayer);
+    CV_DNN_REGISTER_LAYER_CLASS(Requantize,       RequantizeLayer);
     CV_DNN_REGISTER_LAYER_CLASS(ConvolutionInt8,  ConvolutionLayerInt8);
     CV_DNN_REGISTER_LAYER_CLASS(InnerProductInt8, InnerProductLayerInt8);
     CV_DNN_REGISTER_LAYER_CLASS(PoolingInt8,      PoolingLayerInt8);
@@ -173,6 +174,7 @@ void initializeLayerFactory()
     CV_DNN_REGISTER_LAYER_CLASS(SilenceInt8,      BlankLayer);
     CV_DNN_REGISTER_LAYER_CLASS(ConstInt8,        ConstLayer);
     CV_DNN_REGISTER_LAYER_CLASS(ReshapeInt8,      ReshapeLayer);
+    CV_DNN_REGISTER_LAYER_CLASS(ResizeInt8,       ResizeLayer);
     CV_DNN_REGISTER_LAYER_CLASS(SplitInt8,        SplitLayer);
     CV_DNN_REGISTER_LAYER_CLASS(SliceInt8,        SliceLayer);
     CV_DNN_REGISTER_LAYER_CLASS(CropInt8,         CropLayer);
@@ -10,6 +10,7 @@ namespace cv
 namespace dnn
 {
 
+// Quantize FP32/FP16 Inputs to INT8
 class QuantizeLayerImpl CV_FINAL : public QuantizeLayer
 {
 public:
@@ -77,6 +78,7 @@ public:
     }
 };
 
+// Dequantize INT8 Inputs to FP32/FP16
 class DequantizeLayerImpl CV_FINAL : public DequantizeLayer
 {
 public:
@@ -143,6 +145,52 @@ public:
     }
 };
 
+// Rescale/Requantize INT8 Inputs from (scale1, zeropoint1) to (scale2, zeropoint2)
+class RequantizeLayerImpl CV_FINAL : public RequantizeLayer
+{
+public:
+    RequantizeLayerImpl(const LayerParams& params)
+    {
+        scale = params.get<float>("scale", 1.f);
+        shift = params.get<float>("shift", 0.f);
+        setParamsFrom(params);
+    }
+
+    virtual bool supportBackend(int backendId) CV_OVERRIDE
+    {
+        return backendId == DNN_BACKEND_OPENCV;
+    }
+
+    bool getMemoryShapes(const std::vector<MatShape> &inputs,
+                         const int requiredOutputs,
+                         std::vector<MatShape> &outputs,
+                         std::vector<MatShape> &internals) const CV_OVERRIDE
+    {
+        CV_Assert(inputs.size() == 1);
+        Layer::getMemoryShapes(inputs, requiredOutputs, outputs, internals);
+        return false;
+    }
+
+    virtual void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr) CV_OVERRIDE
+    {
+        std::vector<Mat> inputs, outputs;
+        inputs_arr.getMatVector(inputs);
+        outputs_arr.getMatVector(outputs);
+    }
+
+    void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
+    {
+        CV_TRACE_FUNCTION();
+        CV_TRACE_ARG_VALUE(name, "name", name.c_str());
+
+        std::vector<Mat> inputs, outputs;
+        inputs_arr.getMatVector(inputs);
+        outputs_arr.getMatVector(outputs);
+
+        inputs[0].convertTo(outputs[0], CV_8S, scale, shift);
+    }
+};
+
 Ptr<QuantizeLayer> QuantizeLayer::create(const LayerParams& params)
 {
     return Ptr<QuantizeLayer>(new QuantizeLayerImpl(params));
@@ -153,5 +201,10 @@ Ptr<DequantizeLayer> DequantizeLayer::create(const LayerParams& params)
     return Ptr<DequantizeLayer>(new DequantizeLayerImpl(params));
 }
 
+Ptr<RequantizeLayer> RequantizeLayer::create(const LayerParams& params)
+{
+    return Ptr<RequantizeLayer>(new RequantizeLayerImpl(params));
+}
+
 }
 }
index f24c823..47ff071 100644 (file)
@@ -49,6 +49,8 @@ public:
 
         alignCorners = params.get<bool>("align_corners", false);
         halfPixelCenters = params.get<bool>("half_pixel_centers", false);
+        if (interpolation == "opencv_linear")
+            halfPixelCenters = true;
     }
 
     bool getMemoryShapes(const std::vector<MatShape> &inputs,
@@ -131,8 +133,11 @@ public:
 
         Mat& inp = inputs[0];
         Mat& out = outputs[0];
-        if ((interpolation == "nearest" && !alignCorners && !halfPixelCenters) || interpolation == "opencv_linear" || (interpolation == "bilinear" && halfPixelCenters))
+        int depth = inp.depth();
+        if ((interpolation == "nearest" && !alignCorners && !halfPixelCenters) || (interpolation == "opencv_linear" && depth != CV_8S) ||
+            (interpolation == "bilinear" && halfPixelCenters && depth != CV_8S))
         {
+            // INTER_LINEAR Resize mode does not support INT8 inputs
             InterpolationFlags mode = interpolation == "nearest" ? INTER_NEAREST : INTER_LINEAR;
             for (size_t n = 0; n < inputs[0].size[0]; ++n)
             {
@@ -164,34 +169,66 @@ public:
                 widthOffset = 0.5f * scaleWidth;
             }
 
-            for (int y = 0; y < outHeight; ++y)
+            if (depth == CV_8S)
             {
-                float input_y = y * scaleHeight + heightOffset;
-                int y0 = halfPixelCenters ? std::floor(input_y) : lroundf(input_y);
-                y0 = std::min(y0, inpHeight - 1);
+                for (int y = 0; y < outHeight; ++y)
+                {
+                    float input_y = y * scaleHeight + heightOffset;
+                    int y0 = halfPixelCenters ? std::floor(input_y) : lroundf(input_y);
+                    y0 = std::min(y0, inpHeight - 1);
+
+                    const int8_t* inpData_row = inpPlanes.ptr<int8_t>(y0);
+
+                    for (int x = 0; x < outWidth; ++x)
+                    {
+                        float input_x = x * scaleWidth + widthOffset;
+                        int x0 = halfPixelCenters ? std::floor(input_x) : lroundf(input_x);
+                        x0 = std::min(x0, inpWidth - 1);
+
+                        int8_t* outData = outPlanes.ptr<int8_t>(y, x);
+                        const int8_t* inpData_row_c = inpData_row;
 
-                const float* inpData_row = inpPlanes.ptr<float>(y0);
+                        for (int c = 0; c < numPlanes; ++c)
+                        {
+                            *outData = inpData_row_c[x0];
 
-                for (int x = 0; x < outWidth; ++x)
+                            inpData_row_c += inpSpatialSize;
+                            outData += outSpatialSize;
+                        }
+                    }
+                }
+            }
+            else
+            {
+                for (int y = 0; y < outHeight; ++y)
                 {
-                    float input_x = x * scaleWidth + widthOffset;
-                    int x0 = halfPixelCenters ? std::floor(input_x) : lroundf(input_x);
-                    x0 = std::min(x0, inpWidth - 1);
+                    float input_y = y * scaleHeight + heightOffset;
+                    int y0 = halfPixelCenters ? std::floor(input_y) : lroundf(input_y);
+                    y0 = std::min(y0, inpHeight - 1);
 
-                    float* outData = outPlanes.ptr<float>(y, x);
-                    const float* inpData_row_c = inpData_row;
+                    const float* inpData_row = inpPlanes.ptr<float>(y0);
 
-                    for (int c = 0; c < numPlanes; ++c)
+                    for (int x = 0; x < outWidth; ++x)
                     {
-                        *outData = inpData_row_c[x0];
+                        float input_x = x * scaleWidth + widthOffset;
+                        int x0 = halfPixelCenters ? std::floor(input_x) : lroundf(input_x);
+                        x0 = std::min(x0, inpWidth - 1);
+
+                        float* outData = outPlanes.ptr<float>(y, x);
+                        const float* inpData_row_c = inpData_row;
 
-                        inpData_row_c += inpSpatialSize;
-                        outData += outSpatialSize;
+                        for (int c = 0; c < numPlanes; ++c)
+                        {
+                            *outData = inpData_row_c[x0];
+
+                            inpData_row_c += inpSpatialSize;
+                            outData += outSpatialSize;
+                        }
                     }
                 }
             }
         }
-        else if (interpolation == "bilinear")
+        else if (interpolation == "bilinear" || interpolation == "opencv_linear")
         {
             const int inpHeight = inp.size[2];
             const int inpWidth = inp.size[3];
@@ -202,31 +239,65 @@ public:
 
             Mat inpPlanes = inp.reshape(1, numPlanes * inpHeight);
             Mat outPlanes = out.reshape(1, numPlanes * outHeight);
-            for (int y = 0; y < outHeight; ++y)
+            if (depth == CV_8S)
+            {
+                for (int y = 0; y < outHeight; ++y)
+                {
+                    float input_y = halfPixelCenters ? std::max((y + 0.5f) * scaleHeight - 0.5f, 0.0f) : y * scaleHeight;
+                    int y0 = static_cast<int>(input_y);
+                    const int8_t* inpData_row0 = inpPlanes.ptr<int8_t>(y0);
+                    const int8_t* inpData_row1 = inpPlanes.ptr<int8_t>(std::min(y0 + 1, inpHeight - 1));
+                    for (int x = 0; x < outWidth; ++x)
+                    {
+                        float input_x = halfPixelCenters ? std::max((x + 0.5f) * scaleWidth - 0.5f, 0.0f) : x * scaleWidth;
+                        int x0 = static_cast<int>(input_x);
+                        int x1 = std::min(x0 + 1, inpWidth - 1);
+
+                        int8_t* outData = outPlanes.ptr<int8_t>(y, x);
+                        const int8_t* inpData_row0_c = inpData_row0;
+                        const int8_t* inpData_row1_c = inpData_row1;
+                        for (int c = 0; c < numPlanes; ++c)
+                        {
+                            *outData = static_cast<int8_t>(inpData_row0_c[x0] +
+                                (input_y - y0) * (inpData_row1_c[x0] - inpData_row0_c[x0]) +
+                                (input_x - x0) * (inpData_row0_c[x1] - inpData_row0_c[x0] +
+                                (input_y - y0) * (inpData_row1_c[x1] - inpData_row0_c[x1] - inpData_row1_c[x0] + inpData_row0_c[x0])));
+
+                            inpData_row0_c += inpSpatialSize;
+                            inpData_row1_c += inpSpatialSize;
+                            outData += outSpatialSize;
+                        }
+                    }
+                }
+            }
+            else
             {
-                float input_y = y * scaleHeight;
-                int y0 = static_cast<int>(input_y);
-                const float* inpData_row0 = inpPlanes.ptr<float>(y0);
-                const float* inpData_row1 = inpPlanes.ptr<float>(std::min(y0 + 1, inpHeight - 1));
-                for (int x = 0; x < outWidth; ++x)
+                for (int y = 0; y < outHeight; ++y)
                 {
-                    float input_x = x * scaleWidth;
-                    int x0 = static_cast<int>(input_x);
-                    int x1 = std::min(x0 + 1, inpWidth - 1);
-
-                    float* outData = outPlanes.ptr<float>(y, x);
-                    const float* inpData_row0_c = inpData_row0;
-                    const float* inpData_row1_c = inpData_row1;
-                    for (int c = 0; c < numPlanes; ++c)
+                    float input_y = y * scaleHeight;
+                    int y0 = static_cast<int>(input_y);
+                    const float* inpData_row0 = inpPlanes.ptr<float>(y0);
+                    const float* inpData_row1 = inpPlanes.ptr<float>(std::min(y0 + 1, inpHeight - 1));
+                    for (int x = 0; x < outWidth; ++x)
                     {
-                        *outData = inpData_row0_c[x0] +
-                            (input_y - y0) * (inpData_row1_c[x0] - inpData_row0_c[x0]) +
-                            (input_x - x0) * (inpData_row0_c[x1] - inpData_row0_c[x0] +
-                            (input_y - y0) * (inpData_row1_c[x1] - inpData_row0_c[x1] - inpData_row1_c[x0] + inpData_row0_c[x0]));
-
-                        inpData_row0_c += inpSpatialSize;
-                        inpData_row1_c += inpSpatialSize;
-                        outData += outSpatialSize;
+                        float input_x = x * scaleWidth;
+                        int x0 = static_cast<int>(input_x);
+                        int x1 = std::min(x0 + 1, inpWidth - 1);
+
+                        float* outData = outPlanes.ptr<float>(y, x);
+                        const float* inpData_row0_c = inpData_row0;
+                        const float* inpData_row1_c = inpData_row1;
+                        for (int c = 0; c < numPlanes; ++c)
+                        {
+                            *outData = inpData_row0_c[x0] +
+                                (input_y - y0) * (inpData_row1_c[x0] - inpData_row0_c[x0]) +
+                                (input_x - x0) * (inpData_row0_c[x1] - inpData_row0_c[x0] +
+                                (input_y - y0) * (inpData_row1_c[x1] - inpData_row0_c[x1] - inpData_row1_c[x0] + inpData_row0_c[x0]));
+
+                            inpData_row0_c += inpSpatialSize;
+                            inpData_row1_c += inpSpatialSize;
+                            outData += outSpatialSize;
+                        }
                     }
                 }
             }
@@ -363,6 +434,11 @@ public:
     }
 #endif
 
+    virtual bool tryQuantize(const std::vector<std::vector<float> > &scales,
+                             const std::vector<std::vector<int> > &zeropoints, LayerParams& params) CV_OVERRIDE
+    {
+        return true;
+    }
 
 protected:
     int outWidth, outHeight;
index 7caa5ed..c4ebcd1 100644 (file)
@@ -594,7 +594,8 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net)
 Mat getMatFromTensor(opencv_onnx::TensorProto& tensor_proto)
 {
     if (tensor_proto.raw_data().empty() && tensor_proto.float_data().empty() &&
-        tensor_proto.double_data().empty() && tensor_proto.int64_data().empty())
+        tensor_proto.double_data().empty() && tensor_proto.int64_data().empty() &&
+        tensor_proto.int32_data().empty())
         return Mat();
 
     opencv_onnx::TensorProto_DataType datatype = tensor_proto.data_type();
@@ -663,6 +664,24 @@ Mat getMatFromTensor(opencv_onnx::TensorProto& tensor_proto)
             convertInt64ToInt32(src, dst, blob.total());
         }
     }
+    else if (datatype == opencv_onnx::TensorProto_DataType_INT8 ||
+             datatype == opencv_onnx::TensorProto_DataType_UINT8)
+    {
+        // TODO : Add support for uint8 weights and acitvations. For now, converting uint8 tensors to int8.
+        int offset = datatype == opencv_onnx::TensorProto_DataType_INT8 ? 0 : -128;
+        int depth = datatype == opencv_onnx::TensorProto_DataType_INT8 ? CV_8S : CV_8U;
+
+        if (!tensor_proto.int32_data().empty())
+        {
+            const ::google::protobuf::RepeatedField<int32_t> field = tensor_proto.int32_data();
+            Mat(sizes, CV_32SC1, (void*)field.data()).convertTo(blob, CV_8S, 1.0, offset);
+        }
+        else
+        {
+            char* val = const_cast<char*>(tensor_proto.raw_data().c_str());
+            Mat(sizes, depth, val).convertTo(blob, CV_8S, 1.0, offset);
+        }
+    }
     else
     {
         std::string errorMsg = "Unsupported data type: " +
index 9ef947c..7c230f2 100644 (file)
@@ -63,6 +63,8 @@ class ONNXImporter
     void addConstant(const std::string& name, const Mat& blob);
     void addLayer(LayerParams& layerParams,
                   const opencv_onnx::NodeProto& node_proto);
+    void handleQuantizedNode(LayerParams& layerParams,
+                             const opencv_onnx::NodeProto& node_proto);
 
     void expandMid(const std::string& prefix, opencv_onnx::NodeProto& node_proto,
                    const std::string& input, size_t n);
@@ -142,6 +144,14 @@ private:
     void parseSoftMax              (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
     void parseDetectionOutput      (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
     void parseCumSum               (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
+    void parseQuantDequant         (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
+    void parseQConv                (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
+    void parseQMatMul              (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
+    void parseQEltwise             (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
+    void parseQLeakyRelu           (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
+    void parseQSigmoid             (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
+    void parseQAvgPool             (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
+    void parseQConcat              (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
 
     void parseCustomLayer          (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
 };
@@ -242,7 +252,7 @@ void runLayer(LayerParams& params, const std::vector<Mat>& inputs,
     CV_Assert((bool)layer);
 
     std::vector<MatShape> inpShapes(inputs.size());
-    int ddepth = CV_32F;
+    int ddepth = params.get<int>("depth", CV_32F);
     for (size_t i = 0; i < inputs.size(); ++i)
     {
         inpShapes[i] = shape(inputs[i]);
@@ -458,7 +468,8 @@ Mat ONNXImporter::getBlob(const std::string& input_name)
 void ONNXImporter::addLayer(LayerParams& layerParams,
                             const opencv_onnx::NodeProto& node_proto)
 {
-    int id = dstNet.addLayer(layerParams.name, layerParams.type, layerParams);
+    int depth = layerParams.get<int>("depth", CV_32F);
+    int id = dstNet.addLayer(layerParams.name, layerParams.type, depth, layerParams);
     for (int i = 0; i < node_proto.output_size(); ++i)
     {
         layer_id.insert(std::make_pair(node_proto.output(i), LayerInfo(id, i)));
@@ -525,6 +536,51 @@ void ONNXImporter::addConstant(const std::string& name, const Mat& blob)
     outShapes.insert(std::make_pair(name, shape(blob)));
 }
 
+void ONNXImporter::handleQuantizedNode(LayerParams& layerParams,
+                                       const opencv_onnx::NodeProto& node_proto)
+{
+    // Quantized nodes have output names ending with 'quantized'
+    std::string outName = node_proto.output(0);
+    int len = outName.length();
+    if (len <= 9)
+        return;
+
+    if (outName.substr(len - 9) == "quantized")
+    {
+        outName = outName.substr(0, len - 9);
+        Mat scale, zeropoint;
+
+        if (constBlobs.find(outName + "scale") != constBlobs.end() &&
+            constBlobs.find(outName + "zero_point") != constBlobs.end())
+        {
+            scale = getBlob(outName + "scale");
+            zeropoint = getBlob(outName + "zero_point");
+        }
+        else
+        {
+            std::string inpName = node_proto.input(0);
+            inpName = inpName.substr(0, inpName.length() - 9);
+            scale = getBlob(inpName + "scale");
+            zeropoint = getBlob(inpName + "zero_point");
+
+            for (int i = 0; i < node_proto.output_size(); i++)
+            {
+                std::string out = node_proto.output(i);
+                out = out.substr(0, out.length() - 9);
+                addConstant(out + "scale", scale);
+                addConstant(out + "zero_point", zeropoint);
+            }
+        }
+
+        if (scale.total() != 1 || zeropoint.total() != 1)
+            CV_Error(Error::StsNotImplemented, "Per-channel scales/zeropoints are not supported");
+
+        layerParams.set("depth", CV_8S);
+        layerParams.set("scales", DictValue::arrayReal(scale.ptr<float>(), 1));
+        layerParams.set("zeropoints", DictValue::arrayInt(zeropoint.ptr<int8_t>(), 1));
+    }
+}
+
 void ONNXImporter::populateNet()
 {
     CV_Assert(model_proto.has_graph());
@@ -623,6 +679,8 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto)
         layerParams.type = layer_type;
         layerParams.set("has_dynamic_shapes", hasDynamicShapes);
 
+        handleQuantizedNode(layerParams, node_proto);
+
         DispatchMap::const_iterator iter = dispatch.find(layer_type);
         if (iter != dispatch.end())
         {
@@ -684,7 +742,8 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto)
 
 void ONNXImporter::parseMaxPool(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
 {
-    layerParams.type = "Pooling";
+    int depth = layerParams.get<int>("depth", CV_32F);
+    layerParams.type = (depth == CV_8S) ? "PoolingInt8" : "Pooling";
     layerParams.set("pool", "MAX");
     layerParams.set("ceil_mode", layerParams.has("pad_mode"));
     addLayer(layerParams, node_proto);
@@ -988,7 +1047,8 @@ void ONNXImporter::parseSplit(LayerParams& layerParams, const opencv_onnx::NodeP
     {
         layerParams.set("num_split", node_proto.output_size());
     }
-    layerParams.type = "Slice";
+    int depth = layerParams.get<int>("depth", CV_32F);
+    layerParams.type = (depth == CV_8S) ? "SliceInt8" : "Slice";
     addLayer(layerParams, node_proto);
 }
 
@@ -1743,7 +1803,8 @@ void ONNXImporter::parseConvTranspose(LayerParams& layerParams, const opencv_onn
 
 void ONNXImporter::parseTranspose(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
 {
-    layerParams.type = "Permute";
+    int depth = layerParams.get<int>("depth", CV_32F);
+    layerParams.type = (depth == CV_8S) ? "PermuteInt8" : "Permute";
     replaceLayerParam(layerParams, "perm", "order");
 
     CV_Assert(node_proto.input_size() == 1);
@@ -1807,6 +1868,8 @@ void ONNXImporter::parseSqueeze(LayerParams& layerParams, const opencv_onnx::Nod
         addConstant(layerParams.name, out);
         return;
     }
+    int depth = layerParams.get<int>("depth", CV_32F);
+    layerParams.type += (depth == CV_8S) ? "Int8" : "";
     addLayer(layerParams, node_proto);
 }
 
@@ -1862,12 +1925,14 @@ void ONNXImporter::parseUnsqueeze(LayerParams& layerParams, const opencv_onnx::N
     if (axes.size() != 1)
         CV_Error(Error::StsNotImplemented, "Multidimensional unsqueeze");
 
+    int depth = layerParams.get<int>("depth", CV_32F);
+
     MatShape inpShape = outShapes[node_proto.input(0)];
     int axis = axes.getIntValue(0);
     CV_Assert(0 <= axis && axis <= inpShape.size());
     std::vector<int> outShape = inpShape;
     outShape.insert(outShape.begin() + axis, 1);
-    layerParams.type = "Reshape";
+    layerParams.type = (depth == CV_8S) ? "ReshapeInt8" : "Reshape";
     layerParams.set("dim", DictValue::arrayInt(&outShape[0], outShape.size()));
     if (hasDynamicShapes)
     {
@@ -2004,6 +2069,8 @@ void ONNXImporter::parseExpand(LayerParams& layerParams, const opencv_onnx::Node
 void ONNXImporter::parseReshape(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
 {
     CV_Assert(node_proto.input_size() == 2 || layerParams.has("shape"));
+    int depth = layerParams.get<int>("depth", CV_32F);
+    layerParams.type += (depth == CV_8S) ? "Int8" : "";
 
     if (node_proto.input_size() == 2) {
         Mat blob = getBlob(node_proto, 1);
@@ -2038,7 +2105,8 @@ void ONNXImporter::parseReshape(LayerParams& layerParams, const opencv_onnx::Nod
 
 void ONNXImporter::parsePad(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
 {
-    layerParams.type = "Padding";
+    int depth = layerParams.get<int>("depth", CV_32F);
+    layerParams.type = (depth == CV_8S) ? "PaddingInt8" : "Padding";
     replaceLayerParam(layerParams, "mode", "type");
     if (node_proto.input_size() == 3 || node_proto.input_size() == 2)
     {
@@ -2051,7 +2119,8 @@ void ONNXImporter::parsePad(LayerParams& layerParams, const opencv_onnx::NodePro
         if (node_proto.input_size() == 3)
         {
             Mat value = getBlob(node_proto, 2);
-            layerParams.set("value", value.ptr<float>()[0]);
+            float padValue = (depth == CV_8S) ? (float)value.ptr<int8_t>()[0] : value.ptr<float>()[0];
+            layerParams.set("value", padValue);
         }
     }
     addLayer(layerParams, node_proto);
@@ -2270,6 +2339,9 @@ void ONNXImporter::parseResize(LayerParams& layerParams, const opencv_onnx::Node
     for (int i = 1; i < node_proto.input_size(); i++)
         CV_Assert(layer_id.find(node_proto.input(i)) == layer_id.end());
 
+    int depth = layerParams.get<int>("depth", CV_32F);
+    layerParams.type += (depth == CV_8S) ? "Int8" : "";
+
     if (layerParams.has("coordinate_transformation_mode"))
     {
         String interp_mode = layerParams.get<String>("coordinate_transformation_mode");
@@ -2419,6 +2491,396 @@ void ONNXImporter::parseCustomLayer(LayerParams& layerParams, const opencv_onnx:
     addLayer(layerParams, node_proto);
 }
 
+void ONNXImporter::parseQuantDequant(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
+{
+    CV_Assert(node_proto.input_size() == 3);
+    layerParams.type = (node_proto.op_type() == "QuantizeLinear") ? "Quantize" : "Dequantize";
+
+    if (node_proto.op_type() == "DequantizeLinear")
+    {
+        Mat scale = getBlob(node_proto, 1);
+        Mat zeropoint = getBlob(node_proto, 2);
+
+        layerParams.set("scales", DictValue::arrayReal(scale.ptr<float>(), 1));
+        layerParams.set("zeropoints", DictValue::arrayInt(zeropoint.ptr<int8_t>(), 1));
+    }
+    addLayer(layerParams, node_proto);
+}
+
+void ONNXImporter::parseQConv(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
+{
+    int ninputs = node_proto.input_size();
+    CV_Assert(ninputs == 8 || ninputs == 9);
+
+    Mat inp_sc = getBlob(node_proto, 1);
+    Mat inp_zp = getBlob(node_proto, 2);
+
+    Mat weights = getBlob(node_proto, 3);
+    int outCn = weights.size[0];
+    Mat w_scale = getBlob(node_proto, 4);
+    CV_Assert(w_scale.total() == 1 || w_scale.total() == outCn);
+    Mat wt_sc = (w_scale.total() == outCn) ? w_scale : Mat(1, outCn, CV_32F, Scalar(w_scale.at<float>(0)));
+
+    Mat out_sc = getBlob(node_proto, 6);
+    Mat bias = (ninputs == 9) ? getBlob(node_proto, 8) : Mat::zeros(1, outCn, CV_32S);
+
+    Mat weights_2d = weights.reshape(1, outCn);
+    Mat biasFused(1, outCn, CV_32S);
+    Mat outputMultiplier(1, outCn, CV_32F);
+    for (int i = 0; i < outCn; i++)
+    {
+        biasFused.at<int>(i) = bias.at<int>(i) - inp_zp.at<int8_t>(0)*(cv::sum(weights_2d.row(i))[0]);
+        outputMultiplier.at<float>(i) = (inp_sc.at<float>(0) * wt_sc.at<float>(i)) / out_sc.at<float>(0);
+    }
+
+    layerParams.type = "ConvolutionInt8";
+    layerParams.set("num_output", outCn);
+    layerParams.set("input_zeropoint", inp_zp.at<int8_t>(0));
+    layerParams.blobs.push_back(weights);
+    layerParams.blobs.push_back(biasFused);
+    layerParams.blobs.push_back(outputMultiplier);
+    addLayer(layerParams, node_proto);
+}
+
+void ONNXImporter::parseQMatMul(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
+{
+    int ninputs = node_proto.input_size();
+    CV_Assert(ninputs == 8);
+
+    if (constBlobs.find(node_proto.input(3)) == constBlobs.end())
+        CV_Error(Error::StsNotImplemented, "Variable weights is not supported");
+
+    int firstInpDims = outShapes[node_proto.input(0)].size();
+
+    Mat inp_sc = getBlob(node_proto, 1);
+    Mat inp_zp = getBlob(node_proto, 2);
+
+    Mat weights = getBlob(node_proto, 3).t();
+    int outCn = weights.size[0];
+    int secondInpDims = weights.dims;
+
+    Mat w_scale = getBlob(node_proto, 4);
+    CV_Assert(w_scale.total() == 1 || w_scale.total() == outCn);
+    Mat wt_sc = (w_scale.total() == outCn) ? w_scale : Mat(1, outCn, CV_32F, Scalar(w_scale.at<float>(0)));
+    Mat out_sc = getBlob(node_proto, 6);
+
+    Mat bias(1, outCn, CV_32S);
+    Mat outputMultiplier(1, outCn, CV_32F);
+    for (int i = 0; i < outCn; i++)
+    {
+        bias.at<int>(i) = -inp_zp.at<int8_t>(0)*(cv::sum(weights.row(i))[0]);
+        outputMultiplier.at<float>(i) = (inp_sc.at<float>(0) * wt_sc.at<float>(i)) / out_sc.at<float>(0);
+    }
+
+    layerParams.type = "InnerProductInt8";
+    layerParams.set("num_output", outCn);
+    layerParams.set("axis", firstInpDims - secondInpDims + 1);
+    layerParams.blobs.push_back(weights);
+    layerParams.blobs.push_back(bias);
+    layerParams.blobs.push_back(outputMultiplier);
+    addLayer(layerParams, node_proto);
+}
+
+void ONNXImporter::parseQEltwise(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto_)
+{
+    opencv_onnx::NodeProto node_proto = node_proto_;
+    CV_Assert(node_proto.input_size() == 8);
+    std::string op = (node_proto.op_type() == "QLinearAdd") ? "sum" : "prod";
+    int constId = -1;
+    for (int i = 0; i < 4; i += 3)
+    {
+        if (constBlobs.find(node_proto.input(i)) != constBlobs.end())
+            constId = i;
+    }
+
+    Mat inp_0_sc = getBlob(node_proto, 1);
+    Mat inp_0_zp = getBlob(node_proto, 2);
+
+    Mat inp_1_sc = getBlob(node_proto, 4);
+    Mat inp_1_zp = getBlob(node_proto, 5);
+
+    // Set 2nd input as the const input
+    if (constId == 0)
+    {
+        cv::swap(inp_0_sc, inp_1_sc);
+        cv::swap(inp_0_zp, inp_1_zp);
+    }
+
+    float out_sc = getBlob(node_proto, 6).at<float>(0);
+    int8_t out_zp = getBlob(node_proto, 7).at<int8_t>(0);
+
+    std::vector<float> inp_scales = {inp_0_sc.at<float>(0), inp_1_sc.at<float>(0)};
+    std::vector<int8_t> inp_zps = {inp_0_zp.at<int8_t>(0), inp_1_zp.at<int8_t>(0)};
+
+    std::vector<float> coeffs;
+    float offset;
+    if (op == "sum")
+    {
+        coeffs = {inp_scales[0]/out_sc, inp_scales[1]/out_sc};
+        offset = out_zp - coeffs[0]*inp_zps[0] - coeffs[1]*inp_zps[1];
+    }
+    else
+    {
+        coeffs = {inp_scales[0]/out_sc, inp_scales[1]};
+        offset = out_zp;
+    }
+
+    if (constId != -1)
+    {
+        Mat blob = getBlob(node_proto, constId);
+        if (blob.total() == 1)
+        {
+            float val = inp_scales[1] * (blob.at<int8_t>(0) - inp_zps[1]);
+            float scale = inp_scales[0] / out_sc;
+            if (op == "prod")
+                scale *= val;
+
+            float shift = out_zp - scale*inp_zps[0];
+            if (op == "sum")
+                shift += (val/out_sc);
+
+            LayerParams rescaleParams;
+            rescaleParams.name = layerParams.name;
+            rescaleParams.type = "Requantize";
+            rescaleParams.set("depth", CV_8S);
+            rescaleParams.set("scale", scale);
+            rescaleParams.set("shift", shift);
+            addLayer(rescaleParams, node_proto);
+            return;
+        }
+        else
+        {
+            MatShape inpShape = outShapes[node_proto.input(3 - constId)];
+            if (blob.dims == 2)
+                blob = blob.t();
+
+            if (shape(blob) == inpShape)
+            {
+                LayerParams constParams;
+                constParams.name = layerParams.name + "/const";
+                constParams.type = "ConstInt8";
+                constParams.set("depth", CV_8S);
+                constParams.set("scales", DictValue::arrayReal(inp_1_sc.ptr<float>(), 1));
+                constParams.set("zeropoints", DictValue::arrayInt(inp_1_zp.ptr<int8_t>(), 1));
+                constParams.blobs.push_back(blob);
+
+                int id = dstNet.addLayer(constParams.name, constParams.type, CV_8S, constParams);
+                layer_id.insert(std::make_pair(constParams.name, LayerInfo(id, 0)));
+                outShapes[constParams.name] = shape(blob);
+                node_proto.set_input(constId, constParams.name);
+
+                layerParams.type = "EltwiseInt8";
+                layerParams.set("operation", op);
+                layerParams.set("coeff", DictValue::arrayReal(coeffs.data(), coeffs.size()));
+                layerParams.set("offset", offset);
+            }
+            else
+            {
+                layerParams.type = "ScaleInt8";
+                layerParams.set("bias_term", op == "sum");
+                int axis = 1;
+                for (int i = 0; i < graph_proto.initializer_size(); i++)
+                {
+                    opencv_onnx::TensorProto tensor_proto = graph_proto.initializer(i);
+                    if (tensor_proto.name() == node_proto.input(constId))
+                    {
+                        axis = inpShape.size() - tensor_proto.dims_size();
+                        break;
+                    }
+                }
+                layerParams.set("axis", axis);
+                blob = blob.reshape(1, 1);
+                Mat blob_dequantized;
+                blob.convertTo(blob_dequantized, CV_32F, inp_scales[1], -(inp_scales[1] * inp_zps[1]));
+                layerParams.blobs.push_back(blob_dequantized);
+                layerParams.set("input_scales", DictValue::arrayReal(inp_scales.data(), inp_scales.size()));
+            }
+        }
+    }
+    else if (outShapes[node_proto.input(0)] == outShapes[node_proto.input(3)])
+    {
+        layerParams.type = "EltwiseInt8";
+        layerParams.set("operation", op);
+        layerParams.set("coeff", DictValue::arrayReal(coeffs.data(), coeffs.size()));
+        layerParams.set("offset", offset);
+    }
+    else
+    {
+        layerParams.type = "ScaleInt8";
+        layerParams.set("bias_term", op == "sum");
+        layerParams.set("input_scales", DictValue::arrayReal(inp_scales.data(), inp_scales.size()));
+    }
+
+    layerParams.set("input_zeropoints", DictValue::arrayInt(inp_zps.data(), inp_zps.size()));
+    addLayer(layerParams, node_proto);
+}
+
+void ONNXImporter::parseQLeakyRelu(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
+{
+    CV_Assert(node_proto.input_size() == 5);
+
+    float slope = layerParams.get<float>("alpha");
+    float inp_sc = getBlob(node_proto, 1).at<float>(0);
+    int8_t inp_zp = getBlob(node_proto, 2).at<int8_t>(0);
+    float out_sc = getBlob(node_proto, 3).at<float>(0);
+    int8_t out_zp = getBlob(node_proto, 4).at<int8_t>(0);
+
+    Mat lookUpTable(1, 256, CV_8S);
+    int8_t* table = lookUpTable.ptr<int8_t>();
+    for (int i = -128; i < 128; i++)
+    {
+        float x = inp_sc*(i - inp_zp);
+        float y = x >= 0.f ? x : slope*x;
+        int quantized = out_zp + cvRound(y/out_sc);
+        table[i+128] = saturate_cast<int8_t>(quantized);
+    }
+
+    layerParams.type = "ReLUInt8";
+    layerParams.blobs.push_back(lookUpTable);
+    addLayer(layerParams, node_proto);
+}
+
+void ONNXImporter::parseQSigmoid(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
+{
+    CV_Assert(node_proto.input_size() == 5);
+
+    float inp_sc = getBlob(node_proto, 1).at<float>(0);
+    int8_t inp_zp = getBlob(node_proto, 2).at<int8_t>(0);
+    float out_sc = getBlob(node_proto, 3).at<float>(0);
+    int8_t out_zp = getBlob(node_proto, 4).at<int8_t>(0);
+
+    Mat lookUpTable(1, 256, CV_8S);
+    int8_t* table = lookUpTable.ptr<int8_t>();
+    for (int i = -128; i < 128; i++)
+    {
+        float x = inp_sc*(i - inp_zp);
+        float y = 1.f/(1.f + std::exp(-x));
+        int quantized = out_zp + cvRound(y/out_sc);
+        table[i+128] = saturate_cast<int8_t>(quantized);
+    }
+
+    layerParams.type = "SigmoidInt8";
+    layerParams.blobs.push_back(lookUpTable);
+    addLayer(layerParams, node_proto);
+}
+
+void ONNXImporter::parseQAvgPool(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
+{
+    CV_Assert(node_proto.input_size() == 5);
+    float inp_sc = getBlob(node_proto, 1).at<float>(0);
+    int8_t inp_zp = getBlob(node_proto, 2).at<int8_t>(0);
+    float out_sc = getBlob(node_proto, 3).at<float>(0);
+
+    layerParams.type = "PoolingInt8";
+    layerParams.set("pool", "ave");
+    layerParams.set("global_pooling", node_proto.op_type() == "QLinearGlobalAveragePool");
+    layerParams.set("multiplier", inp_sc/out_sc);
+    layerParams.set("input_zeropoint", inp_zp);
+    addLayer(layerParams, node_proto);
+}
+
+void ONNXImporter::parseQConcat(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto_)
+{
+    opencv_onnx::NodeProto node_proto = node_proto_;
+    layerParams.type = "ConcatInt8";
+    int num_inputs = node_proto.input_size();
+
+    float out_scale = getBlob(node_proto, 0).at<float>(0);
+    int out_zp = getBlob(node_proto, 1).at<int8_t>(0);
+
+    for (int i = 2; i < num_inputs; i += 3)
+    {
+        float inp_scale = getBlob(node_proto, i + 1).at<float>(0);
+        int inp_zp = getBlob(node_proto, i + 2).at<int8_t>(0);
+
+        if (inp_scale != out_scale || inp_zp != out_zp)
+        {
+            float scale = inp_scale/out_scale;
+            float shift = out_zp - scale*inp_zp;
+
+            if (constBlobs.find(node_proto.input(i)) != constBlobs.end())
+            {
+                Mat blob = getBlob(node_proto, i);
+                Mat blob_rescaled;
+                blob.convertTo(blob_rescaled, CV_8S, scale, shift);
+                constBlobs[node_proto.input(i)] = blob_rescaled;
+            }
+            else
+            {
+                LayerParams rescaleParams;
+                rescaleParams.name = node_proto.input(i) + "/rescale";
+                rescaleParams.type = "Requantize";
+                rescaleParams.set("depth", CV_8S);
+                rescaleParams.set("scale", scale);
+                rescaleParams.set("shift", shift);
+
+                opencv_onnx::NodeProto proto;
+                proto.add_input(node_proto.input(i));
+                proto.add_output(rescaleParams.name);
+                addLayer(rescaleParams, proto);
+                node_proto.set_input(i, rescaleParams.name);
+            }
+        }
+    }
+
+    bool hasVariableInps = false;
+    for (int i = 2; i < num_inputs; i += 3)
+    {
+        if (layer_id.find(node_proto.input(i)) != layer_id.end())
+        {
+            hasVariableInps = true;
+            break;
+        }
+    }
+
+    if (!hasVariableInps)
+    {
+        std::vector<Mat> inputs, concatenated;
+        MatShape inputShape;
+        for (size_t i = 2; i < num_inputs; i += 3)
+        {
+            Mat blob = getBlob(node_proto, i);
+            if (blob.size.dims() > inputShape.size())
+            {
+                inputShape = shape(blob);
+            }
+            inputs.push_back(blob);
+        }
+
+        int axis = layerParams.get<int>("axis", 1);
+        for (size_t i = 0; i < inputs.size(); ++i)
+        {
+            MatShape targetShape = inputShape;
+            targetShape[axis] = shape(inputs[i])[axis];
+            CV_CheckEQ(total(targetShape), total(shape(inputs[i])), "");
+            inputs[i] = inputs[i].reshape(0, targetShape);
+        }
+        runLayer(layerParams, inputs, concatenated);
+        CV_Assert(concatenated.size() == 1);
+        addConstant(layerParams.name, concatenated[0]);
+        return;
+    }
+    else
+    {
+        for (int i = 2; i < num_inputs; i += 3)
+        {
+            if (constBlobs.find(node_proto.input(i)) != constBlobs.end())
+            {
+                LayerParams constParams;
+                constParams.name = node_proto.input(i);
+                constParams.type = "ConstInt8";
+                constParams.blobs.push_back(getBlob(node_proto, i));
+                constParams.set("depth", CV_8S);
+
+                opencv_onnx::NodeProto proto;
+                proto.add_output(constParams.name);
+                addLayer(constParams, proto);
+            }
+        }
+    }
+    addLayer(layerParams, node_proto);
+}
+
 const ONNXImporter::DispatchMap ONNXImporter::buildDispatchMap()
 {
     DispatchMap dispatch;
@@ -2468,6 +2930,14 @@ const ONNXImporter::DispatchMap ONNXImporter::buildDispatchMap()
     dispatch["SoftMax"] = dispatch["LogSoftmax"] = &ONNXImporter::parseSoftMax;
     dispatch["DetectionOutput"] = &ONNXImporter::parseDetectionOutput;
     dispatch["CumSum"] = &ONNXImporter::parseCumSum;
+    dispatch["QuantizeLinear"] = dispatch["DequantizeLinear"] = &ONNXImporter::parseQuantDequant;
+    dispatch["QLinearConv"] = &ONNXImporter::parseQConv;
+    dispatch["QLinearMatMul"] = &ONNXImporter::parseQMatMul;
+    dispatch["QLinearAdd"] = dispatch["QLinearMul"] = &ONNXImporter::parseQEltwise;
+    dispatch["QLinearLeakyRelu"] = &ONNXImporter::parseQLeakyRelu;
+    dispatch["QLinearSigmoid"] = &ONNXImporter::parseQSigmoid;
+    dispatch["QLinearAveragePool"] = dispatch["QLinearGlobalAveragePool"] = &ONNXImporter::parseQAvgPool;
+    dispatch["QLinearConcat"] = &ONNXImporter::parseQConcat;
 
     return dispatch;
 }
index 1fcb1d0..5e6c05c 100644 (file)
@@ -583,7 +583,7 @@ TEST_P(Test_Int8_nets, ResNet50)
     Mat blob = blobFromImage(inp, 1.0, Size(224, 224), Scalar(), false);
     Mat ref = blobFromNPY(_tf("resnet50_prob.npy"));
 
-    float l1 = 3e-4, lInf = 0.035;
+    float l1 = 3e-4, lInf = 0.04;
     testClassificationNet(net, blob, ref, l1, lInf);
 }
 
@@ -714,7 +714,7 @@ TEST_P(Test_Int8_nets, MobileNet_v1_SSD_PPN)
     Mat blob = blobFromImage(inp, 1.0, Size(300, 300), Scalar(), true, false);
     Mat ref = blobFromNPY(_tf("tensorflow/ssd_mobilenet_v1_ppn_coco.detection_out.npy"));
 
-    float confThreshold = 0.51, scoreDiff = 0.04, iouDiff = 0.06;
+    float confThreshold = 0.51, scoreDiff = 0.05, iouDiff = 0.06;
     testDetectionNet(net, blob, ref, confThreshold, scoreDiff, iouDiff);
 }
 
@@ -815,7 +815,7 @@ TEST_P(Test_Int8_nets, FasterRCNN_resnet50)
     Mat blob = blobFromImage(inp, 1.0, Size(800, 600), Scalar(), true, false);
     Mat ref = blobFromNPY(_tf("tensorflow/faster_rcnn_resnet50_coco_2018_01_28.detection_out.npy"));
 
-    float confThreshold = 0.5, scoreDiff = 0.025, iouDiff = 0.15;
+    float confThreshold = 0.5, scoreDiff = 0.05, iouDiff = 0.15;
     testDetectionNet(net, blob, ref, confThreshold, scoreDiff, iouDiff);
 }
 
@@ -1127,7 +1127,7 @@ TEST_P(Test_Int8_nets, YOLOv4)
 
     std::string config_file = "yolov4.cfg";
     std::string weights_file = "yolov4.weights";
-    double scoreDiff = 0.1, iouDiff = 0.17;
+    double scoreDiff = 0.15, iouDiff = 0.2;
     {
         SCOPED_TRACE("batch size 1");
         testDarknetModel(config_file, weights_file, ref.rowRange(0, N0), scoreDiff, iouDiff);
index 4334da2..5d324b8 100644 (file)
@@ -991,6 +991,112 @@ TEST_P(Test_ONNX_layers, ConvResizePool1d)
     testONNXModels("conv_resize_pool_1d");
 }
 
+TEST_P(Test_ONNX_layers, Quantized_Convolution)
+{
+    testONNXModels("quantized_conv_uint8_weights", npy, 0.004, 0.02);
+    testONNXModels("quantized_conv_int8_weights", npy, 0.03, 0.5);
+    testONNXModels("quantized_conv_per_channel_weights", npy, 0.06, 0.4);
+}
+
+TEST_P(Test_ONNX_layers, Quantized_MatMul)
+{
+    testONNXModels("quantized_matmul_uint8_weights", npy, 0.005, 0.007);
+    testONNXModels("quantized_matmul_int8_weights", npy, 0.06, 0.2);
+    testONNXModels("quantized_matmul_per_channel_weights", npy, 0.06, 0.22);
+}
+
+TEST_P(Test_ONNX_layers, Quantized_MatMul_Variable_Weights)
+{
+    // Unsupported
+    EXPECT_THROW(
+    {
+        testONNXModels("quantized_matmul_variable_inputs");
+    }, cv::Exception);
+}
+
+TEST_P(Test_ONNX_layers, Quantized_Eltwise)
+{
+    testONNXModels("quantized_eltwise");
+}
+
+TEST_P(Test_ONNX_layers, Quantized_Eltwise_Scalar)
+{
+    testONNXModels("quantized_eltwise_scalar");
+}
+
+TEST_P(Test_ONNX_layers, Quantized_Eltwise_Broadcast)
+{
+    testONNXModels("quantized_eltwise_broadcast");
+}
+
+TEST_P(Test_ONNX_layers, Quantized_LeakyReLU)
+{
+    testONNXModels("quantized_leaky_relu");
+}
+
+TEST_P(Test_ONNX_layers, Quantized_Sigmoid)
+{
+    testONNXModels("quantized_sigmoid");
+}
+
+TEST_P(Test_ONNX_layers, Quantized_MaxPool)
+{
+    testONNXModels("quantized_maxpool");
+}
+
+TEST_P(Test_ONNX_layers, Quantized_AvgPool)
+{
+    testONNXModels("quantized_avgpool");
+}
+
+TEST_P(Test_ONNX_layers, Quantized_Split)
+{
+    testONNXModels("quantized_split");
+}
+
+TEST_P(Test_ONNX_layers, Quantized_Pad)
+{
+    testONNXModels("quantized_padding");
+}
+
+TEST_P(Test_ONNX_layers, Quantized_Reshape)
+{
+    testONNXModels("quantized_reshape");
+}
+
+TEST_P(Test_ONNX_layers, Quantized_Transpose)
+{
+    testONNXModels("quantized_transpose");
+}
+
+TEST_P(Test_ONNX_layers, Quantized_Squeeze)
+{
+    testONNXModels("quantized_squeeze");
+}
+
+TEST_P(Test_ONNX_layers, Quantized_Unsqueeze)
+{
+    testONNXModels("quantized_unsqueeze");
+}
+
+TEST_P(Test_ONNX_layers, Quantized_Resize)
+{
+    testONNXModels("quantized_resize_nearest");
+    testONNXModels("quantized_resize_bilinear", npy, 2e-4, 0.003);
+    testONNXModels("quantized_resize_bilinear_align", npy, 3e-4, 0.003);
+}
+
+TEST_P(Test_ONNX_layers, Quantized_Concat)
+{
+    testONNXModels("quantized_concat");
+    testONNXModels("quantized_concat_const_blob");
+}
+
+TEST_P(Test_ONNX_layers, Quantized_Constant)
+{
+    testONNXModels("quantized_constant", npy, 0.002, 0.008);
+}
+
 INSTANTIATE_TEST_CASE_P(/*nothing*/, Test_ONNX_layers, dnnBackendsAndTargets());
 
 class Test_ONNX_nets : public Test_ONNX_layers
@@ -1127,6 +1233,11 @@ TEST_P(Test_ONNX_nets, ResNet50v1)
     testONNXModels("resnet50v1", pb, default_l1, default_lInf, true, target != DNN_TARGET_MYRIAD);
 }
 
+TEST_P(Test_ONNX_nets, ResNet50_Int8)
+{
+    testONNXModels("resnet50_int8", pb, default_l1, default_lInf, true);
+}
+
 TEST_P(Test_ONNX_nets, ResNet101_DUC_HDC)
 {
     applyTestTag(CV_TEST_TAG_VERYLONG);