Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / mkldnn_plugin / mkldnn_edge.cpp
index 92c8c5a..7d13d01 100644 (file)
@@ -1,4 +1,4 @@
-// Copyright (C) 2018 Intel Corporation
+// Copyright (C) 2018-2019 Intel Corporation
 // SPDX-License-Identifier: Apache-2.0
 //
 
 #include <blob_factory.hpp>
 
 using namespace mkldnn;
-using namespace MKLDNNPlugin;
+namespace MKLDNNPlugin {
 
-MKLDNNPlugin::MKLDNNEdge::MKLDNNEdge(const std::shared_ptr<MKLDNNPlugin::MKLDNNNode> &parent,
-                                     const std::shared_ptr<MKLDNNPlugin::MKLDNNNode> &child) {
-    this->parent = parent;
-    this->child = child;
-}
+MKLDNNEdge::MKLDNNEdge(const MKLDNNNodePtr &parent, const MKLDNNNodePtr &child, int pr_port, int ch_port) :
+        parent(parent), child(child), parent_port(pr_port), child_port(ch_port) {}
 
-const std::shared_ptr<MKLDNNPlugin::MKLDNNNode> MKLDNNPlugin::MKLDNNEdge::getParent() const {
+const MKLDNNNodePtr MKLDNNEdge::getParent() const {
     auto parentPtr = parent.lock();
     if (!parentPtr)
         THROW_IE_EXCEPTION << "Edge contains empty parent node";
     return parentPtr;
 }
 
-const std::shared_ptr<MKLDNNPlugin::MKLDNNNode> MKLDNNPlugin::MKLDNNEdge::getChild() const {
+const MKLDNNNodePtr MKLDNNEdge::getChild() const {
     auto childPtr = child.lock();
     if (!childPtr)
         THROW_IE_EXCEPTION << "Edge contains empty child node";
     return childPtr;
 }
 
-bool MKLDNNPlugin::MKLDNNEdge::isDropped() {
-    return getInputNum() == -1 && getOutputNum() == -1;
+bool MKLDNNEdge::isDropped() {
+    bool not_in_parent = true;
+    bool not_in_child = true;
+
+    auto parent_ptr = parent.lock();
+    if (parent_ptr) {
+        for (auto &edge : parent_ptr->childEdges)
+            if (edge.lock().get() == this)
+                not_in_parent = false;
+    }
+
+    auto child_ptr = child.lock();
+    if (child_ptr) {
+        for (auto &edge : child_ptr->parentEdges)
+            if (edge.lock().get() == this)
+                not_in_child = false;
+    }
+    return not_in_parent && not_in_child;
 }
 
-bool MKLDNNPlugin::MKLDNNEdge::needReorder() {
+void MKLDNNEdge::drop() {
+    auto _drop_from = [&] (std::vector<MKLDNNEdgeWeakPtr> &list) {
+        auto myself = std::find_if(list.begin(), list.end(),
+                [&] (MKLDNNEdgeWeakPtr edge) { return edge.lock().get() == this; });
+
+        if (myself != list.end())
+            list.erase(myself);
+    };
+
+    _drop_from(getParent()->childEdges);
+    _drop_from(getChild()->parentEdges);
+}
+
+
+bool MKLDNNEdge::needReorder() {
     bool canBeInPlaceConflicts = false;
     auto parentSPD = getParent()->getSelectedPrimitiveDescriptor();
     auto childSPD = getChild()->getSelectedPrimitiveDescriptor();
     if (!parentSPD || !childSPD)
         THROW_IE_EXCEPTION << "Cannot make a decision about reorder. Primitive descriptors weren't selected.";
 
-    int inputNum = getInputNum();
+    int outNumber = getOutputNum();
+    int inNumber = getInputNum();
     bool in_place = inPlace();
-    if (in_place && !getParent()->getChildEdges().empty()) {
-        for (size_t i = 0; i < getParent()->getChildEdges().size(); i++) {
-            if (i == inputNum)
+    bool childCanChangeMem = childSPD->getConfig().outConfs.empty();
+    for (const auto conf : childSPD->getConfig().outConfs) {
+        if (conf.inPlace == outNumber && outNumber >= 0)
+            childCanChangeMem = true;
+    }
+
+    const auto& detectInPlaceChildsNum = [](const std::vector<MKLDNNEdgePtr>& edges) -> size_t {
+        size_t count = 0;
+        for (const auto& edge : edges) {
+            auto childSPD = edge->getChild()->getSelectedPrimitiveDescriptor();
+            int outNumber = edge->getOutputNum();
+            if (childSPD->getConfig().outConfs.empty())
+                count++;
+            for (const auto conf : childSPD->getConfig().outConfs) {
+                if (conf.inPlace == outNumber)
+                    count++;
+            }
+        }
+        return count;
+    };
+
+    const auto portChildEdges = getParent()->getChildEdgesAtPort(inNumber);
+    if (in_place && detectInPlaceChildsNum(portChildEdges) > 1 && childCanChangeMem)
+        canBeInPlaceConflicts = true;
+    if (!canBeInPlaceConflicts && in_place && !getParent()->getChildEdges().empty()) {
+        for (auto &p_edge_peer : portChildEdges) {
+            if (p_edge_peer.get() == this)
                 continue;
-            if (getParent()->getChildEdgeAt(i)->getChild()->getType() != Reorder && getParent()->getChildEdgeAt(i)->inPlace(LOOK_DOWN))
+            if (p_edge_peer->getChild()->getType() != Reorder && p_edge_peer->inPlace(LOOK_DOWN))
                 canBeInPlaceConflicts = true;
         }
     }
 
     if (in_place) {
-        int outNumber = getOutputNum();
-        int inNumber = getInputNum();
         if (inNumber >= 0 && inNumber < parentSPD->getConfig().outConfs.size() && parentSPD->getConfig().outConfs[inNumber].inPlace >= 0 &&
             outNumber >= 0 && outNumber < childSPD->getConfig().inConfs.size() && childSPD->getConfig().inConfs[outNumber].inPlace >= 0)
             canBeInPlaceConflicts = true;
     }
-    return !MKLDNNExtensionUtils::initTensorsAreEqual(getInputDesc(), getOutputDesc()) || canBeInPlaceConflicts;
+    return canBeInPlaceConflicts || !MKLDNNExtensionUtils::initTensorsAreEqual(getInputDesc(), getOutputDesc());
 }
 
-InferenceEngine::TensorDesc MKLDNNPlugin::MKLDNNEdge::getInputDesc() {
+InferenceEngine::TensorDesc MKLDNNEdge::getInputDesc() {
     if (inputDesc.getLayout() == InferenceEngine::Layout::ANY) {
         inputDesc = getSpecifiedInputDesc({});
     }
     return inputDesc;
 }
 
-InferenceEngine::TensorDesc MKLDNNPlugin::MKLDNNEdge::getOutputDesc() {
+InferenceEngine::TensorDesc MKLDNNEdge::getOutputDesc() {
     if (outputDesc.getLayout() == InferenceEngine::Layout::ANY) {
         outputDesc = getSpecifiedOutputDesc({});
     }
     return outputDesc;
 }
 
-InferenceEngine::TensorDesc MKLDNNPlugin::MKLDNNEdge::getDesc() {
+InferenceEngine::TensorDesc MKLDNNEdge::getDesc() {
     if (!MKLDNNExtensionUtils::initTensorsAreEqual(getInputDesc(), getOutputDesc()))
         THROW_IE_EXCEPTION << "Cannot get descriptor for edge: " << getParent()->getName() << "->"
                            << getChild()->getName();
     return getInputDesc();
 }
 
-int MKLDNNPlugin::MKLDNNEdge::getInputNum() {
-    return getAllInputNums()[0];
-}
-
-std::vector<int> MKLDNNPlugin::MKLDNNEdge::getAllInputNums() {
-    auto parentPtr = parent.lock();
-    if (!parentPtr)
-        return {-1};
-
-    std::vector<int> res;
-    for (size_t i = 0; i < parentPtr->getChildEdges().size(); i++) {
-        auto childEdge = parentPtr->getChildEdges()[i].lock();
-        if (childEdge && childEdge.get() == this) {
-            res.push_back(static_cast<int>(i));
-        }
-    }
-    return res.empty() ? std::vector<int>{-1} : res;
+int MKLDNNEdge::getInputNum() {
+    return parent_port;
 }
 
-int MKLDNNPlugin::MKLDNNEdge::getOutputNum() {
-    return getAllOutputNums()[0];
+int MKLDNNEdge::getOutputNum() {
+    return child_port;
 }
 
-std::vector<int> MKLDNNPlugin::MKLDNNEdge::getAllOutputNums() {
-    auto childPtr = child.lock();
-    if (!childPtr)
-        return {-1};
-
-    std::vector<int> res;
-    for (size_t i = 0; i < childPtr->getParentEdges().size(); i++) {
-        auto parentEdge = childPtr->getParentEdges()[i].lock();
-        if (parentEdge && parentEdge.get() == this) {
-            res.push_back(static_cast<int>(i));
-        }
-    }
-    return res.empty() ? std::vector<int>{-1} : res;
-}
-
-void MKLDNNPlugin::MKLDNNEdge::allocate(const void* mem_ptr) {
+void MKLDNNEdge::allocate(const void* mem_ptr) {
     if (status != Status::NeedAllocation)
         return;
 
@@ -142,7 +162,7 @@ void MKLDNNPlugin::MKLDNNEdge::allocate(const void* mem_ptr) {
     status = Status::Allocated;
 }
 
-void MKLDNNPlugin::MKLDNNEdge::changeStatus(MKLDNNPlugin::MKLDNNEdge::Status state) {
+void MKLDNNEdge::changeStatus(MKLDNNEdge::Status state) {
     if (state == Status::NotAllocated) {
         THROW_IE_EXCEPTION << "Incorrect behaviour! Use method sharedMemFrom()";
     }
@@ -156,7 +176,7 @@ void MKLDNNPlugin::MKLDNNEdge::changeStatus(MKLDNNPlugin::MKLDNNEdge::Status sta
     status = state;
 }
 
-MKLDNNPlugin::MKLDNNDims &MKLDNNPlugin::MKLDNNEdge::getDims() {
+const MKLDNNDims& MKLDNNEdge::getDims() {
     if (!dims.ndims()) {
         MKLDNNDims outDims;
         MKLDNNDims inDims;
@@ -196,11 +216,7 @@ MKLDNNPlugin::MKLDNNDims &MKLDNNPlugin::MKLDNNEdge::getDims() {
     return dims;
 }
 
-void MKLDNNPlugin::MKLDNNEdge::setDims(MKLDNNPlugin::MKLDNNDims &dims) {
-    this->dims = dims;
-}
-
-bool MKLDNNPlugin::MKLDNNEdge::nodeCanChangeDesc(const std::shared_ptr<MKLDNNPlugin::MKLDNNNode> &node) const {
+bool MKLDNNEdge::nodeCanChangeDesc(const MKLDNNNodePtr &node) const {
     PrimitiveDescInfo * selectedPd = node->getSelectedPrimitiveDescriptor();
     if (selectedPd == nullptr)
         THROW_IE_EXCEPTION << "Primitive descriptor for node " << node->getName() << " is not selected.";
@@ -245,7 +261,7 @@ bool MKLDNNPlugin::MKLDNNEdge::nodeCanChangeDesc(const std::shared_ptr<MKLDNNPlu
 /// In we have {any, any, any} -> {any} or {any} -> {any, any, any} or {any} -> {any} it means that
 /// layer doesn't change memory format
 /// We don't support {any, any, nchw} -> {any}
-InferenceEngine::TensorDesc MKLDNNPlugin::MKLDNNEdge::getSpecifiedInputDesc(std::map<mkldnn::memory::format, size_t> formats) {
+InferenceEngine::TensorDesc MKLDNNEdge::getSpecifiedInputDesc(std::map<mkldnn::memory::format, size_t> formats) {
     InferenceEngine::TensorDesc inDesc;
     static int enterCount = 0;
     enterCount++;
@@ -370,7 +386,7 @@ InferenceEngine::TensorDesc MKLDNNPlugin::MKLDNNEdge::getSpecifiedInputDesc(std:
     return MKLDNNMemoryDesc(getDims(), inDataType, desc);
 }
 
-InferenceEngine::TensorDesc MKLDNNPlugin::MKLDNNEdge::getSpecifiedOutputDesc(std::map<mkldnn::memory::format, size_t> formats) {
+InferenceEngine::TensorDesc MKLDNNEdge::getSpecifiedOutputDesc(std::map<mkldnn::memory::format, size_t> formats) {
     static int enterCount = 0;
     enterCount++;
     InferenceEngine::TensorDesc outDesc;
@@ -510,7 +526,7 @@ InferenceEngine::TensorDesc MKLDNNPlugin::MKLDNNEdge::getSpecifiedOutputDesc(std
     return childPtr->getSelectedPrimitiveDescriptor()->getConfig().outConfs[outputIdx].desc;
 }
 
-const MKLDNNPlugin::MKLDNNMemory &MKLDNNPlugin::MKLDNNEdge::getMemory() {
+const MKLDNNMemory &MKLDNNEdge::getMemory() {
     if (status == Status::NotAllocated) {
         memoryPtr.reset(new MKLDNNMemory(getParent()->getEngine()));
         memoryPtr->Create(MKLDNNMemoryDesc(getDesc()), getSharedEdge()->getMemoryPtr()->GetData());
@@ -521,7 +537,7 @@ const MKLDNNPlugin::MKLDNNMemory &MKLDNNPlugin::MKLDNNEdge::getMemory() {
     return *memoryPtr;
 }
 
-MKLDNNPlugin::MKLDNNMemoryPtr &MKLDNNPlugin::MKLDNNEdge::getMemoryPtr() {
+MKLDNNMemoryPtr &MKLDNNEdge::getMemoryPtr() {
     if (status == Status::NotAllocated) {
         memoryPtr.reset(new MKLDNNMemory(getParent()->getEngine()));
         memoryPtr->Create(MKLDNNMemoryDesc(getDesc()), getSharedEdge()->getMemoryPtr()->GetData());
@@ -545,12 +561,12 @@ InferenceEngine::Blob::Ptr MKLDNNEdge::getBlob() {
     return make_blob_with_precision(desc, memoryPtr->GetData());
 }
 
-void MKLDNNPlugin::MKLDNNEdge::sharedMemFrom(const MKLDNNPlugin::MKLDNNEdgePtr &edge) {
+void MKLDNNEdge::sharedMemFrom(const MKLDNNEdgePtr &edge) {
     memoryFromEdge = edge;
     status = Status::NotAllocated;
 }
 
-void MKLDNNPlugin::MKLDNNEdge::validate() {
+void MKLDNNEdge::validate() {
     if (status == Status::Validated)
         return;
     getMemory();
@@ -563,7 +579,7 @@ void MKLDNNPlugin::MKLDNNEdge::validate() {
     status = Status::Validated;
 }
 
-MKLDNNPlugin::MKLDNNEdgePtr MKLDNNPlugin::MKLDNNEdge::getSharedEdge() const {
+MKLDNNEdgePtr MKLDNNEdge::getSharedEdge() const {
     auto memoryFromEdgePtr = memoryFromEdge.lock();
     if (!memoryFromEdgePtr) {
         THROW_IE_EXCEPTION << "Cannot get memory ptr for edge(" << getParent()->getName() << "->"
@@ -578,44 +594,45 @@ void MKLDNNEdge::init() {
     MKLDNNEdgePtr edgePtr = getBaseEdge();
     if (edgePtr.get() == this) {
         changeStatus(Status::NeedAllocation);
-        if (getInputNum() > 0 && getParent()->getSelectedPrimitiveDescriptor() &&
-            getParent()->getSelectedPrimitiveDescriptor()->getConfig().outConfs.size() <= getInputNum() &&
-            edgePtr != getParent()->getChildEdgeAt(0)) {
-            sharedMemFrom(getParent()->getChildEdgeAt(0));
+        auto port = getInputNum();
+        if (port < 0)
+            return;
+        auto edges_at_same_port = getParent()->getChildEdgesAtPort(static_cast<size_t>(port));
+        if (!edges_at_same_port.empty() &&
+            edgePtr != edges_at_same_port[0]) {
+            sharedMemFrom(edges_at_same_port[0]);
         }
     } else {
         sharedMemFrom(edgePtr);
-        if (getInputNum() > 0 && getParent()->getSelectedPrimitiveDescriptor() &&
-                getParent()->getSelectedPrimitiveDescriptor()->getConfig().outConfs.size() <= getInputNum() &&
-                edgePtr != getParent()->getChildEdgeAt(0)) {
-            if (getParent()->getChildEdgeAt(0)->getStatus() != Status::NeedAllocation &&
-                    getParent()->getChildEdgeAt(0)->getStatus() != Status::Uninitialized) {
-                if (getParent()->getChildEdgeAt(0)->getSharedEdge() != edgePtr)
+        auto port = getInputNum();
+        if (port < 0)
+            return;
+        auto edges_at_same_port = getParent()->getChildEdgesAtPort(static_cast<size_t>(port));
+        for (auto edge : edges_at_same_port) {
+            if (edge->getStatus() != Status::NeedAllocation && edge->getStatus() != Status::Uninitialized) {
+                if (edge->getSharedEdge() != edgePtr)
                     THROW_IE_EXCEPTION << "Unsupported behavior. Cannot mark edge "
                                        << getParent()->getChildEdgeAt(0)->getParent()->getName() << "->"
                                        << getParent()->getChildEdgeAt(0)->getChild()->getName() << " as not allocated!";
             } else {
-                getParent()->getChildEdgeAt(0)->sharedMemFrom(edgePtr);
+                if (edge != edgePtr)
+                    edge->sharedMemFrom(edgePtr);
             }
         }
     }
 }
 
 /**
- * Should analize graph node dependensies, inplace node information and return root memory(edge) it view on
+ * Should analyze graph node dependencies, inplace node information and return root memory(edge) it view on
  *
  * @param type some magic enum values... description needed
  * @return root of view-on-memory subgraph
  */
-MKLDNNEdgePtr MKLDNNEdge::getBaseEdge(LOOK look) {
+MKLDNNEdgePtr MKLDNNEdge::getBaseEdge(int look) {
     auto parentConfig = getParent()->getSelectedPrimitiveDescriptor()->getConfig();
     auto childConfig = getChild()->getSelectedPrimitiveDescriptor()->getConfig();
     int inputNum = getInputNum();
     int outputNum = getOutputNum();
-    if (inputNum >= parentConfig.outConfs.size())
-        inputNum = 0;
-    if (outputNum >= childConfig.inConfs.size())
-        outputNum = 0;
 
     if (childConfig.inConfs[outputNum].inPlace >= 0 && parentConfig.outConfs[inputNum].inPlace >= 0) {
         inputNum = getInputNum();
@@ -623,37 +640,43 @@ MKLDNNEdgePtr MKLDNNEdge::getBaseEdge(LOOK look) {
     }
 
     if (childConfig.inConfs[outputNum].inPlace >= 0 && (look & LOOK_DOWN)) {
-        int next_edge_ind = childConfig.inConfs[outputNum].inPlace;
-        if (childConfig.outConfs[next_edge_ind].inPlace >= 0) {
-            childConfig.outConfs[next_edge_ind].inPlace = -1;
+        int next_port_idx = childConfig.inConfs[outputNum].inPlace;
+        if (childConfig.outConfs[next_port_idx].inPlace >= 0) {
+            childConfig.outConfs[next_port_idx].inPlace = -1;
             getChild()->initDescriptor(childConfig);
         }
 
-        // this is a WA ... :-(
-        if (childConfig.outConfs.size() <= getChild()->getChildEdges().size()) {
-            // Multiple connection to some out port.
-            // Will try to find implace consumer.
-            for (int i = 0; i< getChild()->getChildEdges().size(); i++) {
-                auto chch_edge = getChild()->getChildEdgeAt(i);
-                auto chch_conf = chch_edge->getChild()->getSelectedPrimitiveDescriptor()->getConfig();
+        auto ch_edges = getChild()->getChildEdgesAtPort(next_port_idx);
+        auto &next_ch_edge = ch_edges[0];
 
+        // Multiple connection to some out port
+        // Will try to find inplace consumer
+        for (auto &ch_edge : ch_edges) {
+            auto &chch_conf = ch_edge->getChild()->getSelectedPrimitiveDescriptor()->getConfig();
 
-                if (chch_conf.inConfs[chch_edge->getOutputNum()].inPlace >= 0) {
-                    next_edge_ind = i;
-                }
-            }
+            if (chch_conf.inConfs[ch_edge->getOutputNum()].inPlace >= 0)
+                next_ch_edge = ch_edge;
         }
-        return getChild()->getChildEdgeAt(next_edge_ind)->getBaseEdge(LOOK_DOWN);
+        return next_ch_edge->getBaseEdge(LOOK_DOWN);
     } else if (parentConfig.outConfs[inputNum].inPlace >= 0 && (look & LOOK_UP)) {
-        if (parentConfig.inConfs[parentConfig.outConfs[inputNum].inPlace].inPlace >= 0) {
-            parentConfig.inConfs[parentConfig.outConfs[inputNum].inPlace].inPlace = -1;
+        int next_port_idx = parentConfig.outConfs[inputNum].inPlace;
+        if (parentConfig.inConfs[next_port_idx].inPlace >= 0) {
+            parentConfig.inConfs[next_port_idx].inPlace = -1;
             getParent()->initDescriptor(parentConfig);
         }
-        return getParent()->getParentEdgeAt(parentConfig.outConfs[inputNum].inPlace)->getBaseEdge(LOOK_UP);
+        return getParent()->getParentEdgesAtPort(next_port_idx)[0]->getBaseEdge(LOOK_UP);
     }
 
-    inputNum = getInputNum();
-    return getParent()->getChildEdgeAt(inputNum);
+    auto edges_for_same_port = getParent()->getChildEdgesAtPort(inputNum);
+    if (!(look & LOOK_NO_RECURRENT)) {
+        for (auto edge : edges_for_same_port) {
+            if (edge.get() != this) {
+                auto base = edge->getBaseEdge(LOOK_BOTH | LOOK_NO_RECURRENT);
+                if (base != edge) return base;
+            }
+        }
+    }
+    return edges_for_same_port[0];
 }
 
 bool MKLDNNEdge::inPlace(LOOK look) {
@@ -671,18 +694,12 @@ bool MKLDNNEdge::inPlace(LOOK look) {
     if (look & LOOK_UP) {
         if (parentSPD->getConfig().outConfs[inputNum].inPlace >= 0)
             return true;
-        for (const auto &inConf : parentSPD->getConfig().inConfs) {
-            if (inConf.inPlace == inputNum)
-                return true;
-        }
     }
     if (look & LOOK_DOWN) {
         if (childSPD->getConfig().inConfs[outputNum].inPlace >= 0)
             return true;
-        for (const auto &outConf : childSPD->getConfig().outConfs) {
-            if (outConf.inPlace == inputNum)
-                return true;
-        }
     }
     return false;
 }
+
+}  // namespace MKLDNNPlugin