From 598039c0edf4d4424e67c3ee82ed206241c2c228 Mon Sep 17 00:00:00 2001 From: Dmitry Kurtaev Date: Sat, 31 Mar 2018 11:11:10 +0300 Subject: [PATCH] Fix embedded Torch's nn.ConcatTable --- modules/dnn/src/torch/torch_importer.cpp | 31 ++++++++++++++++++------------- modules/dnn/test/test_torch_importer.cpp | 5 +++++ 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/modules/dnn/src/torch/torch_importer.cpp b/modules/dnn/src/torch/torch_importer.cpp index db660ff..83e4a48 100644 --- a/modules/dnn/src/torch/torch_importer.cpp +++ b/modules/dnn/src/torch/torch_importer.cpp @@ -101,6 +101,8 @@ struct TorchImporter std::set readedIndexes; std::map storages; std::map tensors; + // Stack with numbers of unconnected layers per scope (Sequential, ConcatTable etc.) + std::vector numUnconnectedLayers; struct Module { @@ -489,15 +491,7 @@ struct TorchImporter layerParams.set("inputDimension", scalarParams.get("inputDimension")); layerParams.set("outputDimension", scalarParams.get("outputDimension")); } - if (nnName == "Concat") - { - layerParams.set("dimension", scalarParams.get("dimension")); - } - if (nnName == "JoinTable") - { - layerParams.set("dimension", scalarParams.get("dimension")); - } - if (nnName == "DepthConcat") + else if (nnName == "Concat" || nnName == "JoinTable" || nnName == "DepthConcat") { layerParams.set("dimension", scalarParams.get("dimension")); } @@ -1096,6 +1090,7 @@ struct TorchImporter { newId = fill(module->modules[i], addedModules, prevLayerId, prevOutNum); } + numUnconnectedLayers.push_back(module->modules.size()); return newId; } else if (module->thName == "JoinTable") { @@ -1108,9 +1103,14 @@ struct TorchImporter mergeId = net.addLayer(generateLayerName("torchMerge"), "Concat", mergeParams); addedModules.push_back(std::make_pair(mergeId, module)); - for (int i = 0; i < ids.size(); i++) + // Connect to the last number of unconnected layers. + CV_Assert(!numUnconnectedLayers.empty()); + const int numInputs = numUnconnectedLayers.back(); + numUnconnectedLayers.pop_back(); + CV_Assert(numInputs <= ids.size()); + for (int i = 0; i < numInputs; i++) { - net.connect(ids[i], 0, mergeId, i); + net.connect(ids[ids.size() - numInputs + i], 0, mergeId, i); } return mergeId; @@ -1124,9 +1124,14 @@ struct TorchImporter int id = net.addLayer(name, "Eltwise", params); - for (int i = 0; i < ids.size(); i++) + // Connect to the last number of unconnected layers. + CV_Assert(!numUnconnectedLayers.empty()); + const int numInputs = numUnconnectedLayers.back(); + numUnconnectedLayers.pop_back(); + CV_Assert(numInputs <= ids.size()); + for (int i = 0; i < numInputs; i++) { - net.connect(ids[i], 0, id, i); + net.connect(ids[ids.size() - numInputs + i], 0, id, i); } addedModules.push_back(std::make_pair(id, module)); diff --git a/modules/dnn/test/test_torch_importer.cpp b/modules/dnn/test/test_torch_importer.cpp index 621d3ef..2edb79a 100644 --- a/modules/dnn/test/test_torch_importer.cpp +++ b/modules/dnn/test/test_torch_importer.cpp @@ -320,4 +320,9 @@ TEST(Torch_Importer, DISABLED_run_paralel) runTorchNet("net_parallel", DNN_TARGET_OPENCL, "l5_torchMerge"); } +TEST(Torch_Importer, net_residual) +{ + runTorchNet("net_residual", DNN_TARGET_CPU, "", false, true); +} + } -- 2.7.4