nn.BatchNormalization and nn.Dropout layers from Torch
authorDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Mon, 4 Dec 2017 09:57:21 +0000 (12:57 +0300)
committerDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Mon, 4 Dec 2017 09:57:21 +0000 (12:57 +0300)
modules/dnn/src/layers/batch_norm_layer.cpp
modules/dnn/src/torch/torch_importer.cpp
modules/dnn/test/test_torch_importer.cpp

index dc4a4b3..8d9c639 100644 (file)
@@ -119,8 +119,9 @@ public:
         CV_Assert(inputs.size() == 1);
 
         Mat &inpBlob = *inputs[0];
-        int rows = inpBlob.size[2];
-        int cols = inpBlob.size[3];
+        CV_Assert(inpBlob.dims == 2 || inpBlob.dims == 4);
+        int rows = inpBlob.dims > 2 ? inpBlob.size[2] : 1;
+        int cols = inpBlob.dims > 2 ? inpBlob.size[3] : 1;
 
         for (size_t ii = 0; ii < outputs.size(); ii++)
         {
index 8438eb2..df5f1db 100644 (file)
@@ -617,7 +617,8 @@ struct TorchImporter : public ::cv::dnn::Importer
                 curModule->modules.push_back(cv::Ptr<Module>(new Module(nnName, "Sigmoid")));
                 readObject();
             }
-            else if (nnName == "SpatialBatchNormalization" || nnName == "InstanceNormalization")
+            else if (nnName == "SpatialBatchNormalization" || nnName == "InstanceNormalization" ||
+                     nnName == "BatchNormalization")
             {
                 newModule->apiType = "BatchNorm";
                 readTorchTable(scalarParams, tensorParams);
@@ -700,17 +701,24 @@ struct TorchImporter : public ::cv::dnn::Importer
 
                 curModule->modules.push_back(newModule);
             }
-            else if (nnName == "SpatialDropout")
+            else if (nnName == "SpatialDropout" || nnName == "Dropout")
             {
                 readTorchTable(scalarParams, tensorParams);
                 CV_Assert(scalarParams.has("p"));
 
-                float scale = 1 -  scalarParams.get<double>("p");
+                if (scalarParams.has("v2") && scalarParams.get<bool>("v2"))
+                {
+                    newModule->apiType = "Identity";
+                }
+                else
+                {
+                    float scale = 1 -  scalarParams.get<double>("p");
 
-                CV_Assert(scale > 0);
+                    CV_Assert(scale > 0);
 
-                newModule->apiType = "Power";
-                layerParams.set("scale", scale);
+                    newModule->apiType = "Power";
+                    layerParams.set("scale", scale);
+                }
                 curModule->modules.push_back(newModule);
             }
             // TotalVariation layer is from fast-neural-style project: https://github.com/jcjohnson/fast-neural-style
index 691a028..c90ddeb 100644 (file)
@@ -234,6 +234,11 @@ TEST(Torch_Importer, net_padding)
     runTorchNet("net_spatial_reflection_padding", DNN_TARGET_CPU, "", false, true);
 }
 
+TEST(Torch_Importer, net_non_spatial)
+{
+    runTorchNet("net_non_spatial", DNN_TARGET_CPU, "", false, true);
+}
+
 TEST(Torch_Importer, ENet_accuracy)
 {
     Net net;