Multiple inputs for TensorFlow models
authorDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Tue, 26 Jun 2018 10:32:28 +0000 (13:32 +0300)
committerDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Tue, 26 Jun 2018 11:03:59 +0000 (14:03 +0300)
modules/dnn/src/tensorflow/tf_importer.cpp
modules/dnn/test/test_tf_importer.cpp

index 9140368..25b2a73 100644 (file)
@@ -375,6 +375,8 @@ private:
     // and may be used to build the network using binary format only as a weights storage.
     // This approach is similar to Caffe's `.prorotxt` and `.caffemodel`.
     tensorflow::GraphDef netTxt;
+
+    std::vector<String> netInputsNames;
 };
 
 TFImporter::TFImporter(const char *model, const char *config)
@@ -442,7 +444,14 @@ void TFImporter::connect(const std::map<String, int>& layers_name_id_map, Net& n
     std::map<String, int>::const_iterator it = layers_name_id_map.find(outPin.name);
     if (it == layers_name_id_map.end())
         CV_Error(Error::StsError, "Input layer not found: " + outPin.name);
-    network.connect(it->second, outPin.blobIndex, input_layer_id, input_blob_id);
+
+    std::vector<String>::iterator inpNameIt = std::find(netInputsNames.begin(), netInputsNames.end(), outPin.name);
+    int blobIndex;
+    if (inpNameIt == netInputsNames.end())
+        blobIndex = outPin.blobIndex;
+    else
+        blobIndex = inpNameIt - netInputsNames.begin();
+    network.connect(it->second, blobIndex, input_layer_id, input_blob_id);
 }
 
 void TFImporter::connectToAllBlobs(const std::map<String, int>& layer_id, Net& network, const Pin& outPin,
@@ -778,7 +787,7 @@ void TFImporter::populateNet(Net dstNet)
                     Pin inp = parsePin(layer.input(ii));
                     if (layer_id.find(inp.name) == layer_id.end())
                         CV_Error(Error::StsError, "Input layer not found: " + inp.name);
-                    dstNet.connect(layer_id.at(inp.name), inp.blobIndex, id, ii);
+                    connect(layer_id, dstNet, inp, id, ii);
                 }
             }
         }
@@ -1028,7 +1037,7 @@ void TFImporter::populateNet(Net dstNet)
                 Pin inp = parsePin(layer.input(ii));
                 if (layer_id.find(inp.name) == layer_id.end())
                     CV_Error(Error::StsError, "Input layer not found: " + inp.name);
-                dstNet.connect(layer_id.at(inp.name), inp.blobIndex, id, ii - from);
+                connect(layer_id, dstNet, inp, id, ii - from);
             }
         }
         else if (type == "MaxPool")
@@ -1060,10 +1069,12 @@ void TFImporter::populateNet(Net dstNet)
         }
         else if (type == "Placeholder")
         {
-            std::vector<String> netInputs(1);
-            netInputs[0] = name;
-            layer_id[name] = 0;
-            dstNet.setInputsNames(netInputs);
+            if (!hasLayerAttr(layer, "dtype") ||
+                getLayerAttr(layer, "dtype").type() != tensorflow::DT_BOOL)  // If input is not a train/test flag.
+            {
+                netInputsNames.push_back(name);
+                layer_id[name] = 0;
+            }
         }
         else if (type == "Split") {
             // TODO: determining axis index remapping by input dimensions order of input blob
@@ -1201,7 +1212,7 @@ void TFImporter::populateNet(Net dstNet)
                     Pin inp = parsePin(layer.input(ii));
                     if (layer_id.find(inp.name) == layer_id.end())
                         CV_Error(Error::StsError, "Input layer not found: " + inp.name);
-                    dstNet.connect(layer_id.at(inp.name), inp.blobIndex, id, ii);
+                    connect(layer_id, dstNet, inp, id, ii);
                 }
             }
         }
@@ -1719,6 +1730,7 @@ void TFImporter::populateNet(Net dstNet)
             }
         }
     }
+    dstNet.setInputsNames(netInputsNames);
 }
 
 } // namespace
index 5ac8890..895ee9d 100644 (file)
@@ -440,4 +440,20 @@ TEST(Test_TensorFlow, resize_bilinear)
     runTensorFlowNet("resize_bilinear_factor");
 }
 
+TEST(Test_TensorFlow, two_inputs)
+{
+    Net net = readNet(path("two_inputs_net.pbtxt"));
+    net.setPreferableBackend(DNN_BACKEND_OPENCV);
+
+    Mat firstInput(2, 3, CV_32FC1), secondInput(2, 3, CV_32FC1);
+    randu(firstInput, -1, 1);
+    randu(secondInput, -1, 1);
+
+    net.setInput(firstInput, "first_input");
+    net.setInput(secondInput, "second_input");
+    Mat out = net.forward();
+
+    normAssert(out, firstInput + secondInput);
+}
+
 }