Support for some reduce layers for onnx
authorZihao Mu <zihaomu@outlook.com>
Fri, 18 Mar 2022 02:19:13 +0000 (10:19 +0800)
committerZihao Mu <zihaomu@outlook.com>
Fri, 18 Mar 2022 02:19:13 +0000 (10:19 +0800)
modules/dnn/include/opencv2/dnn/all_layers.hpp
modules/dnn/src/init.cpp
modules/dnn/src/int8layers/reduce_layer.cpp [new file with mode: 0644]
modules/dnn/src/layers/reduce_layer.cpp [new file with mode: 0644]
modules/dnn/src/net_quantization.cpp
modules/dnn/src/onnx/onnx_importer.cpp
modules/dnn/test/test_onnx_conformance_layer_filter_opencv_ocl_fp16_denylist.inl.hpp
modules/dnn/test/test_onnx_conformance_layer_parser_denylist.inl.hpp

index 44b16f7..c8c1475 100644 (file)
@@ -325,6 +325,20 @@ CV__DNN_INLINE_NS_BEGIN
         static Ptr<PoolingLayerInt8> create(const LayerParams& params);
     };
 
+    class CV_EXPORTS ReduceLayer : public Layer
+    {
+    public:
+        int reduceType;
+        std::vector<size_t> reduceDims;
+        static Ptr<ReduceLayer> create(const LayerParams& params);
+    };
+
+    class CV_EXPORTS ReduceLayerInt8 : public ReduceLayer
+    {
+    public:
+        static Ptr<ReduceLayerInt8> create(const LayerParams& params);
+    };
+
     class CV_EXPORTS SoftmaxLayer : public Layer
     {
     public:
index 55ed1e5..86ceba3 100644 (file)
@@ -92,6 +92,7 @@ void initializeLayerFactory()
     CV_DNN_REGISTER_LAYER_CLASS(Pooling,        PoolingLayer);
     CV_DNN_REGISTER_LAYER_CLASS(ROIPooling,     PoolingLayer);
     CV_DNN_REGISTER_LAYER_CLASS(PSROIPooling,   PoolingLayer);
+    CV_DNN_REGISTER_LAYER_CLASS(Reduce,         ReduceLayer);
     CV_DNN_REGISTER_LAYER_CLASS(LRN,            LRNLayer);
     CV_DNN_REGISTER_LAYER_CLASS(InnerProduct,   InnerProductLayer);
     CV_DNN_REGISTER_LAYER_CLASS(Softmax,        SoftmaxLayer);
@@ -175,6 +176,7 @@ void initializeLayerFactory()
     CV_DNN_REGISTER_LAYER_CLASS(ConvolutionInt8,  ConvolutionLayerInt8);
     CV_DNN_REGISTER_LAYER_CLASS(InnerProductInt8, InnerProductLayerInt8);
     CV_DNN_REGISTER_LAYER_CLASS(PoolingInt8,      PoolingLayerInt8);
+    CV_DNN_REGISTER_LAYER_CLASS(ReduceInt8,       ReduceLayerInt8);
     CV_DNN_REGISTER_LAYER_CLASS(EltwiseInt8,      EltwiseLayerInt8);
     CV_DNN_REGISTER_LAYER_CLASS(BatchNormInt8,    BatchNormLayerInt8);
     CV_DNN_REGISTER_LAYER_CLASS(ScaleInt8,        ScaleLayerInt8);
diff --git a/modules/dnn/src/int8layers/reduce_layer.cpp b/modules/dnn/src/int8layers/reduce_layer.cpp
new file mode 100644 (file)
index 0000000..935bdc0
--- /dev/null
@@ -0,0 +1,213 @@
+// This file is part of OpenCV project.
+// It is subject to the license terms in the LICENSE file found in the top-level directory
+// of this distribution and at http://opencv.org/license.html.
+
+#include "../precomp.hpp"
+#include "layers_common.hpp"
+
+#include <algorithm>
+#include <stdlib.h>
+#include <numeric>
+
+namespace cv
+{
+namespace dnn
+{
+
+class ReduceLayerInt8Impl CV_FINAL : public ReduceLayerInt8
+{
+public:
+    ReduceLayerInt8Impl(const LayerParams& params)
+    {
+        // Set reduce type
+        CV_Assert(params.has("reduce"));
+        String typeString = toLowerCase(params.get<String>("reduce"));
+        if (typeString == "max")
+            reduceType = MAX;
+        else if (typeString == "min")
+            reduceType = MIN;
+        else
+            CV_Error(Error::StsBadArg, "Unknown reduce type \"" + typeString + "\"");
+
+        // Set deleted dims
+        CV_Assert(params.has("deleted_dims"));
+        DictValue tempDims = params.get("deleted_dims");
+        int i, n = tempDims.size();
+        reduceDims.resize(n);
+        for (i = 0; i < n; i++)
+        {
+            reduceDims[i] = tempDims.get<int>(i);
+        }
+    }
+
+    virtual bool supportBackend(int backendId) CV_OVERRIDE
+    {
+        if (backendId == DNN_BACKEND_OPENCV)
+        {
+            return true;
+        }
+        return false;
+    }
+
+    // reduceType == MIN
+    struct ReduceOpMIN
+    {
+        int8_t apply(const int8_t* first, const int8_t* last)
+        {
+            return std::accumulate(first, last, *first,
+                                   [](int8_t a, int8_t b)
+                                   {
+                                       return std::min(a, b);
+                                   });
+        }
+    };
+
+    // reduceType == MAX
+    struct ReduceOpMAX
+    {
+        int8_t apply(const int8_t* first, const int8_t* last)
+        {
+            return std::accumulate(first, last, *first,
+                                   [](int8_t a, int8_t b)
+                                   {
+                                       return std::max(a, b);
+                                   });
+        }
+    };
+
+    template<typename Func>
+    class ReduceInvoker : public ParallelLoopBody
+    {
+    public:
+        const Mat* src;
+        Mat *dst;
+        std::vector<size_t> reduceDims;
+        int nstripes;
+        int reduceType;
+        Ptr<Func> func;
+
+        ReduceInvoker() : src(0), dst(0), nstripes(0), reduceType(MAX), func(makePtr<Func>()) {}
+
+        static void run(const Mat& src, Mat& dst, std::vector<size_t> reduceDims, int reduceType, int nstripes)
+        {
+            CV_Assert_N(src.isContinuous(), dst.isContinuous(), src.type() == CV_8S, src.type() == dst.type());
+
+            ReduceInvoker<Func> p;
+
+            p.src = &src;
+            p.dst = &dst;
+
+            p.reduceDims = reduceDims;
+            p.nstripes = nstripes;
+            p.reduceType = reduceType;
+
+            parallel_for_(Range(0, nstripes), p, nstripes);
+        }
+
+        void operator()(const Range& r) const CV_OVERRIDE
+        {
+            size_t total = dst->total();
+            size_t stripeSize = (total + nstripes - 1)/nstripes;
+            size_t stripeStart = r.start*stripeSize;
+            size_t stripeEnd = std::min(r.end*stripeSize, total);
+            size_t totalDeleted = std::accumulate(reduceDims.begin(), reduceDims.end(), 1, std::multiplies<size_t>());
+
+            int8_t *dstData = (int8_t *)dst->data;
+            int8_t *srcData = (int8_t *)src->data;
+
+            for (size_t ofs = stripeStart; ofs < stripeEnd;)
+            {
+                const int8_t* first = srcData + ofs * totalDeleted;
+                const int8_t* last = srcData + (ofs + 1) * totalDeleted;
+
+                dstData[ofs] = func->apply(first, last);
+                ofs += 1;
+            }
+        }
+    };
+
+    void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
+    {
+        CV_TRACE_FUNCTION();
+        CV_TRACE_ARG_VALUE(name, "name", name.c_str());
+
+        std::vector<Mat> inputs, outputs;
+        inputs_arr.getMatVector(inputs);
+        outputs_arr.getMatVector(outputs);
+        CV_Assert(inputs.size() == 1);
+        const int nstripes = getNumThreads();
+
+        switch (reduceType)
+        {
+            case MIN:
+            {
+                ReduceInvoker<ReduceOpMIN>::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes);
+                break;
+            }
+            case MAX:
+            {
+                ReduceInvoker<ReduceOpMAX>::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes);
+                break;
+            }
+            default:
+                CV_Error(Error::StsNotImplemented, "Not implemented");
+                break;
+        }
+    }
+
+    bool getMemoryShapes(const std::vector<MatShape> &inputs,
+                         const int requiredOutputs,
+                         std::vector<MatShape> &outputs,
+                         std::vector<MatShape> &internals) const CV_OVERRIDE
+    {
+        CV_Assert(inputs.size() > 0);
+        CV_Assert(reduceDims.size() != 0 && inputs[0].size() >= reduceDims.size());
+
+        std::vector<int> outShape;
+        if (inputs[0].size() == reduceDims.size())
+            outShape.push_back(1);
+        else
+        {
+            for (int i = 0; i < inputs[0].size() - reduceDims.size(); i++)
+            {
+                outShape.push_back(inputs[0][i]);
+            }
+        }
+        outputs.assign(1, outShape);
+
+        return false;
+    }
+
+    virtual bool tryQuantize(const std::vector<std::vector<float> > &scales,
+                             const std::vector<std::vector<int> > &zeropoints, LayerParams& params) CV_OVERRIDE
+    {
+        return false;
+    }
+
+    virtual int64 getFLOPS(const std::vector<MatShape> &inputs,
+                           const std::vector<MatShape> &outputs) const CV_OVERRIDE
+    {
+        CV_UNUSED(inputs); // suppress unused variable warning
+        long flops = 0;
+        size_t totalDeleted = std::accumulate(reduceDims.begin(), reduceDims.end(), 1, std::multiplies<size_t>());
+        for (int i = 0; i < outputs.size(); i++)
+        {
+            flops += total(outputs[i])*(totalDeleted);
+        }
+        return flops;
+    }
+private:
+    enum Type
+    {
+        MAX,
+        MIN
+    };
+};
+
+Ptr<ReduceLayerInt8> ReduceLayerInt8::create(const LayerParams& params)
+{
+    return Ptr<ReduceLayerInt8>(new ReduceLayerInt8Impl(params));
+}
+
+}
+}
diff --git a/modules/dnn/src/layers/reduce_layer.cpp b/modules/dnn/src/layers/reduce_layer.cpp
new file mode 100644 (file)
index 0000000..62bb65f
--- /dev/null
@@ -0,0 +1,388 @@
+// This file is part of OpenCV project.
+// It is subject to the license terms in the LICENSE file found in the top-level directory
+// of this distribution and at http://opencv.org/license.html.
+
+#include "../precomp.hpp"
+#include "opencv2/core/hal/intrin.hpp"
+#include "../op_cuda.hpp"
+#include "../op_webnn.hpp"
+
+#include <float.h>
+#include <algorithm>
+#include <numeric>
+using std::max;
+using std::min;
+
+#include <opencv2/core/utils/logger.hpp>
+
+namespace cv
+{
+namespace dnn
+{
+
+class ReduceLayerImpl CV_FINAL : public ReduceLayer
+{
+public:
+    ReduceLayerImpl(const LayerParams& params)
+    {
+        // set reduce type
+        CV_Assert(params.has("reduce"));
+        String typeString = toLowerCase(params.get<String>("reduce"));
+        if (typeString == "max")
+            reduceType= MAX;
+        else if (typeString == "min")
+            reduceType= MIN;
+        else if (typeString == "ave")
+            reduceType= AVE;
+        else if (typeString == "sum")
+            reduceType= SUM;
+        else if (typeString == "sum_square")
+            reduceType= SUM_SQUARE;
+        else if (typeString == "l1")
+            reduceType= L1;
+        else if (typeString == "l2")
+            reduceType= L2;
+        else if (typeString == "log_sum")
+            reduceType= LOG_SUM;
+        else if (typeString == "log_sum_exp")
+            reduceType= LOG_SUM_EXP;
+        else if (typeString == "prod")
+            reduceType= PROD;
+        else
+            CV_Error(Error::StsBadArg, "Unknown reduce type\"" + typeString + "\"");
+
+        // set deleted dims
+        CV_Assert(params.has("deleted_dims"));
+        DictValue tempDims = params.get("deleted_dims");
+        int i, n = tempDims.size();
+        reduceDims.resize(n);
+        for (i = 0; i < n; i++)
+        {
+            reduceDims[i] = tempDims.get<int>(i);
+        }
+    }
+
+    virtual bool supportBackend(int backendId) CV_OVERRIDE
+    {
+        if (backendId == DNN_BACKEND_OPENCV)
+        {
+            return true;
+        }
+        return false;
+    }
+
+    // reduceType == MIN
+    struct ReduceOpMIN
+    {
+        float apply(const float* first, const float* last, const float ikarea = 1.0f)
+        {
+            return std::accumulate(first, last, FLT_MAX,
+                                   [](float a, float b)
+                                   {
+                                       return std::min(a, b);
+                                   });
+        }
+    };
+
+    // reduceType == MAX
+    struct ReduceOpMAX
+    {
+        float apply(const float* first, const float* last, const float ikarea = 1.0f)
+        {
+            return std::accumulate(first, last, -FLT_MAX,
+                                   [](float a, float b)
+                                   {
+                                       return std::max(a, b);
+                                   });
+        }
+    };
+
+    // reduceType == SUM
+    struct ReduceOpSUM
+    {
+        float apply(const float* first, const float* last, const float ikarea = 1.0f)
+        {
+            return std::accumulate(first, last, 0.f);
+        }
+    };
+
+    // reduceType == AVE
+    struct ReduceOpAVE
+    {
+        float apply(const float* first, const float* last, const float ikarea = 1.0f)
+        {
+            float output = std::accumulate(first, last, 0.f);
+            return output * ikarea;
+        }
+    };
+
+    // reduceType == SUM_SQUARE
+    struct ReduceOpSUM_SQUARE
+    {
+        float apply(const float* first, const float* last, const float ikarea = 1.0f)
+        {
+            return std::accumulate(first, last, 0.f,
+                                   [](float a, float b)
+                                   {
+                                       return a + b * b;
+                                   });
+        }
+    };
+
+    // reduceType == L1
+    struct ReduceOpL1
+    {
+        float apply(const float* first, const float* last, const float ikarea = 1.0f)
+        {
+            return std::accumulate(first, last, 0.f,
+                                   [](float a, float b)
+                                   {
+                                       return a + std::abs(b);
+                                   });
+        }
+    };
+
+    // reduceType == L2
+    struct ReduceOpL2
+    {
+        float apply(const float* first, const float* last, const float ikarea = 1.0f)
+        {
+            float output = std::accumulate(first, last, 0.f,
+                                           [](float a, float b)
+                                           {
+                                               return a + b * b;
+                                           });
+            return std::sqrt(output);
+        }
+    };
+
+    // reduceType == PROD
+    struct ReduceOpPROD
+    {
+        float apply(const float* first, const float* last, const float ikarea = 1.0f)
+        {
+            return std::accumulate(first, last, 1.0f, std::multiplies<float>());
+        }
+    };
+
+    // reduceType == LOG_SUM
+    struct ReduceOpLOG_SUM
+    {
+        float apply(const float* first, const float* last, const float ikarea = 1.0f)
+        {
+            float output = std::accumulate(first, last, 0.0f);
+            return std::log(output);
+        }
+    };
+
+    // reduceType == LOG_SUM_EXP
+    struct ReduceOpLOG_SUM_EXP
+    {
+        float apply(const float* first, const float* last, const float ikarea = 1.0f)
+        {
+            float output = std::accumulate(first, last, 0.0f,
+                                           [](float a, float b)
+                                           {
+                                               return a + std::exp(b);
+                                           });
+            return std::log(output);
+        }
+    };
+
+    template<typename Func>
+    class ReduceInvoker : public ParallelLoopBody
+    {
+    public:
+        const Mat* src;
+        Mat *dst;
+        std::vector<size_t> reduceDims;
+        int nstripes;
+        int reduceType;
+        Ptr<Func> func;
+
+        ReduceInvoker() : src(0), dst(0), nstripes(0), reduceType(MAX), func(makePtr<Func>()) {}
+
+        static void run(const Mat& src, Mat& dst, std::vector<size_t> reduceDims, int reduceType, int nstripes)
+        {
+            CV_Assert_N( src.isContinuous(), dst.isContinuous(), src.type() == CV_32F, src.type() == dst.type());
+
+            ReduceInvoker<Func> p;
+
+            p.src = &src;
+            p.dst = &dst;
+
+            p.reduceDims = reduceDims;
+            p.nstripes = nstripes;
+            p.reduceType = reduceType;
+
+            parallel_for_(Range(0, nstripes), p, nstripes);
+        }
+
+        void operator()(const Range& r) const CV_OVERRIDE
+        {
+            size_t total = dst->total();
+            size_t stripeSize = (total + nstripes - 1)/nstripes;
+            size_t stripeStart = r.start*stripeSize;
+            size_t stripeEnd = std::min(r.end*stripeSize, total);
+            size_t stride_w = std::accumulate(reduceDims.begin(), reduceDims.end(), 1, std::multiplies<size_t>());
+
+            float *dstData = (float *)dst->data;
+            float *srcData = (float *)src->data;
+
+            for (size_t ofs = stripeStart; ofs < stripeEnd;)
+            {
+                const float* first = srcData + ofs * stride_w;
+                const float* last = srcData + (ofs + 1) * stride_w;
+
+                if (ofs < stripeEnd)
+                {
+                    dstData[ofs] = func->apply(first, last, 1.0 / stride_w);
+                    ofs += 1;
+                }
+            }
+        }
+    };
+
+    void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
+    {
+        CV_TRACE_FUNCTION();
+        CV_TRACE_ARG_VALUE(name, "name", name.c_str());
+
+        if (inputs_arr.depth() == CV_16S)
+        {
+            forward_fallback(inputs_arr, outputs_arr, internals_arr);
+            return;
+        }
+
+        std::vector<Mat> inputs, outputs;
+        inputs_arr.getMatVector(inputs);
+        outputs_arr.getMatVector(outputs);
+        CV_Assert(inputs.size() == 1 || (inputs.size() == 2 && reduceType== SUM));
+        const int nstripes = getNumThreads();
+
+        switch (reduceType)
+        {
+            case MIN:
+            {
+                ReduceInvoker<ReduceOpMIN>::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes);
+                break;
+            }
+            case MAX:
+            {
+                ReduceInvoker<ReduceOpMAX>::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes);
+                break;
+            }
+            case AVE:
+            {
+                ReduceInvoker<ReduceOpAVE>::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes);
+                break;
+            }
+            case SUM:
+            {
+                ReduceInvoker<ReduceOpSUM>::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes);
+                break;
+            }
+            case L1:
+            {
+                ReduceInvoker<ReduceOpL1>::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes);
+                break;
+            }
+            case L2:
+            {
+                ReduceInvoker<ReduceOpL2>::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes);
+                break;
+            }
+            case SUM_SQUARE:
+            {
+                ReduceInvoker<ReduceOpSUM_SQUARE>::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes);
+                break;
+            }
+            case PROD:
+            {
+                ReduceInvoker<ReduceOpPROD>::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes);
+                break;
+            }
+            case LOG_SUM:
+            {
+                ReduceInvoker<ReduceOpLOG_SUM>::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes);
+                break;
+            }
+            case LOG_SUM_EXP:
+            {
+                ReduceInvoker<ReduceOpLOG_SUM_EXP>::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes);
+                break;
+            }
+            default:
+                CV_Error(Error::StsNotImplemented, "Not implemented");
+                break;
+        }
+    }
+
+    bool getMemoryShapes(const std::vector<MatShape> &inputs,
+                         const int requiredOutputs,
+                         std::vector<MatShape> &outputs,
+                         std::vector<MatShape> &internals) const CV_OVERRIDE
+    {
+        CV_Assert(inputs.size() > 0);
+        CV_Assert(reduceDims.size() != 0 && inputs[0].size() >= reduceDims.size());
+
+        std::vector<int> outShape;
+        if (inputs[0].size() == reduceDims.size())
+            outShape.push_back(1);
+        else
+        {
+            for (int i = 0; i < inputs[0].size() - reduceDims.size(); i++)
+            {
+                outShape.push_back(inputs[0][i]);
+            }
+        }
+        outputs.assign(1, outShape);
+
+        return false;
+    }
+
+    virtual bool tryQuantize(const std::vector<std::vector<float> > &scales,
+                             const std::vector<std::vector<int> > &zeropoints, LayerParams& params) CV_OVERRIDE
+    {
+        if (reduceType== MAX || reduceType== MIN)
+        {
+            return true;
+        }
+        return false;
+    }
+
+    virtual int64 getFLOPS(const std::vector<MatShape> &inputs,
+                           const std::vector<MatShape> &outputs) const CV_OVERRIDE
+    {
+        CV_UNUSED(inputs); // suppress unused variable warning
+        long flops = 0;
+        size_t stride_w = std::accumulate(reduceDims.begin(), reduceDims.end(), 1, std::multiplies<size_t>());
+        for (int i = 0; i < outputs.size(); i++)
+        {
+            flops += total(outputs[i])*(stride_w);
+        }
+        return flops;
+    }
+private:
+    enum ReduceType
+    {
+        MAX,
+        MIN,
+        AVE,
+        SUM,
+        L1,
+        L2,
+        PROD,
+        SUM_SQUARE,
+        LOG_SUM,
+        LOG_SUM_EXP
+    };
+};
+
+Ptr<ReduceLayer> ReduceLayer::create(const LayerParams& params)
+{
+    return Ptr<ReduceLayer>(new ReduceLayerImpl(params));
+}
+
+}
+}
index b8ee2d3..ef1857a 100644 (file)
@@ -133,7 +133,9 @@ Net Net::Impl::quantize(InputArrayOfArrays calibData, int inputsDtype, int outpu
         if (ld.type == "Blank" || ld.type == "Dropout" || ld.type == "Identity" || ld.type == "Silence" ||
             ld.type == "Flatten" || ld.type == "Padding" || ld.type == "Permute" || ld.type == "Reshape" ||
             ld.type == "ReLU6" || ld.type == "Reorg" || ld.type == "ShuffleChannel" || ld.type == "Resize" ||
-           (ld.type == "ReLU" && !ld.params.get<float>("negative_slope", 0.f)) /* ReLU with negative slope 0 */)
+           (ld.type == "ReLU" && !ld.params.get<float>("negative_slope", 0.f)) || /* ReLU with negative slope 0 */
+           (ld.type == "Reduce" && (toLowerCase(ld.params.get<String>("reduce")) == "max" ||
+            toLowerCase(ld.params.get<String>("reduce")) == "min")))
         {
             for (int i = 0; i < ld.outputBlobs.size(); i++)
             {
index 62569d8..5713c02 100644 (file)
@@ -122,6 +122,7 @@ private:
     void parseMaxUnpool            (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
     void parseMaxPool              (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
     void parseAveragePool          (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
+    void parseGlobalPool           (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
     void parseReduce               (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
     void parseSlice                (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
     void parseSplit                (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
@@ -1087,7 +1088,7 @@ void ONNXImporter::parseAveragePool(LayerParams& layerParams, const opencv_onnx:
     addLayer(layerParams, node_proto);
 }
 
-void ONNXImporter::parseReduce(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto_)
+void ONNXImporter::parseGlobalPool(LayerParams &layerParams, const opencv_onnx::NodeProto &node_proto_)
 {
     opencv_onnx::NodeProto node_proto = node_proto_;
     const std::string& layer_type = node_proto.op_type();
@@ -1096,157 +1097,176 @@ void ONNXImporter::parseReduce(LayerParams& layerParams, const opencv_onnx::Node
     CV_Assert(node_proto.input_size() == 1);
     layerParams.type = "Pooling";
     String pool;
-    if (layer_type == "GlobalMaxPool" || layer_type == "ReduceMax")
+    if (layer_type == "GlobalMaxPool")
         pool = "MAX";
-    else if (layer_type == "ReduceSum")
-        pool = "SUM";
-    else
+    else if (layer_type == "GlobalAveragePool")
         pool = "AVE";
+    else
+        CV_Error(Error::StsNotImplemented, "Unsupported Pooling type of " + layer_type + " operation.");
+
+    CV_Assert(!layerParams.has("axes"));
+    layerParams.set("global_pooling", true);
     layerParams.set("pool", pool);
-    layerParams.set("global_pooling", !layerParams.has("axes"));
+    addLayer(layerParams, node_proto);
+}
+
+void ONNXImporter::parseReduce(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto_)
+{
+    opencv_onnx::NodeProto node_proto = node_proto_;
+    const std::string& layer_type = node_proto.op_type();
+    const std::string output_name = node_proto.output(0);
+    int depth = layerParams.get<int>("depth", CV_32F);
+
+    CV_Assert(node_proto.input_size() <= 2);
+    String reduceType;
+
+    if (layer_type == "ReduceMax")
+        reduceType = "MAX";
+    else if (layer_type == "ReduceMin")
+        reduceType = "MIN";
+    else if (layer_type == "ReduceSum")
+        reduceType = "SUM";
+    else if (layer_type == "ReduceSumSquare")
+        reduceType = "SUM_SQUARE";
+    else if (layer_type == "ReduceProd")
+        reduceType = "PROD";
+    else if (layer_type == "ReduceL1")
+        reduceType = "L1";
+    else if (layer_type == "ReduceL2")
+        reduceType = "L2";
+    else if (layer_type == "ReduceLogSum")
+        reduceType = "LOG_SUM";
+    else if (layer_type == "ReduceLogSumExp")
+        reduceType = "LOG_SUM_EXP";
+    else if (layer_type == "ReduceMean")
+        reduceType = "AVE";
+    else
+        CV_Error(Error::StsNotImplemented, "Unsupported Pooling type of " + layer_type + " operation.");
+
+    // The ReduceInt8 can only support "MAX" and "MIN".
+    if (depth == CV_8S)
+    {
+        CV_Assert(reduceType == "MAX" || reduceType == "MIN");
+    }
+
+    layerParams.type = (depth == CV_8S) ? "ReduceInt8" : "Reduce";
+    layerParams.set("reduce", reduceType);
     bool keepdims = layerParams.get<int>("keepdims", 1) == 1;
-    if (layerParams.has("axes") && (layer_type == "ReduceMean" || layer_type == "ReduceSum" || layer_type == "ReduceMax"))
+
+    if (layer_type == "ReduceSum" && node_proto.input_size() == 2)
+    {
+        // TODO support the opset 13 of ReduceSum.
+        //  in opset 13, the ReduceSum has two input, it takes axes as input instead of attribute
+        //  details:https://github.com/onnx/onnx/issues/3420#issuecomment-844295687
+        CV_Error(Error::StsNotImplemented, "Unsupported " + layer_type + " operation of opset 13, please try to "
+                                                                         "re-export the onnx model with opset 11.");
+    }
+
+    MatShape inpShape = outShapes[node_proto.input(0)];
+    std::vector<bool> shouldDelete(inpShape.size(), false);
+
+    if (layerParams.has("axes"))
     {
-        MatShape inpShape = outShapes[node_proto.input(0)];
         DictValue axes = layerParams.get("axes");
-        MatShape targetShape;
-        std::vector<bool> shouldDelete(inpShape.size(), false);
-        for (int i = 0; i < axes.size(); i++) {
+        for (int i = 0; i < axes.size(); i++)
+        {
             int axis = normalize_axis(axes.get<int>(i), inpShape.size());
             shouldDelete[axis] = true;
         }
-        for (int axis = 0; axis < inpShape.size(); ++axis){
-            if (!shouldDelete[axis])
-                targetShape.push_back(inpShape[axis]);
-            else if (keepdims)
-                targetShape.push_back(1);
+    }
+    else
+    {
+        for (int i = 0; i < inpShape.size(); i++)
+        {
+            shouldDelete[i] = true;
         }
+    }
 
-        if (inpShape.size() == 3 && axes.size() <= 2)
+    MatShape targetShape;
+    for (int i = 0; i < inpShape.size(); ++i)
+    {
+        if (!shouldDelete[i])
         {
-            int axis = normalize_axis(axes.get<int>(0), inpShape.size());
-            CV_CheckNE(axis, 0, "");
-
-            LayerParams reshapeLp;
-            reshapeLp.name = layerParams.name + "/reshape";
-            reshapeLp.type = "Reshape";
-            CV_Assert(layer_id.find(reshapeLp.name) == layer_id.end());
-            reshapeLp.set("axis", 0);
-            reshapeLp.set("num_axes", 1);
-            int newShape[] = {1, -1};
-            reshapeLp.set("dim", DictValue::arrayInt(&newShape[0], 2));
+            targetShape.push_back(inpShape[i]);
+        }
+        else if (keepdims)
+        {
+            targetShape.push_back(1);
+        }
+    }
 
-            opencv_onnx::NodeProto proto;
-            proto.add_input(node_proto.input(0));
-            proto.add_output(reshapeLp.name);
-            addLayer(reshapeLp, proto);
+    if (targetShape.empty())
+        targetShape.push_back(1);
 
-            LayerParams avgLp;
-            avgLp.name = layerParams.name + "/avg";
-            avgLp.type = "Pooling";
-            CV_Assert(layer_id.find(avgLp.name) == layer_id.end());
-            avgLp.set("pool", pool);
-            if (axes.size() == 2)
-            {
-                CV_CheckEQ(normalize_axis(axes.get<int>(0), inpShape.size()), 1, "Unsupported mode");
-                CV_CheckEQ(normalize_axis(axes.get<int>(1), inpShape.size()), 2, "Unsupported mode");
-                avgLp.set("global_pooling", true);
-            }
-            else
-            {
-                avgLp.set(axis == 2 ? "global_pooling_w" : "global_pooling_h", true);
-                avgLp.set(axis == 2 ? "kernel_h" : "kernel_w", 1);
-            }
+    // Using PermuteLayer to move the deleted axis to the last.
+    std::vector<int> perm(inpShape.size(), 0);
+    for (int i = 0; i < inpShape.size(); i++)
+        perm[i] = i;
 
-            node_proto.set_input(0, reshapeLp.name);
-            node_proto.set_output(0, avgLp.name);
-            addLayer(avgLp, node_proto);
-        }
-        else
+    bool needPermuet = false;
+    for (int i = 0; i < inpShape.size(); i++)
+    {
+        if (shouldDelete[i])
         {
-            if (inpShape.size() != 4 && inpShape.size() != 5)
-                CV_Error(Error::StsNotImplemented, "Unsupported input shape of " + layer_type + " operation.");
+            // find the first not deleted element.
+            std::vector<bool>::iterator iter = std::find(shouldDelete.begin() + i, shouldDelete.end(), false);
 
-            CV_Assert(axes.size() <= inpShape.size() - 2);
-            std::vector<int> kernel_size(inpShape.size() - 2, 1);
-            if (axes.size() == 1 && (normalize_axis(axes.get<int>(0), inpShape.size()) <= 1))
-            {
-                int axis = normalize_axis(axes.get<int>(0), inpShape.size());
-                MatShape newShape = inpShape;
-                newShape[axis + 1] = total(newShape, axis + 1);
-                newShape.resize(axis + 2);
-                newShape.insert(newShape.begin(), 2 - axis, 1);
-
-                LayerParams reshapeLp;
-                reshapeLp.type = "Reshape";
-                reshapeLp.name = layerParams.name + "/reshape";
-                CV_Assert(layer_id.find(reshapeLp.name) == layer_id.end());
-                reshapeLp.set("dim", DictValue::arrayInt(&newShape[0], newShape.size()));
-
-                node_proto.set_output(0, reshapeLp.name);
-                addLayer(reshapeLp, node_proto);
-
-                kernel_size.resize(2);
-                kernel_size[0] = inpShape[axis];
-                node_proto.set_input(0, node_proto.output(0));
-            }
-            else
+            if (iter != shouldDelete.end())
             {
-                for (int i = 0; i < axes.size(); i++) {
-                    int axis = normalize_axis(axes.get<int>(i), inpShape.size());
-                    CV_Assert_N(axis >= 2 + i, axis < inpShape.size());
-                    kernel_size[axis - 2] = inpShape[axis];
-                }
-            }
+                int index = iter - shouldDelete.begin();
 
-            LayerParams poolLp = layerParams;
-            poolLp.name = layerParams.name + "/avg";
-            CV_Assert(layer_id.find(poolLp.name) == layer_id.end());
-            poolLp.set("kernel_size", DictValue::arrayInt(&kernel_size[0], kernel_size.size()));
+                bool temp = shouldDelete[index];
+                shouldDelete[index] = shouldDelete[i];
+                shouldDelete[i] = temp;
 
-            node_proto.set_output(0, poolLp.name);
-            addLayer(poolLp, node_proto);
+                std::swap(perm[index], perm[i]);
+                std::swap(inpShape[index], inpShape[i]);
+                needPermuet = true;
+            }
+            else
+                break;
         }
+    }
 
-        layerParams.type = "Reshape";
-        layerParams.set("dim", DictValue::arrayInt(&targetShape[0], targetShape.size()));
+    auto inputString= node_proto.input(0);
+    if (needPermuet)
+    {
+        LayerParams permuteLp;
+        permuteLp.name = layerParams.name + "/permute";
+        permuteLp.type = (depth == CV_8S) ? "PermuteInt8" : "Permute";
+        permuteLp.set("order", DictValue::arrayInt(perm.data(), perm.size()));
 
-        node_proto.set_input(0, node_proto.output(0));
-        node_proto.set_output(0, output_name);
+        opencv_onnx::NodeProto protoPermute;
+        protoPermute.add_input(inputString);
+        protoPermute.add_output(permuteLp.name);
+        addLayer(permuteLp, protoPermute);
+        inputString = permuteLp.name;
     }
-    else if (!layerParams.has("axes") && (layer_type == "ReduceMean" || layer_type == "ReduceSum" || layer_type == "ReduceMax"))
-    {
-        IterShape_t shapeIt = outShapes.find(node_proto.input(0));
-        CV_Assert(shapeIt != outShapes.end());
-        const size_t dims = keepdims ? shapeIt->second.size() : 1;
 
-        LayerParams reshapeLp;
-        reshapeLp.name = layerParams.name + "/reshape";
-        reshapeLp.type = "Reshape";
-        CV_Assert(layer_id.find(reshapeLp.name) == layer_id.end());
-        int newShape[] = {1, 1, 1, -1};
-        reshapeLp.set("dim", DictValue::arrayInt(&newShape[0], 4));
+    std::vector<int> deletedDims;
+    for (int axis_i = 0; axis_i < inpShape.size(); ++axis_i)
+    {
+        if (shouldDelete[axis_i])
+        {
+            deletedDims.push_back(inpShape[axis_i]);
+        }
+    }
 
-        opencv_onnx::NodeProto proto;
-        proto.add_input(node_proto.input(0));
-        proto.add_output(reshapeLp.name);
-        addLayer(reshapeLp, proto);
+    LayerParams reduceLp = layerParams;
+    reduceLp.name = layerParams.name + "/reduce";
+    CV_Assert(layer_id.find(reduceLp.name) == layer_id.end());
+    reduceLp.set("deleted_dims", DictValue::arrayInt(&deletedDims[0], deletedDims.size()));
 
-        LayerParams poolLp = layerParams;
-        poolLp.name = layerParams.name + "/pool";
-        CV_Assert(layer_id.find(poolLp.name) == layer_id.end());
+    node_proto.set_input(0, inputString);
+    node_proto.set_output(0, reduceLp.name);
+    addLayer(reduceLp, node_proto);
 
-        node_proto.set_input(0, reshapeLp.name);
-        node_proto.set_output(0, poolLp.name);
-        addLayer(poolLp, node_proto);
+    layerParams.type = (depth == CV_8S) ? "ReshapeInt8" : "Reshape";
+    layerParams.set("dim", DictValue::arrayInt(&targetShape[0], targetShape.size()));
 
-        layerParams.type = "Reshape";
-        std::vector<int> targetShape(dims, 1);
-        layerParams.set("dim", DictValue::arrayInt(targetShape.data(), targetShape.size()));
+    node_proto.set_input(0, node_proto.output(0));
+    node_proto.set_output(0, output_name);
 
-        node_proto.set_input(0, node_proto.output(0));
-        node_proto.set_output(0, output_name);
-    }
     addLayer(layerParams, node_proto);
 }
 
@@ -3406,8 +3426,10 @@ void ONNXImporter::buildDispatchMap_ONNX_AI(int opset_version)
     dispatch["MaxUnpool"] = &ONNXImporter::parseMaxUnpool;
     dispatch["MaxPool"] = &ONNXImporter::parseMaxPool;
     dispatch["AveragePool"] = &ONNXImporter::parseAveragePool;
-    dispatch["GlobalAveragePool"] = dispatch["GlobalMaxPool"] = dispatch["ReduceMean"] = dispatch["ReduceSum"] =
-            dispatch["ReduceMax"] = &ONNXImporter::parseReduce;
+    dispatch["GlobalAveragePool"] = dispatch["GlobalMaxPool"] = &ONNXImporter::parseGlobalPool;
+    dispatch["ReduceMax"] = dispatch["ReduceMin"] = dispatch["ReduceMean"] = dispatch["ReduceSum"] = dispatch["ReduceMax"] =
+            dispatch["ReduceMin"] = dispatch["ReduceSumSquare"] = dispatch["ReduceProd"] = dispatch["ReduceL1"] =
+            dispatch["ReduceL2"] = dispatch["ReduceLogSum"] = dispatch["ReduceLogSumExp"] = &ONNXImporter::parseReduce;
     dispatch["Slice"] = &ONNXImporter::parseSlice;
     dispatch["Split"] = &ONNXImporter::parseSplit;
     dispatch["Add"] = dispatch["Sum"] = dispatch["Sub"] = &ONNXImporter::parseBias;
index ccd1568..c2425d4 100644 (file)
 "test_split_equal_parts_2d",
 "test_split_equal_parts_default_axis",
 "test_tan",
+"test_reduce_l2_default_axes_keepdims_example", // Expected: (normL1) <= (l1), actual: 0.00490189 vs 0.004
+"test_reduce_log_sum_exp_default_axes_keepdims_example", // Expected: (normL1) <= (l1), actual: 0.00671387 vs 0.004
+"test_reduce_prod_default_axes_keepdims_example", // Expected: (normL1) <= (l1), actual: inf vs 0.004
+"test_reduce_prod_default_axes_keepdims_random", // Expected: (normL1) <= (l1), actual: 18.6621 vs 0.004, Expected: (normInf) <= (lInf), actual: 18.6621 vs 0.02
+"test_reduce_prod_do_not_keepdims_random", // Expected: (normL1) <= (l1), actual: 0.00436729 vs 0.004, Expected: (normInf) <= (lInf), actual: 0.0201836 vs 0.02
+"test_reduce_prod_keepdims_random", // Expected: (normL1) <= (l1), actual: 0.00436729 vs 0.004, Expected: (normInf) <= (lInf), actual: 0.0201836 vs 0.02
+"test_reduce_prod_negative_axes_keepdims_random", // Expected: (normL1) <= (l1), actual: 0.00436729 vs 0.004, Expected: (normInf) <= (lInf), actual: 0.0201836 vs 0.02
+"test_reduce_sum_square_default_axes_keepdims_random", // Expected: (normL1) <= (l1), actual: 0.0183411 vs 0.004
+"test_reduce_sum_square_do_not_keepdims_random", // Expected: (normL1) <= (l1), actual: 0.010789 vs 0.004, Expected: (normInf) <= (lInf), actual: 0.0290298 vs 0.02
+"test_reduce_sum_square_keepdims_random", // Expected: (normL1) <= (l1), actual: 0.010789 vs 0.004, Expected: (normInf) <= (lInf), actual: 0.0290298 vs 0.02
+"test_reduce_sum_square_negative_axes_keepdims_random", // Expected: (normL1) <= (l1), actual: 0.010789 vs 0.004, Expected: (normInf) <= (lInf), actual: 0.0290298 vs 0.02
\ No newline at end of file
index e5d0ead..eef4214 100644 (file)
 "test_range_int32_type_negative_delta_expanded",
 "test_reciprocal",
 "test_reciprocal_example",
-"test_reduce_l1_default_axes_keepdims_example",
-"test_reduce_l1_default_axes_keepdims_random",
-"test_reduce_l1_do_not_keepdims_example",
-"test_reduce_l1_do_not_keepdims_random",
-"test_reduce_l1_keep_dims_example",
-"test_reduce_l1_keep_dims_random",
-"test_reduce_l1_negative_axes_keep_dims_example",
-"test_reduce_l1_negative_axes_keep_dims_random",
-"test_reduce_l2_default_axes_keepdims_example",
-"test_reduce_l2_default_axes_keepdims_random",
-"test_reduce_l2_do_not_keepdims_example",
-"test_reduce_l2_do_not_keepdims_random",
-"test_reduce_l2_keep_dims_example",
-"test_reduce_l2_keep_dims_random",
-"test_reduce_l2_negative_axes_keep_dims_example",
-"test_reduce_l2_negative_axes_keep_dims_random",
-"test_reduce_log_sum",
-"test_reduce_log_sum_asc_axes",
-"test_reduce_log_sum_default",
-"test_reduce_log_sum_desc_axes",
-"test_reduce_log_sum_exp_default_axes_keepdims_example",
-"test_reduce_log_sum_exp_default_axes_keepdims_random",
-"test_reduce_log_sum_exp_do_not_keepdims_example",
-"test_reduce_log_sum_exp_do_not_keepdims_random",
-"test_reduce_log_sum_exp_keepdims_example",
-"test_reduce_log_sum_exp_keepdims_random",
-"test_reduce_log_sum_exp_negative_axes_keepdims_example",
-"test_reduce_log_sum_exp_negative_axes_keepdims_random",
-"test_reduce_log_sum_negative_axes",
-"test_reduce_min_default_axes_keepdims_example",
-"test_reduce_min_default_axes_keepdims_random",
-"test_reduce_min_do_not_keepdims_example",
-"test_reduce_min_do_not_keepdims_random",
-"test_reduce_min_keepdims_example",
-"test_reduce_min_keepdims_random",
-"test_reduce_min_negative_axes_keepdims_example",
-"test_reduce_min_negative_axes_keepdims_random",
-"test_reduce_prod_default_axes_keepdims_example",
-"test_reduce_prod_default_axes_keepdims_random",
-"test_reduce_prod_do_not_keepdims_example",
-"test_reduce_prod_do_not_keepdims_random",
-"test_reduce_prod_keepdims_example",
-"test_reduce_prod_keepdims_random",
-"test_reduce_prod_negative_axes_keepdims_example",
-"test_reduce_prod_negative_axes_keepdims_random",
 "test_reduce_sum_default_axes_keepdims_example",
 "test_reduce_sum_default_axes_keepdims_random",
 "test_reduce_sum_do_not_keepdims_example",
 "test_reduce_sum_keepdims_random",
 "test_reduce_sum_negative_axes_keepdims_example",
 "test_reduce_sum_negative_axes_keepdims_random",
-"test_reduce_sum_square_default_axes_keepdims_example",
-"test_reduce_sum_square_default_axes_keepdims_random",
-"test_reduce_sum_square_do_not_keepdims_example",
-"test_reduce_sum_square_do_not_keepdims_random",
-"test_reduce_sum_square_keepdims_example",
-"test_reduce_sum_square_keepdims_random",
-"test_reduce_sum_square_negative_axes_keepdims_example",
-"test_reduce_sum_square_negative_axes_keepdims_random",
 "test_reflect_pad",
 "test_reshape_allowzero_reordered",
 "test_reshape_extended_dims",