Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / mkldnn_plugin / mkldnn_edge.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "mkldnn_edge.h"
6 #include "mkldnn_node.h"
7 #include "mkldnn_extension_utils.h"
8 #include <blob_factory.hpp>
9
10 using namespace mkldnn;
11 namespace MKLDNNPlugin {
12
13 MKLDNNEdge::MKLDNNEdge(const MKLDNNNodePtr &parent, const MKLDNNNodePtr &child, int pr_port, int ch_port) :
14         parent(parent), child(child), parent_port(pr_port), child_port(ch_port) {}
15
16 const MKLDNNNodePtr MKLDNNEdge::getParent() const {
17     auto parentPtr = parent.lock();
18     if (!parentPtr)
19         THROW_IE_EXCEPTION << "Edge contains empty parent node";
20     return parentPtr;
21 }
22
23 const MKLDNNNodePtr MKLDNNEdge::getChild() const {
24     auto childPtr = child.lock();
25     if (!childPtr)
26         THROW_IE_EXCEPTION << "Edge contains empty child node";
27     return childPtr;
28 }
29
30 bool MKLDNNEdge::isDropped() {
31     bool not_in_parent = true;
32     bool not_in_child = true;
33
34     auto parent_ptr = parent.lock();
35     if (parent_ptr) {
36         for (auto &edge : parent_ptr->childEdges)
37             if (edge.lock().get() == this)
38                 not_in_parent = false;
39     }
40
41     auto child_ptr = child.lock();
42     if (child_ptr) {
43         for (auto &edge : child_ptr->parentEdges)
44             if (edge.lock().get() == this)
45                 not_in_child = false;
46     }
47     return not_in_parent && not_in_child;
48 }
49
50 void MKLDNNEdge::drop() {
51     auto _drop_from = [&] (std::vector<MKLDNNEdgeWeakPtr> &list) {
52         auto myself = std::find_if(list.begin(), list.end(),
53                 [&] (MKLDNNEdgeWeakPtr edge) { return edge.lock().get() == this; });
54
55         if (myself != list.end())
56             list.erase(myself);
57     };
58
59     _drop_from(getParent()->childEdges);
60     _drop_from(getChild()->parentEdges);
61 }
62
63
64 bool MKLDNNEdge::needReorder() {
65     bool canBeInPlaceConflicts = false;
66     auto parentSPD = getParent()->getSelectedPrimitiveDescriptor();
67     auto childSPD = getChild()->getSelectedPrimitiveDescriptor();
68     if (!parentSPD || !childSPD)
69         THROW_IE_EXCEPTION << "Cannot make a decision about reorder. Primitive descriptors weren't selected.";
70
71     int outNumber = getOutputNum();
72     int inNumber = getInputNum();
73     bool in_place = inPlace();
74     bool childCanChangeMem = childSPD->getConfig().outConfs.empty();
75     for (const auto conf : childSPD->getConfig().outConfs) {
76         if (conf.inPlace == outNumber && outNumber >= 0)
77             childCanChangeMem = true;
78     }
79
80     const auto& detectInPlaceChildsNum = [](const std::vector<MKLDNNEdgePtr>& edges) -> size_t {
81         size_t count = 0;
82         for (const auto& edge : edges) {
83             auto childSPD = edge->getChild()->getSelectedPrimitiveDescriptor();
84             int outNumber = edge->getOutputNum();
85             if (childSPD->getConfig().outConfs.empty())
86                 count++;
87             for (const auto conf : childSPD->getConfig().outConfs) {
88                 if (conf.inPlace == outNumber)
89                     count++;
90             }
91         }
92         return count;
93     };
94
95     const auto portChildEdges = getParent()->getChildEdgesAtPort(inNumber);
96     if (in_place && detectInPlaceChildsNum(portChildEdges) > 1 && childCanChangeMem)
97         canBeInPlaceConflicts = true;
98     if (!canBeInPlaceConflicts && in_place && !getParent()->getChildEdges().empty()) {
99         for (auto &p_edge_peer : portChildEdges) {
100             if (p_edge_peer.get() == this)
101                 continue;
102             if (p_edge_peer->getChild()->getType() != Reorder && p_edge_peer->inPlace(LOOK_DOWN))
103                 canBeInPlaceConflicts = true;
104         }
105     }
106
107     if (in_place) {
108         if (inNumber >= 0 && inNumber < parentSPD->getConfig().outConfs.size() && parentSPD->getConfig().outConfs[inNumber].inPlace >= 0 &&
109             outNumber >= 0 && outNumber < childSPD->getConfig().inConfs.size() && childSPD->getConfig().inConfs[outNumber].inPlace >= 0)
110             canBeInPlaceConflicts = true;
111     }
112     return canBeInPlaceConflicts || !MKLDNNExtensionUtils::initTensorsAreEqual(getInputDesc(), getOutputDesc());
113 }
114
115 InferenceEngine::TensorDesc MKLDNNEdge::getInputDesc() {
116     if (inputDesc.getLayout() == InferenceEngine::Layout::ANY) {
117         inputDesc = getSpecifiedInputDesc({});
118     }
119     return inputDesc;
120 }
121
122 InferenceEngine::TensorDesc MKLDNNEdge::getOutputDesc() {
123     if (outputDesc.getLayout() == InferenceEngine::Layout::ANY) {
124         outputDesc = getSpecifiedOutputDesc({});
125     }
126     return outputDesc;
127 }
128
129 InferenceEngine::TensorDesc MKLDNNEdge::getDesc() {
130     if (!MKLDNNExtensionUtils::initTensorsAreEqual(getInputDesc(), getOutputDesc()))
131         THROW_IE_EXCEPTION << "Cannot get descriptor for edge: " << getParent()->getName() << "->"
132                            << getChild()->getName();
133     return getInputDesc();
134 }
135
136 int MKLDNNEdge::getInputNum() {
137     return parent_port;
138 }
139
140 int MKLDNNEdge::getOutputNum() {
141     return child_port;
142 }
143
144 void MKLDNNEdge::allocate(const void* mem_ptr) {
145     if (status != Status::NeedAllocation)
146         return;
147
148     if (memoryPtr)
149         THROW_IE_EXCEPTION << "Unexpected behaviour: status == NeedAllocation but memory is already allocated.";
150
151     auto inputDesc = getInputDesc();
152     auto outputDesc = getOutputDesc();
153     if (!MKLDNNExtensionUtils::initTensorsAreEqual(outputDesc, inputDesc) ||
154             (inputDesc.getDims()[0] != 1 && inputDesc != outputDesc))
155         THROW_IE_EXCEPTION << "Cannot allocate memory. Nodes have primitive descriptors with different formats.";
156     if (inputDesc.getLayout() == InferenceEngine::Layout::ANY)
157         THROW_IE_EXCEPTION << "Cannot get input descriptor!";
158
159     auto parentPtr = getParent();
160     memoryPtr.reset(new MKLDNNMemory(parentPtr->getEngine()));
161     memoryPtr->Create(MKLDNNMemoryDesc(inputDesc), mem_ptr);
162     status = Status::Allocated;
163 }
164
165 void MKLDNNEdge::changeStatus(MKLDNNEdge::Status state) {
166     if (state == Status::NotAllocated) {
167         THROW_IE_EXCEPTION << "Incorrect behaviour! Use method sharedMemFrom()";
168     }
169     if (state == Status::Validated) {
170         THROW_IE_EXCEPTION << "Incorrect behaviour! Use method validate()";
171     }
172     if (status != Status::Uninitialized && state == Status::NeedAllocation)
173         return;
174     if (status == Status::NotAllocated)
175         memoryFromEdge.reset();
176     status = state;
177 }
178
179 const MKLDNNDims& MKLDNNEdge::getDims() {
180     if (!dims.ndims()) {
181         MKLDNNDims outDims;
182         MKLDNNDims inDims;
183         auto childPtr = getChild();
184         auto parentPtr = getParent();
185
186         int inNum = getOutputNum();
187         if (inNum < 0) {
188             THROW_IE_EXCEPTION << "Error cannot find input data for " << child.lock()->getName()
189                                << " from " << parent.lock()->getName();
190         }
191         if (inNum < childPtr->inDims.size()) {
192             outDims = childPtr->inDims[inNum];
193         }
194
195         int outNum = getInputNum();
196         if (outNum < 0) {
197             THROW_IE_EXCEPTION << "Error cannot find output data for " << parent.lock()->getName()
198                                << " to " << child.lock()->getName();
199         }
200         if (outNum >= parentPtr->outDims.size())
201             outNum = 0;
202         if (outNum < parentPtr->outDims.size()) {
203             inDims = parentPtr->outDims[outNum];
204         }
205
206         if (inDims.ndims() && outDims.ndims() && inDims.ndims() != outDims.ndims() && inDims.size() != outDims.size())
207             THROW_IE_EXCEPTION << "Nodes " << getParent()->getName() << " and " << getChild()->getName()
208                                << " have incompatible dimensions!";
209
210         dims = outDims.ndims() ? outDims : inDims;
211
212         if (!dims.ndims())
213             THROW_IE_EXCEPTION << "Cannot detect right dims for nodes " << getParent()->getName()
214                                << " and " << getChild()->getName();
215     }
216     return dims;
217 }
218
219 bool MKLDNNEdge::nodeCanChangeDesc(const MKLDNNNodePtr &node) const {
220     PrimitiveDescInfo * selectedPd = node->getSelectedPrimitiveDescriptor();
221     if (selectedPd == nullptr)
222         THROW_IE_EXCEPTION << "Primitive descriptor for node " << node->getName() << " is not selected.";
223
224     for (auto &inputDesc : selectedPd->getConfig().inConfs) {
225         if (inputDesc.desc.getLayout() != InferenceEngine::Layout::ANY) {
226             return true;
227         }
228     }
229
230     for (auto &outDesc : selectedPd->getConfig().outConfs) {
231         if (outDesc.desc.getLayout() != InferenceEngine::Layout::ANY) {
232             return true;
233         }
234     }
235
236     MKLDNNDims inputDims;
237     for (size_t i = 0; i < node->getParentEdges().size(); i++) {
238         if (inputDims.size() == 1 && inputDims.ndims() == 0) {
239             inputDims = node->getParentEdgeAt(i)->getDims();
240             continue;
241         }
242
243         if (inputDims.ndims() != node->getParentEdgeAt(i)->getDims().ndims()) {
244             return true;
245         }
246     }
247     for (size_t i = 0; i < node->getChildEdges().size(); i++) {
248         if (inputDims.size() == 1 && inputDims.ndims() == 0) {
249             inputDims = node->getChildEdgeAt(i)->getDims();
250             continue;
251         }
252
253         if (inputDims.ndims() != node->getChildEdgeAt(i)->getDims().ndims()) {
254             return true;
255         }
256     }
257
258     return false;
259 }
260
261 /// In we have {any, any, any} -> {any} or {any} -> {any, any, any} or {any} -> {any} it means that
262 /// layer doesn't change memory format
263 /// We don't support {any, any, nchw} -> {any}
264 InferenceEngine::TensorDesc MKLDNNEdge::getSpecifiedInputDesc(std::map<mkldnn::memory::format, size_t> formats) {
265     InferenceEngine::TensorDesc inDesc;
266     static int enterCount = 0;
267     enterCount++;
268
269     if (inputDesc.getLayout() != InferenceEngine::Layout::ANY) {
270         --enterCount;
271         return inputDesc;
272     }
273
274     auto parentPtr = getParent();
275     if (parentPtr->getSelectedPrimitiveDescriptor() == nullptr)
276         THROW_IE_EXCEPTION << "Primitive descriptor for node " << parentPtr->getName() << " is not selected.";
277
278     int inputIdx = getInputNum();
279     if (inputIdx < 0)
280         THROW_IE_EXCEPTION << "Edge cannot be found for node" << parentPtr->getName() << ".";
281
282     if (inputIdx >= parentPtr->getSelectedPrimitiveDescriptor()->getConfig().outConfs.size())
283         inputIdx = 0;
284     inDesc = parentPtr->getSelectedPrimitiveDescriptor()->getConfig().outConfs[inputIdx].desc;
285
286     if (inDesc.getLayout() != InferenceEngine::Layout::ANY) {
287         --enterCount;
288         return inDesc;
289     }
290
291     bool isFormatChanging = nodeCanChangeDesc(parentPtr);
292
293     if (!isFormatChanging && inputIdx < parentPtr->getSelectedPrimitiveDescriptor()->getConfig().inConfs.size() &&
294             parentPtr->getSelectedPrimitiveDescriptor()->getConfig().inConfs[inputIdx].desc.getLayout() != InferenceEngine::Layout::ANY) {
295         inDesc = parentPtr->getSelectedPrimitiveDescriptor()->getConfig().inConfs[inputIdx].desc;
296         parentPtr->getSelectedPrimitiveDescriptor()->getConfig().outConfs[inputIdx].desc = inDesc;
297         --enterCount;
298         return inDesc;
299     }
300
301     for (size_t i = 0; i < parentPtr->getChildEdges().size(); i++) {
302         auto childEdge = parentPtr->getChildEdgeAt(i);
303         auto child = childEdge->getChild();
304         int childIdx = childEdge->getOutputNum();
305         if (!child->getSelectedPrimitiveDescriptor() || childIdx < 0 ||
306                 childEdge->getDims().ndims() != getDims().ndims()) {
307             continue;
308         }
309         if (child->getSelectedPrimitiveDescriptor()->getConfig().inConfs.size() <= childIdx)
310             childIdx = 0;
311         memory::format childInDesc = MKLDNNMemoryDesc(child->getSelectedPrimitiveDescriptor()->getConfig().inConfs[childIdx].desc).getFormat();
312         if (childInDesc != memory::any && childInDesc != memory::format_undef) {
313             if (formats.find(childInDesc) == formats.end())
314                 formats[childInDesc] = 1;
315             else
316                 formats[childInDesc] += 1;
317             continue;
318         }
319         if (nodeCanChangeDesc(child))
320             continue;
321
322         if (enterCount < 2) {
323             childInDesc = MKLDNNMemoryDesc(childEdge->getSpecifiedOutputDesc(formats)).getFormat();
324             if (childInDesc != memory::any && childInDesc != memory::format_undef) {
325                 if (formats.find(childInDesc) == formats.end())
326                     formats[childInDesc] = 1;
327                 else
328                     formats[childInDesc] += 1;
329             }
330         }
331     }
332
333     if (!isFormatChanging) {
334         for (size_t i = 0; i < parentPtr->getParentEdges().size(); i++) {
335             auto parentEdge = parentPtr->getParentEdgeAt(i);
336             auto parent = parentEdge->getParent();
337             int parentIdx = parentEdge->getInputNum();
338             if (!parent->getSelectedPrimitiveDescriptor() || parentIdx < 0 ||
339                     parentEdge->getDims().ndims() != getDims().ndims()) {
340                 continue;
341             }
342             if (parent->getSelectedPrimitiveDescriptor()->getConfig().outConfs.size() <= parentIdx) {
343                 parentIdx = 0;
344             }
345             memory::format parentOutDesc = MKLDNNMemoryDesc(parent->getSelectedPrimitiveDescriptor()->getConfig().outConfs[parentIdx].desc).getFormat();
346             if (parentOutDesc != memory::any && parentOutDesc != memory::format_undef) {
347                 if (formats.find(parentOutDesc) == formats.end())
348                     formats[parentOutDesc] = 1;
349                 else
350                     formats[parentOutDesc] += 1;
351                 continue;
352             }
353             if (nodeCanChangeDesc(parent))
354                 continue;
355
356             if (enterCount < 2) {
357                 parentOutDesc = MKLDNNMemoryDesc(parentEdge->getSpecifiedInputDesc(formats)).getFormat();
358                 if (parentOutDesc != memory::any && parentOutDesc != memory::format_undef) {
359                     if (formats.find(parentOutDesc) == formats.end())
360                         formats[parentOutDesc] = 1;
361                     else
362                         formats[parentOutDesc] += 1;
363                 }
364             }
365         }
366     }
367
368     size_t maxFormatCount = 0;
369     memory::format desc =  MKLDNNMemory::GetPlainFormat(getDims());
370     for (auto &it : formats) {
371         if (maxFormatCount < it.second && MKLDNNMemory::isConsistant(getDims(), it.first)) {
372             maxFormatCount = it.second;
373             desc = it.first;
374         }
375     }
376
377     auto inDataType = MKLDNNMemoryDesc(parentPtr->getSelectedPrimitiveDescriptor()->getConfig().outConfs[inputIdx].desc).getDataType();
378     parentPtr->getSelectedPrimitiveDescriptor()->getConfig().outConfs[inputIdx].desc = MKLDNNMemoryDesc(getDims(), inDataType, desc);
379     if (!isFormatChanging && inputIdx < parentPtr->getSelectedPrimitiveDescriptor()->getConfig().inConfs.size() &&
380             parentPtr->getSelectedPrimitiveDescriptor()->getConfig().inConfs[inputIdx].desc.getLayout() == InferenceEngine::Layout::ANY) {
381         parentPtr->getSelectedPrimitiveDescriptor()->getConfig().inConfs[inputIdx].desc =
382                 MKLDNNExtensionUtils::getUninitTensorDesc(MKLDNNMemoryDesc(getDims(), inDataType, desc));
383     }
384
385     --enterCount;
386     return MKLDNNMemoryDesc(getDims(), inDataType, desc);
387 }
388
389 InferenceEngine::TensorDesc MKLDNNEdge::getSpecifiedOutputDesc(std::map<mkldnn::memory::format, size_t> formats) {
390     static int enterCount = 0;
391     enterCount++;
392     InferenceEngine::TensorDesc outDesc;
393
394     if (outputDesc.getLayout() != InferenceEngine::Layout::ANY) {
395         enterCount--;
396         return outputDesc;
397     }
398
399     auto childPtr = getChild();
400     auto parentPtr = getParent();
401
402     if (childPtr->getSelectedPrimitiveDescriptor() == nullptr)
403         THROW_IE_EXCEPTION << "Primitive descriptor for node " << childPtr->getName() << " is not selected.";
404
405     int outputIdx = getOutputNum();
406     int inputIdx = getInputNum();
407     if (outputIdx < 0) {
408         THROW_IE_EXCEPTION << "Edge cannot be found for node" << childPtr->getName() << ".";
409     }
410     if (outputIdx >= childPtr->getSelectedPrimitiveDescriptor()->getConfig().inConfs.size())
411         outputIdx = 0;
412     outDesc = childPtr->getSelectedPrimitiveDescriptor()->getConfig().inConfs[outputIdx].desc;
413
414     if (outDesc.getLayout() != InferenceEngine::Layout::ANY) {
415         enterCount--;
416         return outDesc;
417     }
418
419     if (inputIdx >= parentPtr->getSelectedPrimitiveDescriptor()->getConfig().outConfs.size())
420         inputIdx = 0;
421
422     bool isFormatChanging = nodeCanChangeDesc(childPtr);
423
424     if ((!isFormatChanging && outputIdx < childPtr->getSelectedPrimitiveDescriptor()->getConfig().outConfs.size() &&
425             childPtr->getSelectedPrimitiveDescriptor()->getConfig().outConfs[outputIdx].desc.getLayout() != InferenceEngine::Layout::ANY) ||
426             (isFormatChanging && inputIdx >= 0 &&
427                     parentPtr->getSelectedPrimitiveDescriptor()->getConfig().outConfs[inputIdx].desc.getLayout() != InferenceEngine::Layout::ANY)) {
428         auto inputDataType = childPtr->getSelectedPrimitiveDescriptor()->getConfig().inConfs[outputIdx].desc.getPrecision();
429         if (!isFormatChanging)
430             outDesc = childPtr->getSelectedPrimitiveDescriptor()->getConfig().outConfs[outputIdx].desc;
431         else
432             outDesc = parentPtr->getSelectedPrimitiveDescriptor()->getConfig().outConfs[inputIdx].desc;
433         childPtr->getSelectedPrimitiveDescriptor()->getConfig().inConfs[outputIdx].desc = InferenceEngine::TensorDesc(inputDataType, getDims().ToSizeVector(),
434                                                     {outDesc.getBlockingDesc().getBlockDims(),
435                                                      outDesc.getBlockingDesc().getOrder()});
436         enterCount--;
437         return childPtr->getSelectedPrimitiveDescriptor()->getConfig().inConfs[outputIdx].desc;
438     }
439
440     for (size_t i = 0; i < childPtr->getParentEdges().size(); i++) {
441         auto parentEdge = childPtr->getParentEdgeAt(i);
442         auto parent = parentEdge->getParent();
443         int parentIdx = parentEdge->getInputNum();
444         if (!parent->getSelectedPrimitiveDescriptor() || parentIdx < 0 ||
445                 parentEdge->getDims().ndims() != getDims().ndims()) {
446             continue;
447         }
448         if (parent->getSelectedPrimitiveDescriptor()->getConfig().outConfs.size() <= parentIdx) {
449             parentIdx = 0;
450         }
451         memory::format parentOutDesc = MKLDNNMemoryDesc(parent->getSelectedPrimitiveDescriptor()->getConfig().outConfs[parentIdx].desc).getFormat();
452         if (parentOutDesc != memory::any && parentOutDesc != memory::format_undef) {
453             if (formats.find(parentOutDesc) == formats.end())
454                 formats[parentOutDesc] = 1;
455             else
456                 formats[parentOutDesc] += 1;
457             continue;
458         }
459         if (nodeCanChangeDesc(parent))
460             continue;
461
462         if (enterCount < 2) {
463             parentOutDesc = MKLDNNMemoryDesc(parentEdge->getSpecifiedInputDesc(formats)).getFormat();
464             if (parentOutDesc != memory::any && parentOutDesc != memory::format_undef) {
465                 if (formats.find(parentOutDesc) == formats.end())
466                     formats[parentOutDesc] = 1;
467                 else
468                     formats[parentOutDesc] += 1;
469             }
470         }
471     }
472
473     if (!isFormatChanging) {
474         for (size_t i = 0; i < childPtr->getChildEdges().size(); i++) {
475             auto childEdge = childPtr->getChildEdgeAt(i);
476             auto child = childEdge->getChild();
477             int childIdx = childEdge->getOutputNum();
478             if (!child->getSelectedPrimitiveDescriptor() || childIdx < 0 ||
479                     childEdge->getDims().ndims() != getDims().ndims()) {
480                 continue;
481             }
482             if (child->getSelectedPrimitiveDescriptor()->getConfig().inConfs.size() <= childIdx) {
483                 childIdx = 0;
484             }
485             memory::format childInDesc = MKLDNNMemoryDesc(child->getSelectedPrimitiveDescriptor()->getConfig().inConfs[childIdx].desc).getFormat();
486             if (childInDesc != memory::any && childInDesc != memory::format_undef) {
487                 if (formats.find(childInDesc) == formats.end())
488                     formats[childInDesc] = 1;
489                 else
490                     formats[childInDesc] += 1;
491                 continue;
492             }
493             if (nodeCanChangeDesc(child))
494                 continue;
495
496             if (enterCount < 2) {
497                 childInDesc = MKLDNNMemoryDesc(childEdge->getSpecifiedOutputDesc(formats)).getFormat();
498                 if (childInDesc != memory::any && childInDesc != memory::format_undef) {
499                     if (formats.find(childInDesc) == formats.end())
500                         formats[childInDesc] = 1;
501                     else
502                         formats[childInDesc] += 1;
503                 }
504             }
505         }
506     }
507
508     size_t maxFormatCount = 0;
509     memory::format format =  MKLDNNMemory::GetPlainFormat(getDims());
510     for (auto &it : formats) {
511         if (maxFormatCount < it.second && MKLDNNMemory::isConsistant(getDims(), it.first)) {
512             maxFormatCount = it.second;
513             format = it.first;
514         }
515     }
516
517     auto inDataType = MKLDNNMemoryDesc(childPtr->getSelectedPrimitiveDescriptor()->getConfig().inConfs[getOutputNum()].desc).getDataType();
518     childPtr->getSelectedPrimitiveDescriptor()->getConfig().inConfs[outputIdx].desc = MKLDNNMemoryDesc(getDims(), inDataType, format);
519     if (!isFormatChanging && outputIdx < childPtr->getSelectedPrimitiveDescriptor()->getConfig().outConfs.size() &&
520             childPtr->getSelectedPrimitiveDescriptor()->getConfig().outConfs[outputIdx].desc.getLayout() == InferenceEngine::Layout::ANY) {
521         childPtr->getSelectedPrimitiveDescriptor()->getConfig().outConfs[outputIdx].desc =
522                 MKLDNNExtensionUtils::getUninitTensorDesc(MKLDNNMemoryDesc(getDims(), inDataType, format));
523     }
524
525     enterCount--;
526     return childPtr->getSelectedPrimitiveDescriptor()->getConfig().outConfs[outputIdx].desc;
527 }
528
529 const MKLDNNMemory &MKLDNNEdge::getMemory() {
530     if (status == Status::NotAllocated) {
531         memoryPtr.reset(new MKLDNNMemory(getParent()->getEngine()));
532         memoryPtr->Create(MKLDNNMemoryDesc(getDesc()), getSharedEdge()->getMemoryPtr()->GetData());
533         memoryFromEdge.reset();
534         changeStatus(Status::Allocated);
535     }
536
537     return *memoryPtr;
538 }
539
540 MKLDNNMemoryPtr &MKLDNNEdge::getMemoryPtr() {
541     if (status == Status::NotAllocated) {
542         memoryPtr.reset(new MKLDNNMemory(getParent()->getEngine()));
543         memoryPtr->Create(MKLDNNMemoryDesc(getDesc()), getSharedEdge()->getMemoryPtr()->GetData());
544         memoryFromEdge.reset();
545         changeStatus(Status::Allocated);
546     }
547
548     return memoryPtr;
549 }
550
551 InferenceEngine::Blob::Ptr MKLDNNEdge::getBlob() {
552     if (!memoryPtr || !dims.ndims())
553         THROW_IE_EXCEPTION << "Cannot get blob! Edge isn't initialized.";
554     InferenceEngine::TensorDesc desc = getDesc();
555
556     if (desc.getLayout() == InferenceEngine::Layout::ANY)
557         desc = InferenceEngine::TensorDesc(desc.getPrecision(), dims.ToSizeVector(), desc.getLayout());
558     else
559         desc = InferenceEngine::TensorDesc(desc.getPrecision(), dims.ToSizeVector(), desc.getBlockingDesc());
560
561     return make_blob_with_precision(desc, memoryPtr->GetData());
562 }
563
564 void MKLDNNEdge::sharedMemFrom(const MKLDNNEdgePtr &edge) {
565     memoryFromEdge = edge;
566     status = Status::NotAllocated;
567 }
568
569 void MKLDNNEdge::validate() {
570     if (status == Status::Validated)
571         return;
572     getMemory();
573     getParent();
574     getChild();
575     getDims();
576     if (status != Status::Allocated) {
577         THROW_IE_EXCEPTION << "Error memory is not allocated!";
578     }
579     status = Status::Validated;
580 }
581
582 MKLDNNEdgePtr MKLDNNEdge::getSharedEdge() const {
583     auto memoryFromEdgePtr = memoryFromEdge.lock();
584     if (!memoryFromEdgePtr) {
585         THROW_IE_EXCEPTION << "Cannot get memory ptr for edge(" << getParent()->getName() << "->"
586                            << getChild()->getName() << "). The pointer on the edge with memory is empty!";
587     }
588     return memoryFromEdgePtr;
589 }
590
591 void MKLDNNEdge::init() {
592     if (status != Status::NeedAllocation && status != Status::Uninitialized)
593         return;
594     MKLDNNEdgePtr edgePtr = getBaseEdge();
595     if (edgePtr.get() == this) {
596         changeStatus(Status::NeedAllocation);
597         auto port = getInputNum();
598         if (port < 0)
599             return;
600         auto edges_at_same_port = getParent()->getChildEdgesAtPort(static_cast<size_t>(port));
601         if (!edges_at_same_port.empty() &&
602             edgePtr != edges_at_same_port[0]) {
603             sharedMemFrom(edges_at_same_port[0]);
604         }
605     } else {
606         sharedMemFrom(edgePtr);
607         auto port = getInputNum();
608         if (port < 0)
609             return;
610         auto edges_at_same_port = getParent()->getChildEdgesAtPort(static_cast<size_t>(port));
611         for (auto edge : edges_at_same_port) {
612             if (edge->getStatus() != Status::NeedAllocation && edge->getStatus() != Status::Uninitialized) {
613                 if (edge->getSharedEdge() != edgePtr)
614                     THROW_IE_EXCEPTION << "Unsupported behavior. Cannot mark edge "
615                                        << getParent()->getChildEdgeAt(0)->getParent()->getName() << "->"
616                                        << getParent()->getChildEdgeAt(0)->getChild()->getName() << " as not allocated!";
617             } else {
618                 if (edge != edgePtr)
619                     edge->sharedMemFrom(edgePtr);
620             }
621         }
622     }
623 }
624
625 /**
626  * Should analyze graph node dependencies, inplace node information and return root memory(edge) it view on
627  *
628  * @param type some magic enum values... description needed
629  * @return root of view-on-memory subgraph
630  */
631 MKLDNNEdgePtr MKLDNNEdge::getBaseEdge(int look) {
632     auto parentConfig = getParent()->getSelectedPrimitiveDescriptor()->getConfig();
633     auto childConfig = getChild()->getSelectedPrimitiveDescriptor()->getConfig();
634     int inputNum = getInputNum();
635     int outputNum = getOutputNum();
636
637     if (childConfig.inConfs[outputNum].inPlace >= 0 && parentConfig.outConfs[inputNum].inPlace >= 0) {
638         inputNum = getInputNum();
639         return getParent()->getChildEdgeAt(inputNum);
640     }
641
642     if (childConfig.inConfs[outputNum].inPlace >= 0 && (look & LOOK_DOWN)) {
643         int next_port_idx = childConfig.inConfs[outputNum].inPlace;
644         if (childConfig.outConfs[next_port_idx].inPlace >= 0) {
645             childConfig.outConfs[next_port_idx].inPlace = -1;
646             getChild()->initDescriptor(childConfig);
647         }
648
649         auto ch_edges = getChild()->getChildEdgesAtPort(next_port_idx);
650         auto &next_ch_edge = ch_edges[0];
651
652         // Multiple connection to some out port
653         // Will try to find inplace consumer
654         for (auto &ch_edge : ch_edges) {
655             auto &chch_conf = ch_edge->getChild()->getSelectedPrimitiveDescriptor()->getConfig();
656
657             if (chch_conf.inConfs[ch_edge->getOutputNum()].inPlace >= 0)
658                 next_ch_edge = ch_edge;
659         }
660         return next_ch_edge->getBaseEdge(LOOK_DOWN);
661     } else if (parentConfig.outConfs[inputNum].inPlace >= 0 && (look & LOOK_UP)) {
662         int next_port_idx = parentConfig.outConfs[inputNum].inPlace;
663         if (parentConfig.inConfs[next_port_idx].inPlace >= 0) {
664             parentConfig.inConfs[next_port_idx].inPlace = -1;
665             getParent()->initDescriptor(parentConfig);
666         }
667         return getParent()->getParentEdgesAtPort(next_port_idx)[0]->getBaseEdge(LOOK_UP);
668     }
669
670     auto edges_for_same_port = getParent()->getChildEdgesAtPort(inputNum);
671     if (!(look & LOOK_NO_RECURRENT)) {
672         for (auto edge : edges_for_same_port) {
673             if (edge.get() != this) {
674                 auto base = edge->getBaseEdge(LOOK_BOTH | LOOK_NO_RECURRENT);
675                 if (base != edge) return base;
676             }
677         }
678     }
679     return edges_for_same_port[0];
680 }
681
682 bool MKLDNNEdge::inPlace(LOOK look) {
683     auto parentSPD = getParent()->getSelectedPrimitiveDescriptor();
684     auto childSPD = getChild()->getSelectedPrimitiveDescriptor();
685     if (!parentSPD || !childSPD)
686         THROW_IE_EXCEPTION << "Cannot make a decision about reorder. Primitive descriptors weren't selected.";
687     int inputNum = getInputNum();
688     int outputNum = getOutputNum();
689     if (inputNum >= parentSPD->getConfig().outConfs.size())
690         inputNum = 0;
691     if (outputNum >= childSPD->getConfig().inConfs.size())
692         outputNum = 0;
693
694     if (look & LOOK_UP) {
695         if (parentSPD->getConfig().outConfs[inputNum].inPlace >= 0)
696             return true;
697     }
698     if (look & LOOK_DOWN) {
699         if (childSPD->getConfig().inConfs[outputNum].inPlace >= 0)
700             return true;
701     }
702     return false;
703 }
704
705 }  // namespace MKLDNNPlugin