Improve CopyTIBody util to cover disconnected graphs
authorAlexander Peskov <alexander.peskov@intel.com>
Mon, 19 Oct 2020 17:16:34 +0000 (20:16 +0300)
committerAlexander Peskov <alexander.peskov@intel.com>
Mon, 2 Nov 2020 09:37:48 +0000 (12:37 +0300)
Signed-off-by: Alexander Peskov <alexander.peskov@intel.com>
inference-engine/src/legacy_api/src/net_pass.cpp

index dafe806..c3c0c39 100644 (file)
@@ -46,15 +46,28 @@ static std::vector<DataPtr> getAllInputs(const std::vector<DataPtr>& heads) {
     CNNLayerSet inputLayers;
     std::unordered_set<CNNLayer*> allLayers;
 
+    // define any layer connected to provided Data object (consumer or creator)
+    auto findConnectedLayer = [] (const DataPtr& data) -> CNNLayerPtr {
+        auto consumerLayers = getInputTo(data);
+        if (!consumerLayers.empty())
+            return consumerLayers.begin()->second;
+
+        auto creator = getCreatorLayer(data).lock();
+        if (creator != nullptr)
+            return creator;
+
+        return nullptr;
+    };
+
     // Define all start layers
     for (const auto& data : heads) {
-        auto& secondLayers = getInputTo(data);
+        auto entryLayer = findConnectedLayer(data);
 
-        if (secondLayers.empty()) continue;
+        if (entryLayer == nullptr) continue;
 
         details::UnorderedDFS(
-            allLayers, secondLayers.begin()->second,
-            [&](CNNLayerPtr layer) {
+            allLayers, entryLayer,
+            [&inputLayers](const CNNLayerPtr &layer) {
                 if (layer->insData.empty()) {
                     inputLayers.insert(layer);
                 }
@@ -77,10 +90,17 @@ static std::vector<DataPtr> getAllInputs(const std::vector<DataPtr>& heads) {
 std::vector<CNNLayerPtr> TIBodySortTopologically(const TensorIterator::Body& body) {
     std::vector<CNNLayerPtr> all_layers;
 
-    auto all_input_layers = getAllInputs(body.inputs);
+    // In case of graph with several connected component
+    // total entry point is a union of [inputs]U[outputs]
+    // All internal nodes are achievable starting from this.
+    auto total_entry_point = body.inputs;
+    total_entry_point.insert(total_entry_point.end(),
+                             body.outputs.begin(), body.outputs.end());
+
+    auto all_input_layers = getAllInputs(total_entry_point);
     CNNNetForestDFS(
         all_input_layers,
-        [&](CNNLayerPtr current) {
+        [&all_layers](const CNNLayerPtr &current) {
             all_layers.push_back(current);
         },
         false);