CNNLayerSet inputLayers;
std::unordered_set<CNNLayer*> allLayers;
+ // define any layer connected to provided Data object (consumer or creator)
+ auto findConnectedLayer = [] (const DataPtr& data) -> CNNLayerPtr {
+ auto consumerLayers = getInputTo(data);
+ if (!consumerLayers.empty())
+ return consumerLayers.begin()->second;
+
+ auto creator = getCreatorLayer(data).lock();
+ if (creator != nullptr)
+ return creator;
+
+ return nullptr;
+ };
+
// Define all start layers
for (const auto& data : heads) {
- auto& secondLayers = getInputTo(data);
+ auto entryLayer = findConnectedLayer(data);
- if (secondLayers.empty()) continue;
+ if (entryLayer == nullptr) continue;
details::UnorderedDFS(
- allLayers, secondLayers.begin()->second,
- [&](CNNLayerPtr layer) {
+ allLayers, entryLayer,
+ [&inputLayers](const CNNLayerPtr &layer) {
if (layer->insData.empty()) {
inputLayers.insert(layer);
}
std::vector<CNNLayerPtr> TIBodySortTopologically(const TensorIterator::Body& body) {
std::vector<CNNLayerPtr> all_layers;
- auto all_input_layers = getAllInputs(body.inputs);
+ // In case of graph with several connected component
+ // total entry point is a union of [inputs]U[outputs]
+ // All internal nodes are achievable starting from this.
+ auto total_entry_point = body.inputs;
+ total_entry_point.insert(total_entry_point.end(),
+ body.outputs.begin(), body.outputs.end());
+
+ auto all_input_layers = getAllInputs(total_entry_point);
CNNNetForestDFS(
all_input_layers,
- [&](CNNLayerPtr current) {
+ [&all_layers](const CNNLayerPtr ¤t) {
all_layers.push_back(current);
},
false);