add ArgMax and ArgMin layers
authorSmirnov Egor <s.e.a.98@yandex.ru>
Mon, 6 Dec 2021 16:33:59 +0000 (19:33 +0300)
committerSmirnov Egor <s.e.a.98@yandex.ru>
Mon, 6 Dec 2021 17:49:54 +0000 (20:49 +0300)
modules/dnn/include/opencv2/dnn/all_layers.hpp
modules/dnn/src/init.cpp
modules/dnn/src/layers/arg_layer.cpp [new file with mode: 0644]
modules/dnn/src/onnx/onnx_importer.cpp
modules/dnn/test/test_onnx_importer.cpp

index 0bec9e3..4593527 100644 (file)
@@ -284,6 +284,16 @@ CV__DNN_INLINE_NS_BEGIN
         static Ptr<LRNLayer> create(const LayerParams& params);
     };
 
+
+    /** @brief ArgMax/ArgMin layer
+     * @note returns indices as floats, which means the supported range is [-2^24; 2^24]
+     */
+    class CV_EXPORTS ArgLayer : public Layer
+    {
+    public:
+        static Ptr<ArgLayer> create(const LayerParams& params);
+    };
+
     class CV_EXPORTS PoolingLayer : public Layer
     {
     public:
index affaa1a..443d1ea 100644 (file)
@@ -123,6 +123,7 @@ void initializeLayerFactory()
     CV_DNN_REGISTER_LAYER_CLASS(Identity,       BlankLayer);
     CV_DNN_REGISTER_LAYER_CLASS(Silence,        BlankLayer);
     CV_DNN_REGISTER_LAYER_CLASS(Const,          ConstLayer);
+    CV_DNN_REGISTER_LAYER_CLASS(Arg,            ArgLayer);
 
     CV_DNN_REGISTER_LAYER_CLASS(Crop,           CropLayer);
     CV_DNN_REGISTER_LAYER_CLASS(Eltwise,        EltwiseLayer);
diff --git a/modules/dnn/src/layers/arg_layer.cpp b/modules/dnn/src/layers/arg_layer.cpp
new file mode 100644 (file)
index 0000000..94af458
--- /dev/null
@@ -0,0 +1,120 @@
+// This file is part of OpenCV project.
+// It is subject to the license terms in the LICENSE file found in the top-level directory
+// of this distribution and at http://opencv.org/license.html.
+
+#include "../precomp.hpp"
+#include "layers_common.hpp"
+
+
+namespace cv { namespace dnn {
+
+class ArgLayerImpl CV_FINAL : public ArgLayer
+{
+public:
+    enum class ArgOp
+    {
+        MIN = 0,
+        MAX = 1,
+    };
+
+    ArgLayerImpl(const LayerParams& params)
+    {
+        setParamsFrom(params);
+
+        axis = params.get<int>("axis", 0);
+        keepdims = (params.get<int>("keepdims", 1) == 1);
+        select_last_index = (params.get<int>("select_last_index", 0) == 1);
+
+        const std::string& argOp = params.get<std::string>("op");
+
+        if (argOp == "max")
+        {
+            op = ArgOp::MAX;
+        }
+        else if (argOp == "min")
+        {
+            op = ArgOp::MIN;
+        }
+        else
+        {
+            CV_Error(Error::StsBadArg, "Unsupported operation");
+        }
+    }
+
+    virtual bool supportBackend(int backendId) CV_OVERRIDE
+    {
+        return backendId == DNN_BACKEND_OPENCV && preferableTarget == DNN_TARGET_CPU;
+    }
+
+    void handleKeepDims(MatShape& shape, const int axis_) const
+    {
+        if (keepdims)
+        {
+            shape[axis_] = 1;
+        }
+        else
+        {
+            shape.erase(shape.begin() + axis_);
+        }
+    }
+
+    virtual bool getMemoryShapes(const std::vector<MatShape> &inputs,
+                                 const int requiredOutputs,
+                                 std::vector<MatShape> &outputs,
+                                 std::vector<MatShape> &internals) const CV_OVERRIDE
+    {
+        MatShape inpShape = inputs[0];
+
+        const int axis_ = normalize_axis(axis, inpShape);
+        handleKeepDims(inpShape, axis_);
+        outputs.assign(1, inpShape);
+
+        return false;
+    }
+
+    void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
+    {
+        CV_TRACE_FUNCTION();
+        CV_TRACE_ARG_VALUE(name, "name", name.c_str());
+
+        std::vector<Mat> inputs, outputs;
+        inputs_arr.getMatVector(inputs);
+        outputs_arr.getMatVector(outputs);
+
+        CV_Assert_N(inputs.size() == 1, outputs.size() == 1);
+        std::vector<int> outShape = shape(outputs[0]);
+        Mat output(outShape, CV_32SC1);
+
+        switch (op)
+        {
+        case ArgOp::MIN:
+            cv::reduceArgMin(inputs[0], output, axis, select_last_index);
+            break;
+        case ArgOp::MAX:
+            cv::reduceArgMax(inputs[0], output, axis, select_last_index);
+            break;
+        default:
+            CV_Error(Error::StsBadArg, "Unsupported operation.");
+        }
+
+        output = output.reshape(1, outShape);
+        output.convertTo(outputs[0], CV_32FC1);
+    }
+
+private:
+    // The axis in which to compute the arg indices. Accepted range is [-r, r-1] where r = rank(data).
+    int axis;
+    // Keep the reduced dimension or not
+    bool keepdims;
+    // Whether to select the first or the last index or Max/Min.
+    bool select_last_index;
+    // Operation to be performed
+    ArgOp op;
+};
+
+Ptr<ArgLayer> ArgLayer::create(const LayerParams& params)
+{
+    return Ptr<ArgLayer>(new ArgLayerImpl(params));
+}
+
+}}  // namespace cv::dnn
index b0d7d4b..85c4479 100644 (file)
@@ -100,6 +100,7 @@ private:
     const DispatchMap dispatch;
     static const DispatchMap buildDispatchMap();
 
+    void parseArg                  (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
     void parseMaxPool              (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
     void parseAveragePool          (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
     void parseReduce               (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
@@ -768,6 +769,14 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto)
     }
 }
 
+void ONNXImporter::parseArg(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
+{
+    const std::string& layer_type = node_proto.op_type();
+    layerParams.type = "Arg";
+    layerParams.set("op", layer_type == "ArgMax" ? "max" : "min");
+    addLayer(layerParams, node_proto);
+}
+
 void setCeilMode(LayerParams& layerParams)
 {
     // auto_pad attribute is deprecated and uses ceil
@@ -2986,6 +2995,7 @@ const ONNXImporter::DispatchMap ONNXImporter::buildDispatchMap()
 {
     DispatchMap dispatch;
 
+    dispatch["ArgMax"] = dispatch["ArgMin"] = &ONNXImporter::parseArg;
     dispatch["MaxPool"] = &ONNXImporter::parseMaxPool;
     dispatch["AveragePool"] = &ONNXImporter::parseAveragePool;
     dispatch["GlobalAveragePool"] = dispatch["GlobalMaxPool"] = dispatch["ReduceMean"] = dispatch["ReduceSum"] =
index f2deaf1..bfea855 100644 (file)
@@ -355,6 +355,15 @@ TEST_P(Test_ONNX_layers, Min)
     testONNXModels("min", npy, 0, 0, false, true, 2);
 }
 
+TEST_P(Test_ONNX_layers, ArgLayer)
+{
+    if (backend != DNN_BACKEND_OPENCV || target != DNN_TARGET_CPU)
+        throw SkipTestException("Only CPU is supported");  // FIXIT use tags
+
+    testONNXModels("argmax");
+    testONNXModels("argmin");
+}
+
 TEST_P(Test_ONNX_layers, Scale)
 {
     if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019)