Merge pull request #16735 from l-bat:flatten_const_onnx
authorLiubov Batanina <piccione-mail@yandex.ru>
Sat, 14 Mar 2020 11:05:49 +0000 (14:05 +0300)
committerGitHub <noreply@github.com>
Sat, 14 Mar 2020 11:05:49 +0000 (11:05 +0000)
* Supported Flatten for constant nodes

* Added default axis

* Refactoring

* Refactoring

* Added cast layer

* Fix comments

* Add Cast for layers

modules/dnn/src/onnx/onnx_importer.cpp
modules/dnn/test/test_onnx_importer.cpp

index 0ca9095..e08f7b0 100644 (file)
@@ -431,9 +431,20 @@ void ONNXImporter::populateNet(Net dstNet)
         {
             bool isSub = layer_type == "Sub";
             CV_CheckEQ(node_proto.input_size(), 2, "");
-            if (layer_id.find(node_proto.input(1)) == layer_id.end())
+            bool is_const_0 = layer_id.find(node_proto.input(0)) == layer_id.end();
+            bool is_const_1 = layer_id.find(node_proto.input(1)) == layer_id.end();
+            if (is_const_0 && is_const_1)
             {
-                Mat blob = getBlob(node_proto, constBlobs, 1);
+                Mat blob_0 = getBlob(node_proto, constBlobs, 0);
+                Mat blob_1 = getBlob(node_proto, constBlobs, 1);
+                CV_Assert(blob_0.size == blob_1.size);
+                Mat output = isSub ? (blob_0 - blob_1) : (blob_0 + blob_1);
+                constBlobs.insert(std::make_pair(layerParams.name, output));
+                continue;
+            }
+            else if (is_const_0 || is_const_1)
+            {
+                Mat blob = getBlob(node_proto, constBlobs, is_const_0 ? 0 : 1);
                 blob = blob.reshape(1, 1);
                 if (blob.total() == 1) {
                     layerParams.type = "Power";
@@ -808,6 +819,21 @@ void ONNXImporter::populateNet(Net dstNet)
             layerParams.set("end_axis", axis);
             layerParams.type = "Flatten";
         }
+        else if (layer_type == "Flatten")
+        {
+            CV_CheckEQ(node_proto.input_size(), 1, "");
+            if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
+            {
+                Mat input = getBlob(node_proto, constBlobs, 0);
+                int axis = clamp(layerParams.get<int>("axis", 1), input.dims);
+
+                std::vector<int> out_size(&input.size[0], &input.size[0] + axis);
+                out_size.push_back(input.total(axis));
+                Mat output = input.reshape(1, out_size);
+                constBlobs.insert(std::make_pair(layerParams.name, output));
+                continue;
+            }
+        }
         else if (layer_type == "Unsqueeze")
         {
             CV_Assert(node_proto.input_size() == 1);
@@ -896,6 +922,31 @@ void ONNXImporter::populateNet(Net dstNet)
             constBlobs.insert(std::make_pair(layerParams.name, shapeMat));
             continue;
         }
+        else if (layer_type == "Cast")
+        {
+            if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
+            {
+                Mat blob = getBlob(node_proto, constBlobs, 0);
+                int type;
+                switch (layerParams.get<int>("to"))
+                {
+                    case opencv_onnx::TensorProto_DataType_FLOAT:   type = CV_32F; break;
+                    case opencv_onnx::TensorProto_DataType_UINT8:   type = CV_8U; break;
+                    case opencv_onnx::TensorProto_DataType_UINT16:  type = CV_16U; break;
+                    case opencv_onnx::TensorProto_DataType_FLOAT16: type = CV_16S; break;
+                    case opencv_onnx::TensorProto_DataType_INT8:
+                    case opencv_onnx::TensorProto_DataType_INT16:
+                    case opencv_onnx::TensorProto_DataType_INT32:
+                    case opencv_onnx::TensorProto_DataType_INT64:   type = CV_32S; break;
+                    default: type = blob.type();
+                }
+                blob.convertTo(blob, type);
+                constBlobs.insert(std::make_pair(layerParams.name, blob));
+                continue;
+            }
+            else
+                layerParams.type = "Identity";
+        }
         else if (layer_type == "Gather")
         {
             CV_Assert(node_proto.input_size() == 2);
index f284eed..769862d 100644 (file)
@@ -187,6 +187,11 @@ TEST_P(Test_ONNX_layers, MaxPooling_Sigmoid)
     testONNXModels("maxpooling_sigmoid");
 }
 
+TEST_P(Test_ONNX_layers, Cast)
+{
+    testONNXModels("cast");
+}
+
 TEST_P(Test_ONNX_layers, Concatenation)
 {
     if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019)
@@ -377,6 +382,7 @@ TEST_P(Test_ONNX_layers, DynamicReshape)
     testONNXModels("dynamic_reshape");
     testONNXModels("dynamic_reshape_opset_11");
     testONNXModels("flatten_by_prod");
+    testONNXModels("flatten_const");
 }
 
 TEST_P(Test_ONNX_layers, Reshape)