1 // Copyright (C) 2018-2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
7 #include <legacy/graph_tools.hpp>
8 #include "gna_plugin_log.hpp"
16 namespace InferenceEngine {
18 static constexpr size_t invalid_data_idx = std::numeric_limits<size_t>::max();
20 // compares data, for copied network and in old network
21 inline bool areEqualDatas(DataPtr source, DataPtr target) {
22 if (source.get() == target.get()) {
27 // actual dims value might be incorrect dueto syntetic case
28 // , when getbatch() size returns value not reflect in actual data
30 if (source->getTensorDesc().getDims().size() != target->getTensorDesc().getDims().size()) {
35 if (source->getName() != target->getName()) {
39 // inputTO layers are identical by design
43 /// @brief utility to locate input data idx from given outdata and given layer
44 inline std::vector<int> CNNLayerFindInsDataIdxes(DataPtr sourceData, CNNLayerPtr layer) {
45 std::vector<int> dataIdxes;
46 auto outLayers = getInputTo(sourceData);
47 for (auto & outLayer : outLayers) {
48 if (outLayer.second.get() != layer.get()) {
51 for (int j = 0; j < layer->insData.size(); j++) {
52 if (areEqualDatas(layer->insData[j].lock(), sourceData)) {
53 dataIdxes.push_back(j);
57 IE_ASSERT(!dataIdxes.empty());
62 * @brief pointer of previous layers
63 * @param idx - index in previous layer collection
66 inline InferenceEngine::CNNLayerPtr CNNNetPrevLayer(const InferenceEngine::CNNLayerPtr & layer, int idx = 0) {
67 if (CNNNetHasPrevLayer(layer.get(), idx)) {
68 auto prevData = layer->insData[idx].lock();
69 IE_ASSERT(prevData != nullptr);
70 return getCreatorLayer(prevData).lock();
72 THROW_IE_EXCEPTION << "Layer " << layer->name << " has no previous layer";
77 * @brief pointer of previous layers
78 * @param idx - index in previous layer collection
81 inline InferenceEngine::CNNLayerPtr CNNNetPrevLayer(const InferenceEngine::CNNLayer* layer, int idx = 0) {
82 IE_ASSERT(layer != nullptr);
83 if (CNNNetHasPrevLayer(layer, idx)) {
84 auto prevData = layer->insData[idx].lock();
85 return getCreatorLayer(prevData).lock();
87 THROW_IE_EXCEPTION << "Layer " << layer->name << " has no previous layer";
95 using raw_ptr_type = T;
96 explicit ExtractRawPtr(T ptr) : ptr(ptr) {}
103 class ExtractRawPtr<std::shared_ptr<U>> {
104 std::shared_ptr<U> ptr;
106 using raw_ptr_type = U*;
107 explicit ExtractRawPtr(const std::shared_ptr<U> & ptr) : ptr(ptr) {}
114 inline typename ExtractRawPtr<T>::raw_ptr_type raw_ptr(T obj) {
115 ExtractRawPtr<T> x(obj);
120 * @brief gets pointer to previous layer
121 * @param idx - index in previous layer connection - in other layers only zero idx will be used
122 * @param layer - source layer
123 * @param shouldSkip - skip kriteria
125 template <class Layer>
126 inline InferenceEngine::CNNLayerPtr CNNNetPrevLayerSkipCertain(Layer layer, int idx,
127 const std::function<bool(CNNLayerPtr)> &shouldSkip) {
128 IE_ASSERT(layer != nullptr);
129 if (!CNNNetHasPrevLayer(raw_ptr(layer), idx)) {
130 THROW_GNA_EXCEPTION << "Can't find PrevLayer. All layers are skipped.";
133 auto prev = CNNNetPrevLayer(layer, idx);
135 /// using upper search simplified version
136 if (shouldSkip(prev)) {
137 return CNNNetPrevLayerSkipCertain(prev, 0, shouldSkip);
144 * @brief returns next layer, skipping certain layers based on given functor
145 * @param layer - given start layer
146 * @param oidx - index of output data
147 * @param iidx - index of input layers for given output
148 * @param bOnlyCheck - doesn't throw exception if next layer missed
150 * @return layer pointer and it's insData index that uses to connect to previous layer in chain
153 template <class Layer>
154 inline std::pair<InferenceEngine::CNNLayerPtr, int> CNNNetCheckNextLayerSkipCertain(Layer layer, int oidx, int iidx, bool bOnlyCheck,
155 const std::function<bool(CNNLayerPtr)> &shouldSkip) {
156 if (oidx >= layer->outData.size()) {
157 if (bOnlyCheck) return {nullptr, 0};
158 THROW_GNA_LAYER_EXCEPTION(layer) << " no next output layer for outdata: " << oidx;
160 if (iidx >= getInputTo(layer->outData[oidx]).size()) {
161 if (bOnlyCheck) return {nullptr, 0};
162 THROW_GNA_LAYER_EXCEPTION(layer) << " no next output layer for outdata: " << oidx << " and inputTo index: " << iidx;
165 auto outLayer = getInputTo(layer->outData[oidx]).begin();
166 std::advance(outLayer, iidx);
168 if (!shouldSkip(outLayer->second)) {
169 auto insDataIdx = CNNLayerFindInsDataIdxes(layer->outData[oidx], outLayer->second);
170 if (insDataIdx.size() != 1) {
171 if (bOnlyCheck) return {nullptr, 0};
172 THROW_GNA_LAYER_EXCEPTION(layer) << " has multiple connection to " << oidx << " outData";
174 return {outLayer->second, insDataIdx.front()};
176 return CNNNetCheckNextLayerSkipCertain(outLayer->second, 0, 0, bOnlyCheck, shouldSkip);
180 * @brief return all layers reachable from given one
182 * @param oDataIdx - -1 means iterate over all odata indexes
186 template <class Layer>
187 inline std::vector<CNNLayerPtr> CNNNetGetAllNextLayersSkipCertain(Layer layer, int oDataIdx, const std::function<bool(CNNLayerPtr)> &shouldSkip) {
188 // TODO: need to have generic function that creates slice of the graph : starting from given layer
189 // and skipped all non functional - ending up into functional one
191 std::list<CNNLayerPtr> currentSet;
192 std::vector<CNNLayerPtr> resultSet;
194 std::vector<std::map<std::string, CNNLayerPtr>> start;
195 if (oDataIdx == -1) {
196 for (int i = 0; i != layer->outData.size(); i++) {
197 start.push_back(getInputTo(layer->outData[i]));
200 start.push_back(getInputTo(layer->outData[oDataIdx]));
203 auto separate_layers = [¤tSet, &resultSet, &shouldSkip](std::map<std::string, CNNLayerPtr>& inputTo) {
204 for (auto &&bfsLayer : inputTo) {
205 if (shouldSkip(bfsLayer.second)) {
206 currentSet.push_back(bfsLayer.second);
209 resultSet.push_back(bfsLayer.second);
213 int startIdx, endIdx;
214 if (oDataIdx == -1) {
216 endIdx = layer->outData.size();
219 endIdx = oDataIdx + 1;
222 for (int i = startIdx; i != endIdx; i++) {
223 separate_layers(getInputTo(layer->outData[i]));
226 while (!currentSet.empty()) {
227 auto currentLayer = currentSet.front();
228 currentSet.pop_front();
229 for (auto && oData : currentLayer->outData) {
230 separate_layers(getInputTo(oData));
236 /// @brief alias for strict checkNextLayer (false)
237 template <class Layer>
238 inline std::pair<InferenceEngine::CNNLayerPtr, int> CNNNetGetNextLayerSkipCertain(Layer layer, int oidx, int iidx,
239 const std::function<bool(CNNLayerPtr)> &shouldSkip) {
240 return CNNNetCheckNextLayerSkipCertain(layer, oidx, iidx, false, shouldSkip);
243 /// @brief alias for non-strict checkNextLayer (false)
244 template <class Layer>
245 inline bool CNNNetHasNextLayerSkipCertain(Layer layer, int oidx, int iidx, const std::function<bool(CNNLayerPtr)> &shouldSkip) {
246 auto l = CNNNetCheckNextLayerSkipCertain(layer, oidx, iidx, true, shouldSkip);
247 return l.first.get() != nullptr;
251 /// @brief utility to locate output data idx from given insData index and given layer
252 inline int CNNLayerFindOutDataIdx(CNNLayerPtr layer, int insDataIdx) {
253 auto prevLayer = CNNNetPrevLayer(layer, insDataIdx);
254 auto outDataToSearch = layer->insData[insDataIdx].lock();
255 auto outDataIt = std::find(prevLayer->outData.begin(), prevLayer->outData.end(), outDataToSearch);
256 return std::distance(prevLayer->outData.begin(), outDataIt);
260 * @brief swap two layer in graph - with modifying input/output references
261 * also if layers have different dimensions they are preserved, so layers should be dimensions agnostic
262 * lhs is a first node in topological order - this is current limitation to avoid passing cnnnetwork object
264 inline void CNNNetSwapLayers(InferenceEngine::CNNLayerPtr lhs,
265 InferenceEngine::CNNLayerPtr rhs) {
266 if (lhs == nullptr || rhs ==nullptr) {
267 THROW_IE_EXCEPTION << "CNNNetSwapLayers : nullptr";
269 if (lhs.get() == rhs.get())
272 if (lhs->outData.size() > 1) {
273 THROW_IE_EXCEPTION << "Unsupported layer for swap operation : " << lhs->name;
275 if (rhs->outData.size() > 1) {
276 THROW_IE_EXCEPTION << "Unsupported layer for swap operation : " << rhs->name;
279 auto &rhs_outputs = getInputTo(rhs->outData.front());
280 auto &lhs_outputs = getInputTo(lhs->outData.front());
282 // fixing input layers edges
283 for (int i = 0; true; i++) {
284 if (!CNNNetHasPrevLayer(lhs.get(), i)) break;
285 auto prev_lhs = CNNNetPrevLayer(lhs, i);
286 if (!prev_lhs) break;
287 if (prev_lhs == rhs) continue;
289 for (auto &prev_next : prev_lhs->outData) {
290 auto lhs_ptr = getInputTo(prev_next).find(lhs->name);
291 lhs_ptr->second = rhs;
295 for (int i = 0; true; i++) {
296 if (!CNNNetHasPrevLayer(rhs.get(), i)) break;
297 auto prev_rhs = CNNNetPrevLayer(rhs, i);
298 if (!prev_rhs) break;
299 if (prev_rhs == lhs) continue;
301 for (auto &prev_next : prev_rhs->outData) {
302 auto lhs_ptr = getInputTo(prev_next).find(rhs->name);
303 lhs_ptr->second = lhs;
307 // fixing output layers back edges
308 for (auto &next_lhs : lhs_outputs) {
309 if (next_lhs.second == rhs) continue;
311 bool hasHrsConnection = false;
312 for (auto &ins_for_lhs_next : next_lhs.second->insData) {
313 if (getCreatorLayer(ins_for_lhs_next.lock()).lock() != rhs ) continue;
314 hasHrsConnection = true;
317 if (!hasHrsConnection) {
318 for (auto &ins_for_lhs_next : next_lhs.second->insData) {
319 if (getCreatorLayer(ins_for_lhs_next.lock()).lock() != lhs) continue;
320 ins_for_lhs_next = rhs->outData.front();
325 for (auto &next_rhs : rhs_outputs) {
326 if (next_rhs.second == lhs) continue;
328 bool hasLHSConnection = false;
329 for (auto &ins_for_rhs_next : next_rhs.second->insData) {
330 if (getCreatorLayer(ins_for_rhs_next.lock()).lock() != lhs) continue;
331 hasLHSConnection = true;
334 if (!hasLHSConnection) {
335 for (auto &ins_for_rhs_next : next_rhs.second->insData) {
336 if (getCreatorLayer(ins_for_rhs_next.lock()).lock() != rhs) continue;
337 ins_for_rhs_next = lhs->outData.front();
342 // fixing layers itself output references
344 // c++11 lacks generic lambda
345 using inputTo_element = std::remove_reference<decltype(*lhs_outputs.begin())>::type;
347 std::remove_reference<decltype(lhs_outputs)>::type tmp;
348 bool bHadInterconnectR2L = false;
350 // 0. remove interconnect rhs->lhs
351 details::erase_if(rhs_outputs, [&bHadInterconnectR2L, &lhs](inputTo_element & element) {
352 bHadInterconnectR2L |= element.second == lhs;
353 return element.second == lhs;
356 // 1. move all output references from rhs to tmp
357 tmp.insert(std::begin(rhs_outputs), std::end(rhs_outputs));
361 // 2. removing lhs->rhs interconnect
362 bool bHadInterConnect = false;
363 details::erase_if(lhs_outputs, [&bHadInterConnect, &rhs](inputTo_element & element) {
364 bHadInterConnect |= element.second == rhs;
365 return element.second == rhs;
368 // 3. move all output references from lhs to rhs
369 rhs_outputs.insert(std::begin(lhs_outputs), std::end(lhs_outputs));
372 // 4. move from tmp to lhs
373 lhs_outputs.insert(std::begin(tmp), std::end(tmp));
375 // 5.restore interconnects
376 if (bHadInterConnect) {
377 rhs_outputs[lhs->name] = lhs;
379 if (bHadInterconnectR2L) {
380 lhs_outputs[rhs->name] = rhs;
384 // fixing layers itself input references
386 // 1. removing interconnects lhs->rhs
387 bool interConnectBackL2R = false;
388 details::erase_if(lhs->insData, [&interConnectBackL2R, &rhs](DataWeakPtr weakData) {
389 InferenceEngine::CNNLayerPtr creator = nullptr;
390 auto data = weakData.lock();
392 creator = getCreatorLayer(data).lock();
393 interConnectBackL2R |= creator == rhs;
394 return creator == rhs;
397 // 2. removing interconnects rhs->lhs
398 auto interConnectBackR2L = false;
399 if (!interConnectBackL2R) {
400 details::erase_if(rhs->insData, [&interConnectBackR2L, &lhs](DataWeakPtr weakData) {
401 auto data = weakData.lock();
402 IE_ASSERT(data != nullptr);
403 interConnectBackR2L |= getCreatorLayer(data).lock() == lhs;
404 return getCreatorLayer(data).lock() == lhs;
409 std::swap(lhs->insData, rhs->insData);
411 // 4. Restoring interconnections
412 if (interConnectBackL2R) {
413 rhs->insData.push_back(lhs->outData.front());
415 if (interConnectBackR2L) {
416 lhs->insData.push_back(rhs->outData.front());
421 // 1. step find out what layer is first in topological order
422 // 2. integrate shape infer mechanism starting from lhs
423 lhs->outData.front()->setDims(rhs->outData.front()->getDims());
429 * @@brief insertLayer between given layers
430 * @param after, insertion happened after this layer, if after is nullptr, insertion happened after all inputLayers for before layer
431 * @param before, insertion happened before layer, if before is nullptr, insertion happened before all outputLayers of after layer
432 * @param layerToInsert inserted layer
433 * @param outDataIndex index data to be used to insert layer after it. Cannot be used to specify allOutputDatas
435 inline void CNNNetworkInsertLayer(CNNLayerPtr after,
437 CNNLayerPtr layerToInsert,
438 size_t outDataIndex = invalid_data_idx) {
439 if (after == nullptr && before == nullptr) {
440 THROW_IE_EXCEPTION << "Cannot Insert Layer: before or after layers should be valid layer pointers";
443 bool bLocated = false;
444 bool hasOutputIndex = outDataIndex != invalid_data_idx;
445 if (after != nullptr) {
446 for (auto && data : after->outData) {
447 if (hasOutputIndex && outDataIndex) {
451 auto inputTo = getInputTo(data);
452 for (auto inputIt = inputTo.begin(); inputIt != inputTo.end(); ++inputIt) {
453 auto input = inputIt->second;
454 if (before != nullptr && input.get() != before.get())
458 for (auto x : CNNLayerFindInsDataIdxes(data, input)) {
459 input->insData[x] = layerToInsert->outData.front();
462 getInputTo(layerToInsert->outData.front())[inputIt->first] = input;
466 // erasing only one particular connection
467 getInputTo(data).erase(inputIt->first);
468 if (before != nullptr) {
472 if (getInputTo(data).empty()) {
476 // erasing all connection
477 if (before == nullptr) {
478 getInputTo(data).clear();
481 getInputTo(data)[layerToInsert->outData.front()->getName()] = layerToInsert;
482 layerToInsert->insData.push_back(data);
485 if (hasOutputIndex) {
490 // if given outputDataIndex is not correct, lets find index that matches *before* layer
492 if (before != nullptr) {
493 IE_ASSERT(before->insData.size() == 1);
494 auto prevLayer = after;
495 for (auto idx = prevLayer->outData.begin(); idx != prevLayer->outData.end(); idx++) {
496 auto &outputports = getInputTo(*idx);
497 for (auto ll = outputports.begin(); ll != outputports.end(); ll++) {
498 if (ll->second.get() == before.get()) {
499 // looks we found where need to remove
500 outputports.erase(ll);
501 before->insData.clear();
502 before->insData.push_back(layerToInsert->outData.front());
503 getInputTo(layerToInsert->outData.front())[before->name] = before;
513 // now we have a before layer without inputs
516 // inserting into node that doesnt have child
517 IE_ASSERT(!after->outData.empty());
518 for (auto &&next : after->outData) {
519 if (!getInputTo(next).empty()) continue;
520 getInputTo(next)[layerToInsert->name] = layerToInsert;
521 layerToInsert->insData.push_back(next);
527 THROW_IE_EXCEPTION << "Cannot insert layer between: " <<
528 ((after == nullptr) ? std::string("nullptr") : after->name) << " and " <<
529 ((before == nullptr) ? std::string("nullptr") : before->name);
534 * @brief returns previous layers and outData index for it
537 * @param acceptanceCriteria
541 std::vector<std::pair<CNNLayerPtr, int> > CNNNetGetPrevLayersSkip(CNNLayerPtr origin, const T &acceptanceCriteria, int idx = -1) {
542 std::vector<std::pair<CNNLayerPtr, int> > prevLayers;
543 for (int i = idx == -1 ? 0 : idx; CNNNetHasPrevLayer(origin.get(), i) && (idx == -1 || i == idx); i++) {
544 auto prevLayer = CNNNetPrevLayer(origin, i);
545 if (acceptanceCriteria(prevLayer)) {
546 prevLayers.push_back({prevLayer, CNNLayerFindOutDataIdx(origin, i)});
548 // if for some input we need to look in upper layers - original index not used here intentionally
549 auto prevPrevLayers = CNNNetGetPrevLayersSkip(prevLayer, acceptanceCriteria);
550 prevLayers.insert(prevLayers.end(), prevPrevLayers.begin(), prevPrevLayers.end());
558 * @brief remove given layer from topology, currently only layers with one input data and one output data supported
560 inline void CNNNetworkRemoveLayer(CNNLayerPtr layer) {
562 THROW_IE_EXCEPTION << "Cannot remove layer pointed to NULL";
564 if (layer->insData.size() != 1) {
565 THROW_IE_EXCEPTION << "Cannot remove layer : "<< layer->name <<" that has not 1 input";
567 if (layer->outData.size() != 1) {
568 THROW_IE_EXCEPTION << "Cannot remove layer : "<< layer->name <<" that has not 1 output";
571 auto isp = layer->insData.front().lock();
573 THROW_IE_EXCEPTION << "Cannot remove layer : "<< layer->name <<" cannot get it's input";
575 // if dimensions of input layer not equal target dimensions - shape infer or reshape layer required, so skipping those cases
576 auto osp = layer->outData.front();
577 if (isp->getDims() != osp->getDims()) {
578 THROW_IE_EXCEPTION << "Cannot remove layer : "<< layer->name <<" its input layer("
579 << isp->getName() << ") and output(" << osp->getName() << ") have incompatible dimensions";
582 // remove isp->layer connection
583 for (auto i = getInputTo(isp).begin(); i != getInputTo(isp).end(); i++) {
584 if (i->second.get() == layer.get()) {
585 getInputTo(isp).erase(i);
590 // remove osp->layer connection
591 for (auto && outData : getInputTo(osp)) {
592 for (auto i = outData.second->insData.begin(); i != outData.second->insData.end(); i++) {
593 auto insData = i->lock();
595 THROW_IE_EXCEPTION << "Cannot remove layer : "<< layer->name <<", its output layer(" <<
596 outData.first << " has invalid input configuration";
598 auto creator = getCreatorLayer(insData).lock();
600 THROW_IE_EXCEPTION << "Cannot remove layer : "<< layer->name <<", its output layer(" <<
601 outData.first << " has invalid input configuration";
604 // found layer that need to be removed
605 if (creator.get() == layer.get()) {
606 outData.second->insData.erase(i);
612 // add isp->osp connections
613 for (auto && outData : getInputTo(osp)) {
614 // new syntetic name to avoid duplicates in map
615 getInputTo(isp)[layer->name + "_" + outData.first] = outData.second;
618 // add osp->isp connections
619 for (auto && outData : getInputTo(osp)) {
620 outData.second->insData.push_back(isp);
623 // removing layer->osp, and layer->isp connection not necessary - layer will delete it by itself
626 } // namespace InferenceEngine