From bbbec300a6f262faa36ea735b7cbf686f3b1f339 Mon Sep 17 00:00:00 2001 From: Dmitry Kurtaev Date: Mon, 4 Dec 2017 12:57:21 +0300 Subject: [PATCH] nn.BatchNormalization and nn.Dropout layers from Torch --- modules/dnn/src/layers/batch_norm_layer.cpp | 5 +++-- modules/dnn/src/torch/torch_importer.cpp | 20 ++++++++++++++------ modules/dnn/test/test_torch_importer.cpp | 5 +++++ 3 files changed, 22 insertions(+), 8 deletions(-) diff --git a/modules/dnn/src/layers/batch_norm_layer.cpp b/modules/dnn/src/layers/batch_norm_layer.cpp index dc4a4b3..8d9c639 100644 --- a/modules/dnn/src/layers/batch_norm_layer.cpp +++ b/modules/dnn/src/layers/batch_norm_layer.cpp @@ -119,8 +119,9 @@ public: CV_Assert(inputs.size() == 1); Mat &inpBlob = *inputs[0]; - int rows = inpBlob.size[2]; - int cols = inpBlob.size[3]; + CV_Assert(inpBlob.dims == 2 || inpBlob.dims == 4); + int rows = inpBlob.dims > 2 ? inpBlob.size[2] : 1; + int cols = inpBlob.dims > 2 ? inpBlob.size[3] : 1; for (size_t ii = 0; ii < outputs.size(); ii++) { diff --git a/modules/dnn/src/torch/torch_importer.cpp b/modules/dnn/src/torch/torch_importer.cpp index 8438eb2..df5f1db 100644 --- a/modules/dnn/src/torch/torch_importer.cpp +++ b/modules/dnn/src/torch/torch_importer.cpp @@ -617,7 +617,8 @@ struct TorchImporter : public ::cv::dnn::Importer curModule->modules.push_back(cv::Ptr(new Module(nnName, "Sigmoid"))); readObject(); } - else if (nnName == "SpatialBatchNormalization" || nnName == "InstanceNormalization") + else if (nnName == "SpatialBatchNormalization" || nnName == "InstanceNormalization" || + nnName == "BatchNormalization") { newModule->apiType = "BatchNorm"; readTorchTable(scalarParams, tensorParams); @@ -700,17 +701,24 @@ struct TorchImporter : public ::cv::dnn::Importer curModule->modules.push_back(newModule); } - else if (nnName == "SpatialDropout") + else if (nnName == "SpatialDropout" || nnName == "Dropout") { readTorchTable(scalarParams, tensorParams); CV_Assert(scalarParams.has("p")); - float scale = 1 - scalarParams.get("p"); + if (scalarParams.has("v2") && scalarParams.get("v2")) + { + newModule->apiType = "Identity"; + } + else + { + float scale = 1 - scalarParams.get("p"); - CV_Assert(scale > 0); + CV_Assert(scale > 0); - newModule->apiType = "Power"; - layerParams.set("scale", scale); + newModule->apiType = "Power"; + layerParams.set("scale", scale); + } curModule->modules.push_back(newModule); } // TotalVariation layer is from fast-neural-style project: https://github.com/jcjohnson/fast-neural-style diff --git a/modules/dnn/test/test_torch_importer.cpp b/modules/dnn/test/test_torch_importer.cpp index 691a028..c90ddeb 100644 --- a/modules/dnn/test/test_torch_importer.cpp +++ b/modules/dnn/test/test_torch_importer.cpp @@ -234,6 +234,11 @@ TEST(Torch_Importer, net_padding) runTorchNet("net_spatial_reflection_padding", DNN_TARGET_CPU, "", false, true); } +TEST(Torch_Importer, net_non_spatial) +{ + runTorchNet("net_non_spatial", DNN_TARGET_CPU, "", false, true); +} + TEST(Torch_Importer, ENet_accuracy) { Net net; -- 2.7.4