Fix embedded Torch's nn.ConcatTable
authorDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Sat, 31 Mar 2018 08:11:10 +0000 (11:11 +0300)
committerDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Sat, 31 Mar 2018 08:11:10 +0000 (11:11 +0300)
modules/dnn/src/torch/torch_importer.cpp
modules/dnn/test/test_torch_importer.cpp

index db660ff..83e4a48 100644 (file)
@@ -101,6 +101,8 @@ struct TorchImporter
     std::set<int> readedIndexes;
     std::map<int, Mat> storages;
     std::map<int, Mat> tensors;
+    // Stack with numbers of unconnected layers per scope (Sequential, ConcatTable etc.)
+    std::vector<int> numUnconnectedLayers;
 
     struct Module
     {
@@ -489,15 +491,7 @@ struct TorchImporter
                     layerParams.set("inputDimension", scalarParams.get<int>("inputDimension"));
                     layerParams.set("outputDimension", scalarParams.get<int>("outputDimension"));
                 }
-                if (nnName == "Concat")
-                {
-                    layerParams.set("dimension", scalarParams.get<int>("dimension"));
-                }
-                if (nnName == "JoinTable")
-                {
-                    layerParams.set("dimension", scalarParams.get<int>("dimension"));
-                }
-                if (nnName == "DepthConcat")
+                else if (nnName == "Concat" || nnName == "JoinTable" || nnName == "DepthConcat")
                 {
                     layerParams.set("dimension", scalarParams.get<int>("dimension"));
                 }
@@ -1096,6 +1090,7 @@ struct TorchImporter
                 {
                     newId = fill(module->modules[i], addedModules, prevLayerId, prevOutNum);
                 }
+                numUnconnectedLayers.push_back(module->modules.size());
                 return newId;
             }
             else if (module->thName == "JoinTable") {
@@ -1108,9 +1103,14 @@ struct TorchImporter
                 mergeId = net.addLayer(generateLayerName("torchMerge"), "Concat", mergeParams);
                 addedModules.push_back(std::make_pair(mergeId, module));
 
-                for (int i = 0; i < ids.size(); i++)
+                // Connect to the last number of unconnected layers.
+                CV_Assert(!numUnconnectedLayers.empty());
+                const int numInputs = numUnconnectedLayers.back();
+                numUnconnectedLayers.pop_back();
+                CV_Assert(numInputs <= ids.size());
+                for (int i = 0; i < numInputs; i++)
                 {
-                    net.connect(ids[i], 0, mergeId, i);
+                    net.connect(ids[ids.size() - numInputs + i], 0, mergeId, i);
                 }
 
                 return mergeId;
@@ -1124,9 +1124,14 @@ struct TorchImporter
 
                 int id = net.addLayer(name, "Eltwise", params);
 
-                for (int i = 0; i < ids.size(); i++)
+                // Connect to the last number of unconnected layers.
+                CV_Assert(!numUnconnectedLayers.empty());
+                const int numInputs = numUnconnectedLayers.back();
+                numUnconnectedLayers.pop_back();
+                CV_Assert(numInputs <= ids.size());
+                for (int i = 0; i < numInputs; i++)
                 {
-                    net.connect(ids[i], 0, id, i);
+                    net.connect(ids[ids.size() - numInputs + i], 0, id, i);
                 }
 
                 addedModules.push_back(std::make_pair(id, module));
index 621d3ef..2edb79a 100644 (file)
@@ -320,4 +320,9 @@ TEST(Torch_Importer, DISABLED_run_paralel)
     runTorchNet("net_parallel", DNN_TARGET_OPENCL, "l5_torchMerge");
 }
 
+TEST(Torch_Importer, net_residual)
+{
+    runTorchNet("net_residual", DNN_TARGET_CPU, "", false, true);
+}
+
 }