1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
8 #include "ie_algorithm.hpp"
9 #include "mkldnn_input_node.h"
10 #include <mkldnn_node.h>
15 namespace MKLDNNPlugin {
17 class MKLDNNMemoryNode {
20 explicit MKLDNNMemoryNode(std::string id) : _id(id) {}
21 explicit MKLDNNMemoryNode(InferenceEngine::CNNLayerPtr lp) {
22 if (lp->params.find("id") != lp->params.end()) {
23 _id = lp->GetParamAsString("id");
26 virtual ~MKLDNNMemoryNode() = default;
30 virtual void setInputNode(MKLDNNNode *) = 0;
32 class MKLDNNMemoryOutputNode;
33 class MKLDNNMemoryInputNode;
37 * TODO: ATTENTION: this is a temporary solution, this connection should be keep in graph
39 class MKLDNNMemoryNodeVirtualEdge {
40 using Holder = std::map<std::string, MKLDNNMemoryNode*>;
41 static Holder & getExisted() {
42 static Holder existed;
46 static MKLDNNMemoryNode * getByName(std::string name) {
47 auto result = getExisted().find(name);
48 if (result != getExisted().end()) {
49 return result->second;
55 static void registerOutput(MKLDNNMemoryOutputNode * node);
56 static void registerInput(MKLDNNMemoryInputNode * node);
57 static void remove(MKLDNNMemoryNode * node) {
58 InferenceEngine::details::erase_if(getExisted(), [&](const Holder::value_type & it){
59 return it.second == node;
61 // std::cout <<"[remove] " << node << ", size="<< getExisted().size() <<"\n" << std::flush;
65 class MKLDNNMemoryOutputNode : public MKLDNNNode, public MKLDNNMemoryNode {
67 MKLDNNMemoryOutputNode(const InferenceEngine::CNNLayerPtr& layer, const mkldnn::engine& eng);
68 ~MKLDNNMemoryOutputNode() override;
69 void getSupportedDescriptors() override;
70 void initSupportedPrimitiveDescriptors() override;
71 const MKLDNNEdgePtr getChildEdgeAt(size_t idx) const override;
72 void createPrimitive() override {}
73 void execute(mkldnn::stream strm) override;
74 bool created() const override {
75 return getType() == MemoryOutput;
78 void setInputNode(MKLDNNNode* node) override {
83 * @brief keeps reference to input sibling node
85 MKLDNNNode* inputNode = nullptr;
86 static Register<MKLDNNMemoryOutputNode> reg;
90 class MKLDNNMemoryInputNode : public MKLDNNInputNode, public MKLDNNMemoryNode {
92 static std::string nameFromCombinedName(std::string name);
93 static std::string idFromCombinedName(std::string name);
95 MKLDNNMemoryInputNode(const InferenceEngine::CNNLayerPtr& layer, const mkldnn::engine& eng);
96 ~MKLDNNMemoryInputNode() override;
98 bool created() const override {
99 return getType() == MemoryInput;
102 void setInputNode(MKLDNNNode* node) override {}
104 static Register<MKLDNNMemoryInputNode> reg;
109 } // namespace MKLDNNPlugin