From 9510551c637cdec06c11aff63466352f336e2802 Mon Sep 17 00:00:00 2001 From: Dmitry Kurtaev Date: Tue, 26 Jun 2018 13:32:28 +0300 Subject: [PATCH] Multiple inputs for TensorFlow models --- modules/dnn/src/tensorflow/tf_importer.cpp | 28 ++++++++++++++++++++-------- modules/dnn/test/test_tf_importer.cpp | 16 ++++++++++++++++ 2 files changed, 36 insertions(+), 8 deletions(-) diff --git a/modules/dnn/src/tensorflow/tf_importer.cpp b/modules/dnn/src/tensorflow/tf_importer.cpp index 9140368..25b2a73 100644 --- a/modules/dnn/src/tensorflow/tf_importer.cpp +++ b/modules/dnn/src/tensorflow/tf_importer.cpp @@ -375,6 +375,8 @@ private: // and may be used to build the network using binary format only as a weights storage. // This approach is similar to Caffe's `.prorotxt` and `.caffemodel`. tensorflow::GraphDef netTxt; + + std::vector netInputsNames; }; TFImporter::TFImporter(const char *model, const char *config) @@ -442,7 +444,14 @@ void TFImporter::connect(const std::map& layers_name_id_map, Net& n std::map::const_iterator it = layers_name_id_map.find(outPin.name); if (it == layers_name_id_map.end()) CV_Error(Error::StsError, "Input layer not found: " + outPin.name); - network.connect(it->second, outPin.blobIndex, input_layer_id, input_blob_id); + + std::vector::iterator inpNameIt = std::find(netInputsNames.begin(), netInputsNames.end(), outPin.name); + int blobIndex; + if (inpNameIt == netInputsNames.end()) + blobIndex = outPin.blobIndex; + else + blobIndex = inpNameIt - netInputsNames.begin(); + network.connect(it->second, blobIndex, input_layer_id, input_blob_id); } void TFImporter::connectToAllBlobs(const std::map& layer_id, Net& network, const Pin& outPin, @@ -778,7 +787,7 @@ void TFImporter::populateNet(Net dstNet) Pin inp = parsePin(layer.input(ii)); if (layer_id.find(inp.name) == layer_id.end()) CV_Error(Error::StsError, "Input layer not found: " + inp.name); - dstNet.connect(layer_id.at(inp.name), inp.blobIndex, id, ii); + connect(layer_id, dstNet, inp, id, ii); } } } @@ -1028,7 +1037,7 @@ void TFImporter::populateNet(Net dstNet) Pin inp = parsePin(layer.input(ii)); if (layer_id.find(inp.name) == layer_id.end()) CV_Error(Error::StsError, "Input layer not found: " + inp.name); - dstNet.connect(layer_id.at(inp.name), inp.blobIndex, id, ii - from); + connect(layer_id, dstNet, inp, id, ii - from); } } else if (type == "MaxPool") @@ -1060,10 +1069,12 @@ void TFImporter::populateNet(Net dstNet) } else if (type == "Placeholder") { - std::vector netInputs(1); - netInputs[0] = name; - layer_id[name] = 0; - dstNet.setInputsNames(netInputs); + if (!hasLayerAttr(layer, "dtype") || + getLayerAttr(layer, "dtype").type() != tensorflow::DT_BOOL) // If input is not a train/test flag. + { + netInputsNames.push_back(name); + layer_id[name] = 0; + } } else if (type == "Split") { // TODO: determining axis index remapping by input dimensions order of input blob @@ -1201,7 +1212,7 @@ void TFImporter::populateNet(Net dstNet) Pin inp = parsePin(layer.input(ii)); if (layer_id.find(inp.name) == layer_id.end()) CV_Error(Error::StsError, "Input layer not found: " + inp.name); - dstNet.connect(layer_id.at(inp.name), inp.blobIndex, id, ii); + connect(layer_id, dstNet, inp, id, ii); } } } @@ -1719,6 +1730,7 @@ void TFImporter::populateNet(Net dstNet) } } } + dstNet.setInputsNames(netInputsNames); } } // namespace diff --git a/modules/dnn/test/test_tf_importer.cpp b/modules/dnn/test/test_tf_importer.cpp index 5ac8890..895ee9d 100644 --- a/modules/dnn/test/test_tf_importer.cpp +++ b/modules/dnn/test/test_tf_importer.cpp @@ -440,4 +440,20 @@ TEST(Test_TensorFlow, resize_bilinear) runTensorFlowNet("resize_bilinear_factor"); } +TEST(Test_TensorFlow, two_inputs) +{ + Net net = readNet(path("two_inputs_net.pbtxt")); + net.setPreferableBackend(DNN_BACKEND_OPENCV); + + Mat firstInput(2, 3, CV_32FC1), secondInput(2, 3, CV_32FC1); + randu(firstInput, -1, 1); + randu(secondInput, -1, 1); + + net.setInput(firstInput, "first_input"); + net.setInput(secondInput, "second_input"); + Mat out = net.forward(); + + normAssert(out, firstInput + secondInput); +} + } -- 2.7.4