[IE] Improve Network topological sort in case of disconnected graph
authorAlexander Peskov <alexander.peskov@intel.com>
Wed, 28 Oct 2020 22:37:20 +0000 (01:37 +0300)
committerAlexander Peskov <alexander.peskov@intel.com>
Mon, 2 Nov 2020 09:37:48 +0000 (12:37 +0300)
Signed-off-by: Alexander Peskov <alexander.peskov@intel.com>
inference-engine/src/legacy_api/include/legacy/graph_tools.hpp

index b364e0d..c03f409 100644 (file)
@@ -371,19 +371,42 @@ inline CNNLayerSet CNNNetGetAllInputLayers(const ICNNNetwork& network) {
     InputsDataMap inputs;
     network.getInputsInfo(inputs);
 
+    OutputsDataMap outputs;
+    network.getOutputsInfo(outputs);
+
+    std::vector<DataPtr> entryDataSet;
+    entryDataSet.reserve(inputs.size() + outputs.size());
+    for (const auto &kvp : inputs)
+        entryDataSet.push_back(kvp.second->getInputData());
+    for (const auto &kvp : outputs)
+        entryDataSet.push_back(kvp.second);
+
     CNNLayerSet inputLayers;
     std::unordered_set<CNNLayer*> allLayers;
 
-    if (inputs.empty()) return inputLayers;
+    if (entryDataSet.empty()) return inputLayers;
 
-    for (const auto& input : inputs) {
-        auto& secondLayers = getInputTo(input.second->getInputData());
+    // define any layer connected to provided Data object (consumer or creator)
+    auto findConnectedLayer = [] (const DataPtr& data) -> CNNLayerPtr {
+        auto consumerLayers = getInputTo(data);
+        if (!consumerLayers.empty())
+            return consumerLayers.begin()->second;
 
-        if (secondLayers.empty()) continue;
+        auto creator = getCreatorLayer(data).lock();
+        if (creator != nullptr)
+            return creator;
+
+        return nullptr;
+    };
+
+    for (const auto& data : entryDataSet) {
+        auto entryLayer = findConnectedLayer(data);
+
+        if (entryLayer == nullptr) continue;
 
         details::UnorderedDFS(
-            allLayers, secondLayers.begin()->second,
-            [&](CNNLayerPtr layer) {
+            allLayers, entryLayer,
+            [&inputLayers](const CNNLayerPtr& layer) {
                 if (layer->insData.empty()) {
                     inputLayers.insert(layer);
                 }