implementation of scatter and scatternd with conformance tests enabled
authorfengyuentau <yuantao.feng@opencv.org.cn>
Sun, 18 Sep 2022 14:13:55 +0000 (22:13 +0800)
committerfengyuentau <yuantao.feng@opencv.org.cn>
Mon, 17 Oct 2022 03:30:32 +0000 (11:30 +0800)
14 files changed:
modules/dnn/include/opencv2/dnn/all_layers.hpp
modules/dnn/perf/perf_layer.cpp
modules/dnn/src/init.cpp
modules/dnn/src/layers/scatterND_layer.cpp [new file with mode: 0644]
modules/dnn/src/layers/scatter_layer.cpp [new file with mode: 0644]
modules/dnn/src/onnx/onnx_importer.cpp
modules/dnn/test/test_onnx_conformance.cpp
modules/dnn/test/test_onnx_conformance_layer_filter__cuda_denylist.inl.hpp
modules/dnn/test/test_onnx_conformance_layer_filter__halide_denylist.inl.hpp
modules/dnn/test/test_onnx_conformance_layer_filter__openvino.inl.hpp
modules/dnn/test/test_onnx_conformance_layer_filter__vulkan_denylist.inl.hpp
modules/dnn/test/test_onnx_conformance_layer_filter_opencv_ocl_fp16_denylist.inl.hpp
modules/dnn/test/test_onnx_conformance_layer_filter_opencv_ocl_fp32_denylist.inl.hpp
modules/dnn/test/test_onnx_conformance_layer_parser_denylist.inl.hpp

index f87a46ba5e26f7636080b305359b3cb7a585921a..a0711b868cdd8b5710fe29b84e3999829c8d7c27 100644 (file)
@@ -1065,6 +1065,18 @@ CV__DNN_INLINE_NS_BEGIN
         static Ptr<CumSumLayer> create(const LayerParams& params);
     };
 
+    class CV_EXPORTS ScatterLayer : public Layer
+    {
+    public:
+        static Ptr<ScatterLayer> create(const LayerParams& params);
+    };
+
+    class CV_EXPORTS ScatterNDLayer : public Layer
+    {
+    public:
+        static Ptr<ScatterNDLayer> create(const LayerParams& params);
+    };
+
 //! @}
 //! @}
 CV__DNN_INLINE_NS_END
index 03ba8ab0e9ede4fa465b639ae36e628f6678eb1e..f169f4e6a8181101993b0d177b50e0c8bb74829d 100644 (file)
@@ -239,7 +239,178 @@ PERF_TEST_P_(Layer_Slice, FastNeuralStyle_eccv16)
     test_slice<4>(inputShape, begin, end);
 }
 
+struct Layer_Scatter : public TestBaseWithParam<tuple<Backend, Target> >
+{
+    void test_layer(const std::vector<int>& shape, const String reduction = "none", int axis = 0)
+    {
+        int backendId = get<0>(GetParam());
+        int targetId = get<1>(GetParam());
+
+        Mat data(shape, CV_32FC1);
+        Mat indices(shape, CV_32FC1);
+        Mat updates(shape, CV_32FC1);
+
+        Scalar mean = 0.f;
+        Scalar std = 1.f;
+        randn(data, mean, std);
+        randu(indices, 0, shape[axis]);
+        randn(updates, mean, std);
+
+        indices.convertTo(indices, CV_32SC1, 1, -1);
+
+        Net net;
+        LayerParams lp;
+        lp.type = "Scatter";
+        lp.name = "testLayer";
+        lp.set("reduction", reduction);
+        lp.set("axis", axis);
+
+        int id = net.addLayerToPrev(lp.name, lp.type, lp);
+        net.connect(0, 0, id, 0);
+        net.connect(0, 1, id, 1);
+        net.connect(0, 2, id, 2);
+
+        // warmup
+        {
+            std::vector<String> inpNames(3);
+            inpNames[0] = "data";
+            inpNames[1] = "indices";
+            inpNames[2] = "updates";
+            net.setInputsNames(inpNames);
+            net.setInput(data, inpNames[0]);
+            net.setInput(indices, inpNames[1]);
+            net.setInput(updates, inpNames[2]);
+
+            net.setPreferableBackend(backendId);
+            net.setPreferableTarget(targetId);
+            Mat out = net.forward();
+        }
+
+        TEST_CYCLE()
+        {
+            Mat res = net.forward();
+        }
+
+        SANITY_CHECK_NOTHING();
+    }
+
+    int N = 8;
+    int C = 256;
+    int H = 128;
+    int W = 100;
+};
+
+PERF_TEST_P_(Layer_Scatter, DISABLED_Scatter)
+{
+    test_layer({N, C, H, W});
+}
+
+PERF_TEST_P_(Layer_Scatter, DISABLED_Scatter_add)
+{
+    test_layer({N, C, H, W}, "add");
+}
+
+struct Layer_ScatterND : public TestBaseWithParam<tuple<Backend, Target> >
+{
+    void test_layer(const std::vector<int>& shape, const String reduction = "none")
+    {
+        int backendId = get<0>(GetParam());
+        int targetId = get<1>(GetParam());
+
+        std::vector<int> indices_shape(shape);
+        indices_shape.push_back(int(shape.size()));
+        Mat data(shape, CV_32FC1);
+        Mat indices(indices_shape, CV_32FC1);
+        Mat updates(shape, CV_32FC1);
+
+        Scalar mean = 0.f;
+        Scalar std = 1.f;
+        randn(data, mean, std);
+        randn(updates, mean, std);
+
+        // initialize the indices with index tuples like [0...N, 0...C, 0...H, 0...W]
+        std::vector<int> current_index_tuple(shape.size());
+        int total = data.total();
+        std::vector<int> indices_step;
+        for (int i = 0; i < indices.dims; i++)
+        {
+            int step = indices.step.p[i] / sizeof(float);
+            indices_step.push_back(step);
+        }
+        int t, j, idx, offset_at_idx, offset;
+        for (int i = 0; i < total; i++)
+        {
+            t = i;
+            for (j = shape.size() - 1; j >= 0; j--)
+            {
+                idx = t / shape[j];
+                offset_at_idx = (int)(t - idx * shape[j]);
+                current_index_tuple[j] = offset_at_idx;
+                t = idx;
+            }
+
+            offset = 0;
+            for (j = 0; j < shape.size(); j++)
+                offset += current_index_tuple[j] * indices_step[j];
+
+            for (j = 0; j < shape.size(); j++)
+                indices.at<float>(offset + j) = current_index_tuple[j];
+        }
+
+        Net net;
+        LayerParams lp;
+        lp.type = "ScatterND";
+        lp.name = "testLayer";
+        lp.set("reduction", reduction);
+
+        int id = net.addLayerToPrev(lp.name, lp.type, lp);
+        net.connect(0, 0, id, 0);
+        net.connect(0, 1, id, 1);
+        net.connect(0, 2, id, 2);
+
+        // warmup
+        {
+            std::vector<String> inpNames(3);
+            inpNames[0] = "data";
+            inpNames[1] = "indices";
+            inpNames[2] = "updates";
+            net.setInputsNames(inpNames);
+            net.setInput(data, inpNames[0]);
+            net.setInput(indices, inpNames[1]);
+            net.setInput(updates, inpNames[2]);
+
+            net.setPreferableBackend(backendId);
+            net.setPreferableTarget(targetId);
+            Mat out = net.forward();
+        }
+
+        TEST_CYCLE()
+        {
+            Mat res = net.forward();
+        }
+
+        SANITY_CHECK_NOTHING();
+    }
+
+    int N = 8;
+    int C = 256;
+    int H = 128;
+    int W = 100;
+};
+
+PERF_TEST_P_(Layer_ScatterND, DISABLED_ScatterND)
+{
+    test_layer({N, C, H ,W});
+}
+
+PERF_TEST_P_(Layer_ScatterND, DISABLED_ScatterND_add)
+{
+    test_layer({N, C, H , W}, "add");
+}
+
 INSTANTIATE_TEST_CASE_P(/**/, Layer_Slice, dnnBackendsAndTargets(false, false));
 INSTANTIATE_TEST_CASE_P(/**/, Layer_NaryEltwise, testing::Values(std::make_tuple(DNN_BACKEND_OPENCV, DNN_TARGET_CPU)));
+INSTANTIATE_TEST_CASE_P(/**/, Layer_Scatter, testing::Values(std::make_tuple(DNN_BACKEND_OPENCV, DNN_TARGET_CPU)));
+INSTANTIATE_TEST_CASE_P(/**/, Layer_ScatterND, testing::Values(std::make_tuple(DNN_BACKEND_OPENCV, DNN_TARGET_CPU)));
 
 } // namespace
index 902b6dae889ca5a40ead358f700590792ba6dcbe..63bbf2cb3f60430f523053b0b13bf313d546b28c 100644 (file)
@@ -175,6 +175,9 @@ void initializeLayerFactory()
     CV_DNN_REGISTER_LAYER_CLASS(GRU,            GRULayer);
     CV_DNN_REGISTER_LAYER_CLASS(CumSum,         CumSumLayer);
 
+    CV_DNN_REGISTER_LAYER_CLASS(Scatter,        ScatterLayer);
+    CV_DNN_REGISTER_LAYER_CLASS(ScatterND,      ScatterNDLayer);
+
     CV_DNN_REGISTER_LAYER_CLASS(Quantize,         QuantizeLayer);
     CV_DNN_REGISTER_LAYER_CLASS(Dequantize,       DequantizeLayer);
     CV_DNN_REGISTER_LAYER_CLASS(Requantize,       RequantizeLayer);
diff --git a/modules/dnn/src/layers/scatterND_layer.cpp b/modules/dnn/src/layers/scatterND_layer.cpp
new file mode 100644 (file)
index 0000000..648d35f
--- /dev/null
@@ -0,0 +1,202 @@
+// This file is part of OpenCV project.
+// It is subject to the license terms in the LICENSE file found in the top-level directory
+// of this distribution and at http://opencv.org/license.html.
+
+#include "../precomp.hpp"
+#include "layers_common.hpp"
+
+#include <algorithm> // for std::max & std::min
+
+namespace cv { namespace dnn {
+
+class ScatterNDLayerImpl CV_FINAL : public ScatterNDLayer
+{
+public:
+    enum class REDUCTION
+    {
+        NONE = 1,
+        ADD,
+        MUL,
+        MAX,
+        MIN
+    } reduction;
+
+    ScatterNDLayerImpl(const LayerParams& params)
+    {
+        setParamsFrom(params);
+
+        String reduction_name = toLowerCase(params.get<String>("reduction", "none"));
+        if (reduction_name == "none")
+            reduction = REDUCTION::NONE;
+        else if (reduction_name == "add")
+            reduction = REDUCTION::ADD;
+        else if (reduction_name == "mul")
+            reduction = REDUCTION::MUL;
+        else if (reduction_name == "max")
+            reduction = REDUCTION::MAX;
+        else if (reduction_name == "min")
+            reduction = REDUCTION::MIN;
+        else
+            CV_Error(cv::Error::StsBadArg, "Unkown reduction \"" + reduction_name + "\"");
+    }
+
+    virtual bool supportBackend(int backendId) CV_OVERRIDE
+    {
+        return backendId == DNN_BACKEND_OPENCV;
+    }
+
+    virtual bool getMemoryShapes(const std::vector<MatShape> &inputs,
+                                 const int requiredOutputs,
+                                 std::vector<MatShape> &outputs,
+                                 std::vector<MatShape> &internals) const CV_OVERRIDE
+    {
+        CV_CheckEQ(inputs.size(), 3ull, "ScatterND: require three inputs.");
+
+        size_t r = inputs[0].size(), q = inputs[1].size(), p = inputs[2].size(), k = inputs[1].back();
+        CV_CheckEQ(r + q - inputs[1].back() - 1, p, "ScatterND: updates should have rank of data.dims + indices.dims - indices.size[-1] - 1");
+        CV_CheckLE(k, r, "ScatterND: indices.shape[-1] must be less than (or equal to) the rank of input data.");
+
+        for (int i = 0; i < q - 1; i++) // np.ndindex(indices.shape[-1])
+        {
+            CV_CheckEQ(inputs[2][i], inputs[1][i], "ScatterND: updates.shape[0 : rank(indices)-1] must equal to indices.shape[0 : rank(indices)-1].");
+        }
+        for (int i = q - 1, j = k, m = 0; i + m < p; m++)
+        {
+            CV_CheckEQ(inputs[2][i + m], inputs[0][j + m], "ScatterND: updates.shape[rank(indices)-1 : ] must equal to data[indices.shape[-1] : rank(data)-1].");
+        }
+
+        outputs.assign(1, inputs[0]);
+        return false;
+    }
+
+    void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
+    {
+        CV_TRACE_FUNCTION();
+        CV_TRACE_ARG_VALUE(name, "name", name.c_str());
+
+        std::vector<Mat> inputs, outputs;
+        inputs_arr.getMatVector(inputs);
+        outputs_arr.getMatVector(outputs);
+
+        const Mat& data = inputs[0];
+        const Mat& indices = inputs[1];
+        const Mat& updates = inputs[2];
+        Mat& out = outputs[0];
+
+        typeDispatch(outputs[0].type(), data, indices, updates, out);
+    }
+
+    // NOTE: This impl does not check whether indices have duplicate entries.
+    //       The last duplicate entry will overwrite the previous.
+    template<typename T, typename Functor>
+    void forward_impl(const Functor& rd, const Mat& data, const Mat& indices, const Mat& updates, Mat& out)
+    {
+        data.copyTo(out);
+
+        const int* shape = data.size.p;
+        const size_t* step = data.step.p;
+
+        const int ind_ndims = indices.dims;
+        const int* ind_shape = indices.size.p;
+        const T* p_indices = indices.ptr<const T>();
+
+        const int upd_ndims = updates.dims;
+        const int* upd_shape = updates.size.p;
+        const T* p_updates = updates.ptr<const T>();
+
+        T* p_out = out.ptr<T>();
+
+        int k = ind_shape[ind_ndims - 1]; // last dim of indices
+        size_t total = (size_t)(indices.total() / k);
+
+        size_t updates_size = 1;
+        for (int i = ind_ndims - 1; i < upd_ndims; i++)
+            updates_size *= upd_shape[i];
+
+        size_t inp_start_offset = 0;
+        size_t ind_start_offset = 0;
+        size_t upd_start_offset = 0;
+        for (size_t i = 0; i < total; i++, ind_start_offset += k, upd_start_offset += updates_size)
+        {
+            const T* tmp_p_indices = p_indices + ind_start_offset;
+            inp_start_offset = 0;
+            for (int j = 0; j < k; j++)
+            {
+                CV_Assert(tmp_p_indices[j] < shape[j] && tmp_p_indices[j] > -shape[j]);
+                inp_start_offset += (((int)tmp_p_indices[j] + shape[j]) % shape[j]) * step[j];
+            }
+            inp_start_offset /= sizeof(T);
+
+            const T* tmp_p_updates = p_updates + upd_start_offset;
+            T* tmp_p_out = p_out + inp_start_offset;
+            for (int j = 0; j < updates_size; j++)
+                tmp_p_out[j] = rd(tmp_p_out[j], tmp_p_updates[j]);
+        }
+    }
+
+    template<typename... Args>
+    inline void typeDispatch(const int type, Args&&... args)
+    {
+        switch (type)
+        {
+            case CV_8U:
+                reductionDispatch<uint8_t>(std::forward<Args>(args)...);
+                break;
+            case CV_32S:
+                reductionDispatch<int32_t>(std::forward<Args>(args)...);
+                break;
+            case CV_32F:
+                reductionDispatch<float>(std::forward<Args>(args)...);
+                break;
+            default:
+                CV_Error(cv::Error::BadDepth, "Unsupported type.");
+        };
+    }
+
+    template<typename T, typename... Args>
+    inline void reductionDispatch(Args&&... args)
+    {
+        switch (reduction)
+        {
+            case REDUCTION::NONE:
+            {
+                auto rd = [](const T& a, const T& b) { return b; }; // a from input data, b from updates
+                forward_impl<T>(rd, std::forward<Args>(args)...);
+                break;
+            }
+            case REDUCTION::ADD:
+            {
+                auto rd = [](const T& a, const T& b) { return a + b; };
+                forward_impl<T>(rd, std::forward<Args>(args)...);
+                break;
+            }
+            case REDUCTION::MUL:
+            {
+                auto rd = [](const T& a, const T& b) { return a * b; };
+                forward_impl<T>(rd, std::forward<Args>(args)...);
+                break;
+            }
+            case REDUCTION::MAX:
+            {
+                auto rd = [](const T& a, const T& b) { return std::max(a, b); };
+                forward_impl<T>(rd, std::forward<Args>(args)...);
+                break;
+            }
+            case REDUCTION::MIN:
+            {
+                auto rd = [](const T& a, const T& b) { return std::min(a, b); };
+                forward_impl<T>(rd, std::forward<Args>(args)...);
+                break;
+            }
+            default:
+                CV_Error(Error::StsBadArg, "Unsupported reduction.");
+        };
+    }
+};
+
+Ptr<ScatterNDLayer> ScatterNDLayer::create(const LayerParams& params)
+{
+    return makePtr<ScatterNDLayerImpl>(params);
+}
+
+}} // namespace cv::dnn
diff --git a/modules/dnn/src/layers/scatter_layer.cpp b/modules/dnn/src/layers/scatter_layer.cpp
new file mode 100644 (file)
index 0000000..084eecb
--- /dev/null
@@ -0,0 +1,208 @@
+// This file is part of OpenCV project.
+// It is subject to the license terms in the LICENSE file found in the top-level directory
+// of this distribution and at http://opencv.org/license.html.
+
+#include "../precomp.hpp"
+#include "layers_common.hpp"
+
+#include <algorithm> // for std::max & std::min
+
+namespace cv { namespace dnn {
+
+class ScatterLayerImpl CV_FINAL : public ScatterLayer
+{
+public:
+    enum class REDUCTION
+    {
+        NONE = 1,
+        ADD,
+        MUL,
+        MAX,
+        MIN
+    } reduction;
+
+    ScatterLayerImpl(const LayerParams& params)
+    {
+        setParamsFrom(params);
+
+        axis = params.get<int>("axis", 0);
+        String reduction_name = toLowerCase(params.get<String>("reduction", "none"));
+        if (reduction_name == "none")
+            reduction = REDUCTION::NONE;
+        else if (reduction_name == "add")
+            reduction = REDUCTION::ADD;
+        else if (reduction_name == "mul")
+            reduction = REDUCTION::MUL;
+        else if (reduction_name == "max")
+            reduction = REDUCTION::MAX;
+        else if (reduction_name == "min")
+            reduction = REDUCTION::MIN;
+        else
+            CV_Error(cv::Error::StsBadArg, "Unkown reduction \"" + reduction_name + "\"");
+    }
+
+    virtual bool supportBackend(int backendId) CV_OVERRIDE
+    {
+        return backendId == DNN_BACKEND_OPENCV;
+    }
+
+    virtual bool getMemoryShapes(const std::vector<MatShape> &inputs,
+                                 const int requiredOutputs,
+                                 std::vector<MatShape> &outputs,
+                                 std::vector<MatShape> &internals) const CV_OVERRIDE
+    {
+        CV_CheckEQ(inputs.size(), 3ull, "Scatter: require three inputs.");
+        CV_CheckEQ(inputs[0].size(), inputs[1].size(), "Scatter: input data should have the same ndim with indices.");
+        CV_CheckEQ(inputs[0].size(), inputs[2].size(), "Scatter: input data should have the same ndim with updates.");
+        for (size_t i = 0; i < inputs[0].size(); i++)
+        {
+            CV_CheckGE(inputs[0][i], inputs[1][i], "Scatter: each dim of input data should be greater than (or equal to) indices'.");
+            CV_CheckEQ(inputs[1][i], inputs[2][i], "Scatter: each dim of indices should be equal to updates'.");
+        }
+        outputs.assign(1, inputs[0]);
+        return false;
+    }
+
+    void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
+    {
+        CV_TRACE_FUNCTION();
+        CV_TRACE_ARG_VALUE(name, "name", name.c_str());
+
+        std::vector<Mat> inputs, outputs;
+        inputs_arr.getMatVector(inputs);
+        outputs_arr.getMatVector(outputs);
+
+        const Mat& data = inputs[0];
+        const Mat& indices = inputs[1];
+        const Mat& updates = inputs[2];
+        Mat& out = outputs[0];
+
+        typeDispatch(outputs[0].type(), data, indices, updates, out);
+    }
+
+    template<typename T, typename Functor>
+    void forward_impl(const Functor& rd, const Mat& data, const Mat& indices, const Mat& updates, Mat& out)
+    {
+        data.copyTo(out);
+
+        const int ndims = data.dims;
+        const int* shape = data.size.p;
+        const size_t* step = data.step.p;
+
+        const int* ind_shape = indices.size.p;
+        const size_t* ind_step = indices.step.p;
+
+        size_t inp_offset = 0;
+        size_t ind_offset = 0;
+        const T* p_index = indices.ptr<const T>();
+        const T* p_update = updates.ptr<const T>();
+        T* p_out = out.ptr<T>();
+
+        size_t total = indices.total();
+
+        int j, offset_at_idx, index;
+        size_t t, idx;
+        for (size_t i = 0; i < total; i++)
+        {
+            t = i;
+            inp_offset = 0;
+            ind_offset = 0;
+            int offset_at_axis = 0;
+            for (j = ndims - 1; j >= 0; j--)
+            {
+                idx = t / ind_shape[j];
+                offset_at_idx = (int)(t - idx * ind_shape[j]);
+                ind_offset += offset_at_idx * ind_step[j];
+                inp_offset += offset_at_idx * step[j];
+                t = idx;
+                if (j == axis)
+                {
+                    offset_at_axis = offset_at_idx * step[j];
+                }
+            }
+            ind_offset /= sizeof(T);
+
+            // get index and overwrite current indices
+            const T* tmp_p_index = p_index + ind_offset;
+            index = (int)(*tmp_p_index);
+            CV_Assert(index < shape[axis] && index > -shape[axis]);
+
+            inp_offset = inp_offset - offset_at_axis + ((index + shape[axis]) % shape[axis]) * step[axis];
+            inp_offset /= sizeof(T);
+
+            const T* tmp_p_update = p_update + ind_offset;
+            T* tmp_p_out = p_out + inp_offset;
+            *tmp_p_out = rd(*tmp_p_out, *tmp_p_update);
+        }
+    }
+
+    template<typename... Args>
+    inline void typeDispatch(const int type, Args&&... args)
+    {
+        switch (type)
+        {
+            case CV_8U:
+                reductionDispatch<uint8_t>(std::forward<Args>(args)...);
+                break;
+            case CV_32S:
+                reductionDispatch<int32_t>(std::forward<Args>(args)...);
+                break;
+            case CV_32F:
+                reductionDispatch<float>(std::forward<Args>(args)...);
+                break;
+            default:
+                CV_Error(cv::Error::BadDepth, "Unsupported type.");
+        };
+    }
+
+    template<typename T, typename... Args>
+    inline void reductionDispatch(Args&&... args)
+    {
+        switch (reduction)
+        {
+            case REDUCTION::NONE:
+            {
+                auto rd = [](const T& a, const T& b) { return b; }; // a from input data, b from updates
+                forward_impl<T>(rd, std::forward<Args>(args)...);
+                break;
+            }
+            case REDUCTION::ADD:
+            {
+                auto rd = [](const T& a, const T& b) { return a + b; };
+                forward_impl<T>(rd, std::forward<Args>(args)...);
+                break;
+            }
+            case REDUCTION::MUL:
+            {
+                auto rd = [](const T& a, const T& b) { return a * b; };
+                forward_impl<T>(rd, std::forward<Args>(args)...);
+                break;
+            }
+            case REDUCTION::MAX:
+            {
+                auto rd = [](const T& a, const T& b) { return std::max(a, b); };
+                forward_impl<T>(rd, std::forward<Args>(args)...);
+                break;
+            }
+            case REDUCTION::MIN:
+            {
+                auto rd = [](const T& a, const T& b) { return std::min(a, b); };
+                forward_impl<T>(rd, std::forward<Args>(args)...);
+                break;
+            }
+            default:
+                CV_Error(Error::StsBadArg, "Unsupported reduction.");
+        };
+    }
+
+private:
+    // Attributes
+    int axis;
+};
+
+Ptr<ScatterLayer> ScatterLayer::create(const LayerParams& params)
+{
+    return makePtr<ScatterLayerImpl>(params);
+}
+
+}} // namespace cv::dnn
index 0b104d1e63717102ca2d530de44c963730a0d73f..e91534e4097f1cb5b95f49fcaea667163903320c 100644 (file)
@@ -181,6 +181,7 @@ private:
     void parseElementWise          (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
     void parseDepthToSpace         (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
     void parseRange                (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
+    void parseScatter              (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
     void parseSimpleLayers         (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
 
     // Domain: com.microsoft
@@ -3106,6 +3107,58 @@ void ONNXImporter::parseRange(LayerParams& layerParams, const opencv_onnx::NodeP
     constBlobsExtraInfo.insert(std::make_pair(node_proto.output(0), TensorInfo(1)));
 }
 
+void ONNXImporter::parseScatter(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
+{
+    CV_CheckEQ(node_proto.input_size(), 3, "Scatter: three inputs are required.");
+    layerParams.type = "Scatter";
+    if (node_proto.op_type() == "ScatterND")
+        layerParams.type = "ScatterND";
+
+     size_t consts = 0;
+    for (size_t i = 0; i < node_proto.input_size(); ++i)
+        if (layer_id.find(node_proto.input(i)) == layer_id.end())
+            ++consts;
+
+    if (consts == node_proto.input_size())
+    {
+        std::vector<Mat> inputs, output;
+        for (size_t i = 0; i < node_proto.input_size(); i++)
+        {
+            Mat blob = getBlob(node_proto, i);
+            if (i == 1) // indices
+                blob.convertTo(blob, CV_32F);
+            inputs.push_back(blob);
+        }
+        runLayer(layerParams, inputs, output);
+        CV_Assert(output.size() == 1);
+        addConstant(node_proto.output(0), output[0]);
+        return;
+    }
+    else if (consts > 0)
+    {
+        for (size_t i = 0; i < node_proto.input_size(); i++)
+        {
+            if (layer_id.find(node_proto.input(i)) == layer_id.end())
+            {
+                Mat blob = getBlob(node_proto, i);
+                if (i == 1) // indices, from int32/int64 to float32
+                    blob.convertTo(blob, CV_32F);
+
+                LayerParams constParams;
+                constParams.name = node_proto.input(i);
+                constParams.type = "Const";
+                constParams.blobs.push_back(blob);
+
+                opencv_onnx::NodeProto proto;
+                proto.add_output(constParams.name);
+                addLayer(constParams, proto);
+            }
+        }
+    }
+
+    addLayer(layerParams, node_proto);
+}
+
 void ONNXImporter::parseSimpleLayers(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
 {
     bool is_all_input_const = true;
@@ -3726,6 +3779,7 @@ void ONNXImporter::buildDispatchMap_ONNX_AI(int opset_version)
     dispatch["DetectionOutput"] = &ONNXImporter::parseDetectionOutput;
     dispatch["CumSum"] = &ONNXImporter::parseCumSum;
     dispatch["SpaceToDepth"] = dispatch["DepthToSpace"] = &ONNXImporter::parseDepthToSpace;
+    dispatch["ScatterElements"] = dispatch["Scatter"] = dispatch["ScatterND"] = &ONNXImporter::parseScatter;
 
     dispatch["Equal"] = dispatch["Greater"] = dispatch["Less"] = dispatch["Pow"] = dispatch["Add"] =
             dispatch["Sub"] = dispatch["Mul"] = dispatch["Div"] = &ONNXImporter::parseElementWise;
index e9bc0e418797efcfdfe5d46bc7229fb621bbd80f..fc766c2b81083067bee71238ccfcf5e93cbef318 100644 (file)
@@ -666,11 +666,15 @@ static const TestCase testConformanceConfig[] = {
     {"test_scatter_elements_with_axis", 3, 1},
     {"test_scatter_elements_with_duplicate_indices", 3, 1},
     {"test_scatter_elements_with_negative_indices", 3, 1},
+    {"test_scatter_elements_with_reduction_max", 3, 1},
+    {"test_scatter_elements_with_reduction_min", 3, 1},
     {"test_scatter_elements_without_axis", 3, 1},
     {"test_scatter_with_axis", 3, 1},
     {"test_scatter_without_axis", 3, 1},
     {"test_scatternd", 3, 1},
     {"test_scatternd_add", 3, 1},
+    {"test_scatternd_max", 3, 1},
+    {"test_scatternd_min", 3, 1},
     {"test_scatternd_multiply", 3, 1},
     {"test_sce_NCd1_mean_weight_negative_ii", 3, 1},
     {"test_sce_NCd1_mean_weight_negative_ii_expanded", 3, 1},
index c18ced0c5945b87ba58d320cf43ccfa633fe4379..4c05f103052492e64d3b44dddd5f22f0792fa5c1 100644 (file)
 "test_sub_uint8",
 "test_tan",  // FP16 only
 "test_upsample_nearest",
+"test_scatter_elements_with_axis",
+"test_scatter_elements_with_duplicate_indices",
+"test_scatter_elements_with_negative_indices",
+"test_scatter_elements_with_reduction_max",
+"test_scatter_elements_with_reduction_min",
+"test_scatter_elements_without_axis",
+"test_scatter_with_axis",
+"test_scatter_without_axis",
+"test_scatternd",
+"test_scatternd_add",
+"test_scatternd_max",
+"test_scatternd_min",
+"test_scatternd_multiply",
index 72900a8194ac1147aea692a8966483e5cd1d462d..4924aaf9dac01936793a3c1a05685ff2e22fc6d6 100644 (file)
 "test_sub_uint8",
 "test_tanh",
 "test_upsample_nearest",
+"test_scatter_elements_with_axis",
+"test_scatter_elements_with_duplicate_indices",
+"test_scatter_elements_with_negative_indices",
+"test_scatter_elements_with_reduction_max",
+"test_scatter_elements_with_reduction_min",
+"test_scatter_elements_without_axis",
+"test_scatter_with_axis",
+"test_scatter_without_axis",
+"test_scatternd",
+"test_scatternd_add",
+"test_scatternd_max",
+"test_scatternd_min",
+"test_scatternd_multiply",
index cad914d05ac240a0b9ca26a66fbe451acae603f0..e6a35dfab9a6c07f2bb2f641038212a88a9a83cf 100644 (file)
@@ -1588,6 +1588,10 @@ CASE(test_scatter_elements_with_duplicate_indices)
     // no filter
 CASE(test_scatter_elements_with_negative_indices)
     // no filter
+CASE(test_scatter_elements_with_reduction_max)
+    // no filter
+CASE(test_scatter_elements_with_reduction_min)
+    // no filter
 CASE(test_scatter_elements_without_axis)
     // no filter
 CASE(test_scatter_with_axis)
@@ -1598,6 +1602,10 @@ CASE(test_scatternd)
     // no filter
 CASE(test_scatternd_add)
     // no filter
+CASE(test_scatternd_max)
+    // no filter
+CASE(test_scatternd_min)
+    // no filter
 CASE(test_scatternd_multiply)
     // no filter
 CASE(test_sce_NCd1_mean_weight_negative_ii)
index 101d44cbf01f19cec5bf59f118475ba08ae7cc7f..8156686428e9f3a105a8f80b2fbae026304ec705 100644 (file)
 "test_sub_uint8",
 "test_transpose_all_permutations_0",
 "test_upsample_nearest",
+"test_scatter_elements_with_axis",
+"test_scatter_elements_with_duplicate_indices",
+"test_scatter_elements_with_negative_indices",
+"test_scatter_elements_with_reduction_max",
+"test_scatter_elements_with_reduction_min",
+"test_scatter_elements_without_axis",
+"test_scatter_with_axis",
+"test_scatter_without_axis",
+"test_scatternd",
+"test_scatternd_add",
+"test_scatternd_max",
+"test_scatternd_min",
+"test_scatternd_multiply",
index c2425d469fcb6f64397227fd50ea5cec0e48701e..9b6b2414dba1be68108f125ef6f38bf9612f749b 100644 (file)
 "test_reduce_sum_square_default_axes_keepdims_random", // Expected: (normL1) <= (l1), actual: 0.0183411 vs 0.004
 "test_reduce_sum_square_do_not_keepdims_random", // Expected: (normL1) <= (l1), actual: 0.010789 vs 0.004, Expected: (normInf) <= (lInf), actual: 0.0290298 vs 0.02
 "test_reduce_sum_square_keepdims_random", // Expected: (normL1) <= (l1), actual: 0.010789 vs 0.004, Expected: (normInf) <= (lInf), actual: 0.0290298 vs 0.02
-"test_reduce_sum_square_negative_axes_keepdims_random", // Expected: (normL1) <= (l1), actual: 0.010789 vs 0.004, Expected: (normInf) <= (lInf), actual: 0.0290298 vs 0.02
\ No newline at end of file
+"test_reduce_sum_square_negative_axes_keepdims_random", // Expected: (normL1) <= (l1), actual: 0.010789 vs 0.004, Expected: (normInf) <= (lInf), actual: 0.0290298 vs 0.02
+"test_scatter_elements_with_axis",
+"test_scatter_elements_with_duplicate_indices",
+"test_scatter_elements_with_negative_indices",
+"test_scatter_elements_with_reduction_max",
+"test_scatter_elements_with_reduction_min",
+"test_scatter_elements_without_axis",
+"test_scatter_with_axis",
+"test_scatter_without_axis",
+"test_scatternd",
+"test_scatternd_add",
+"test_scatternd_max",
+"test_scatternd_min",
+"test_scatternd_multiply",
index 9a7a21f393e7a866b46e2a24d05fd5d0f0969b12..7fe58a07fde2e4d2ef5f992d81941677681b907e 100644 (file)
@@ -1,2 +1,15 @@
 "test_averagepool_3d_default",
 "test_maxpool_3d_default",
+"test_scatter_elements_with_axis",
+"test_scatter_elements_with_duplicate_indices",
+"test_scatter_elements_with_negative_indices",
+"test_scatter_elements_with_reduction_max",
+"test_scatter_elements_with_reduction_min",
+"test_scatter_elements_without_axis",
+"test_scatter_with_axis",
+"test_scatter_without_axis",
+"test_scatternd",
+"test_scatternd_add",
+"test_scatternd_max",
+"test_scatternd_min",
+"test_scatternd_multiply",
index 1437e5475b5691270416d7554591e1817418457f..0630833b1ff3c86ef1bb62d28e878b7d7420fe94 100644 (file)
 "test_roialign_aligned_true",
 "test_scan9_sum",
 "test_scan_sum",
-"test_scatter_elements_with_axis",
-"test_scatter_elements_with_duplicate_indices",
-"test_scatter_elements_with_negative_indices",
-"test_scatter_elements_without_axis",
-"test_scatter_with_axis",
-"test_scatter_without_axis",
-"test_scatternd",
-"test_scatternd_add",
-"test_scatternd_multiply",
 "test_sce_NCd1_mean_weight_negative_ii",
 "test_sce_NCd1_mean_weight_negative_ii_expanded",
 "test_sce_NCd1_mean_weight_negative_ii_log_prob",