Import Upsample and Unsqueeze from ONNX
authorDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Thu, 21 Feb 2019 16:48:46 +0000 (19:48 +0300)
committerDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Thu, 21 Feb 2019 17:17:28 +0000 (20:17 +0300)
modules/dnn/src/onnx/onnx_importer.cpp
modules/dnn/test/test_onnx_importer.cpp

index 218775b..98c3563 100644 (file)
@@ -392,10 +392,10 @@ void ONNXImporter::populateNet(Net dstNet)
             layerParams.set("ceil_mode", isCeilMode(layerParams));
             layerParams.set("ave_pool_padded_area", framework_name == "pytorch");
         }
-        else if (layer_type == "GlobalAveragePool")
+        else if (layer_type == "GlobalAveragePool" || layer_type == "GlobalMaxPool")
         {
             layerParams.type = "Pooling";
-            layerParams.set("pool", "AVE");
+            layerParams.set("pool", layer_type == "GlobalAveragePool" ? "AVE" : "MAX");
             layerParams.set("global_pooling", true);
         }
         else if (layer_type == "Add" || layer_type == "Sum")
@@ -448,6 +448,11 @@ void ONNXImporter::populateNet(Net dstNet)
                 layerParams.set("bias_term", false);
             }
         }
+        else if (layer_type == "Neg")
+        {
+            layerParams.type = "Power";
+            layerParams.set("scale", -1);
+        }
         else if (layer_type == "Constant")
         {
             CV_Assert(node_proto.input_size() == 0);
@@ -595,21 +600,35 @@ void ONNXImporter::populateNet(Net dstNet)
         else if (layer_type == "Unsqueeze")
         {
             CV_Assert(node_proto.input_size() == 1);
-            Mat input = getBlob(node_proto, constBlobs, 0);
-
             DictValue axes = layerParams.get("axes");
-            std::vector<int> dims;
-            for (int j = 0; j < input.dims; j++) {
-                dims.push_back(input.size[j]);
-            }
-            CV_Assert(axes.getIntValue(axes.size()-1) <= dims.size());
-            for (int j = 0; j < axes.size(); j++) {
-                dims.insert(dims.begin() + axes.getIntValue(j), 1);
+            if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
+            {
+                // Constant input.
+                Mat input = getBlob(node_proto, constBlobs, 0);
+
+                std::vector<int> dims;
+                for (int j = 0; j < input.dims; j++) {
+                    dims.push_back(input.size[j]);
+                }
+                CV_Assert(axes.getIntValue(axes.size()-1) <= dims.size());
+                for (int j = 0; j < axes.size(); j++) {
+                    dims.insert(dims.begin() + axes.getIntValue(j), 1);
+                }
+
+                Mat out = input.reshape(0, dims);
+                constBlobs.insert(std::make_pair(layerParams.name, out));
+                continue;
             }
 
-            Mat out = input.reshape(0, dims);
-            constBlobs.insert(std::make_pair(layerParams.name, out));
-            continue;
+            // Variable input.
+            if (axes.size() != 1)
+                CV_Error(Error::StsNotImplemented, "Multidimensional unsqueeze");
+
+            int dims[] = {1, -1};
+            layerParams.type = "Reshape";
+            layerParams.set("axis", axes.getIntValue(0));
+            layerParams.set("num_axes", 1);
+            layerParams.set("dim", DictValue::arrayInt(&dims[0], 2));
         }
         else if (layer_type == "Reshape")
         {
@@ -707,6 +726,25 @@ void ONNXImporter::populateNet(Net dstNet)
                 continue;
             }
         }
+        else if (layer_type == "Upsample")
+        {
+            layerParams.type = "Resize";
+            if (layerParams.has("scales"))
+            {
+                // Pytorch layer
+                DictValue scales = layerParams.get("scales");
+                CV_Assert(scales.size() == 4);
+                layerParams.set("zoom_factor_y", scales.getIntValue(2));
+                layerParams.set("zoom_factor_x", scales.getIntValue(3));
+            }
+            else
+            {
+                // Caffe2 layer
+                replaceLayerParam(layerParams, "height_scale", "zoom_factor_y");
+                replaceLayerParam(layerParams, "width_scale", "zoom_factor_x");
+            }
+            replaceLayerParam(layerParams, "mode", "interpolation");
+        }
         else
         {
             for (int j = 0; j < node_proto.input_size(); j++) {
index 217ef34..72112d2 100644 (file)
@@ -140,6 +140,11 @@ TEST_P(Test_ONNX_layers, Padding)
     testONNXModels("padding");
 }
 
+TEST_P(Test_ONNX_layers, Resize)
+{
+    testONNXModels("resize_nearest");
+}
+
 TEST_P(Test_ONNX_layers, MultyInputs)
 {
     const String model =  _tf("models/multy_inputs.onnx");
@@ -169,6 +174,11 @@ TEST_P(Test_ONNX_layers, DynamicReshape)
     testONNXModels("dynamic_reshape");
 }
 
+TEST_P(Test_ONNX_layers, Reshape)
+{
+    testONNXModels("unsqueeze");
+}
+
 INSTANTIATE_TEST_CASE_P(/*nothing*/, Test_ONNX_layers, dnnBackendsAndTargets());
 
 class Test_ONNX_nets : public Test_ONNX_layers {};