From 0ce7c33bc865948aae471e4d4f953018dc773b3b Mon Sep 17 00:00:00 2001 From: Dmitry Kurtaev Date: Wed, 16 Aug 2017 21:11:59 +0300 Subject: [PATCH] Torch's Concat and ConcatTable doesn't use Split layer --- modules/dnn/src/layers/split_layer.cpp | 5 ++--- modules/dnn/src/torch/torch_importer.cpp | 23 +++++++---------------- 2 files changed, 9 insertions(+), 19 deletions(-) diff --git a/modules/dnn/src/layers/split_layer.cpp b/modules/dnn/src/layers/split_layer.cpp index bae6c87..435d4bd 100644 --- a/modules/dnn/src/layers/split_layer.cpp +++ b/modules/dnn/src/layers/split_layer.cpp @@ -75,7 +75,7 @@ public: Layer::getMemoryShapes(inputs, max(1, outputsCount >= 0 ? outputsCount : requiredOutputs), outputs, internals); - return true; + return false; } void forward(std::vector &inputs, std::vector &outputs, std::vector &internals) @@ -86,8 +86,7 @@ public: for (size_t i = 0; i < outputs.size(); i++) { CV_Assert(inputs[0]->total() == outputs[i].total()); - if (outputs[i].data != inputs[0]->data) - inputs[0]->copyTo(outputs[i]); + inputs[0]->copyTo(outputs[i]); } } }; diff --git a/modules/dnn/src/torch/torch_importer.cpp b/modules/dnn/src/torch/torch_importer.cpp index 44fcd8c..ecaf054 100644 --- a/modules/dnn/src/torch/torch_importer.cpp +++ b/modules/dnn/src/torch/torch_importer.cpp @@ -827,20 +827,18 @@ struct TorchImporter : public ::cv::dnn::Importer } else if (module->thName == "Concat") { - int newId, splitId, mergeId; - LayerParams mergeParams, splitParams; + int newId, mergeId; + LayerParams mergeParams; mergeParams.set("axis", module->params.get("dimension") - 1); - splitId = net.addLayer(generateLayerName("torchSplit"), "Split", splitParams); - net.connect(prevLayerId, prevOutNum, splitId, 0); - std::vector branchIds; for (int i = 0; i < (int)module->modules.size(); i++) { - newId = fill(module->modules[i], addedModules, splitId, i); + newId = fill(module->modules[i], addedModules, prevLayerId, prevOutNum); branchIds.push_back(newId); } + moduleCounter += 1; // Skip split layer creation. See https://github.com/opencv/opencv/pull/9384. mergeId = net.addLayer(generateLayerName("torchMerge"), "Concat", mergeParams); for (int i = 0; i < branchIds.size(); i++) @@ -884,19 +882,12 @@ struct TorchImporter : public ::cv::dnn::Importer return mergeId; } else if (module->thName == "ConcatTable") { - int newId = -1, splitId; - LayerParams splitParams; - - splitId = net.addLayer(generateLayerName("torchSplit"), "Split", splitParams); - net.connect(prevLayerId, prevOutNum, splitId, 0); - - addedModules.push_back(std::make_pair(splitId, module)); - + int newId = -1; + moduleCounter += 1; // Skip split layer creation. See https://github.com/opencv/opencv/pull/9384. for (int i = 0; i < (int)module->modules.size(); i++) { - newId = fill(module->modules[i], addedModules, splitId, i); + newId = fill(module->modules[i], addedModules, prevLayerId, prevOutNum); } - return newId; } else if (module->thName == "JoinTable") { -- 2.7.4