Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / mkldnn_plugin / nodes / mkldnn_memory_node.hpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #pragma once
6
7 #include <ie_common.h>
8 #include "ie_algorithm.hpp"
9 #include "mkldnn_input_node.h"
10 #include <mkldnn_node.h>
11 #include <string>
12 #include <memory>
13 #include <map>
14
15 namespace MKLDNNPlugin {
16
17 class MKLDNNMemoryNode {
18     std::string _id;
19  public:
20     explicit MKLDNNMemoryNode(std::string id) : _id(id) {}
21     explicit MKLDNNMemoryNode(InferenceEngine::CNNLayerPtr lp) {
22         if (lp->params.find("id") != lp->params.end()) {
23             _id = lp->GetParamAsString("id");
24         }
25     }
26     virtual ~MKLDNNMemoryNode() = default;
27     std::string getId() {
28         return _id;
29     }
30     virtual void setInputNode(MKLDNNNode *) = 0;
31 };
32 class MKLDNNMemoryOutputNode;
33 class MKLDNNMemoryInputNode;
34
35 /**
36  * @brief
37  * TODO: ATTENTION: this is a temporary solution, this connection should be keep in graph
38  */
39 class MKLDNNMemoryNodeVirtualEdge {
40     using Holder = std::map<std::string, MKLDNNMemoryNode*>;
41     static Holder & getExisted() {
42         static Holder existed;
43         return existed;
44     }
45
46     static MKLDNNMemoryNode * getByName(std::string name) {
47         auto result = getExisted().find(name);
48         if (result != getExisted().end()) {
49             return result->second;
50         }
51         return nullptr;
52     }
53
54  public:
55     static void registerOutput(MKLDNNMemoryOutputNode * node);
56     static void registerInput(MKLDNNMemoryInputNode * node);
57     static void remove(MKLDNNMemoryNode * node) {
58         InferenceEngine::details::erase_if(getExisted(), [&](const Holder::value_type & it){
59             return it.second == node;
60         });
61         // std::cout <<"[remove]   " << node << ", size="<< getExisted().size() <<"\n" << std::flush;
62     }
63 };
64
65 class MKLDNNMemoryOutputNode : public MKLDNNNode, public MKLDNNMemoryNode {
66  public:
67     MKLDNNMemoryOutputNode(const InferenceEngine::CNNLayerPtr& layer, const mkldnn::engine& eng);
68     ~MKLDNNMemoryOutputNode() override;
69     void getSupportedDescriptors() override;
70     void initSupportedPrimitiveDescriptors() override;
71     const MKLDNNEdgePtr getChildEdgeAt(size_t idx) const override;
72     void createPrimitive() override {}
73     void execute(mkldnn::stream strm) override;
74     bool created() const override {
75         return getType() == MemoryOutput;
76     }
77
78     void setInputNode(MKLDNNNode* node) override {
79         inputNode = node;
80     }
81  private:
82     /**
83      * @brief keeps reference to input sibling node
84      */
85     MKLDNNNode* inputNode = nullptr;
86     static Register<MKLDNNMemoryOutputNode> reg;
87 };
88
89
90 class MKLDNNMemoryInputNode : public MKLDNNInputNode, public MKLDNNMemoryNode {
91  protected:
92     static std::string nameFromCombinedName(std::string name);
93     static std::string idFromCombinedName(std::string name);
94  public:
95     MKLDNNMemoryInputNode(const InferenceEngine::CNNLayerPtr& layer, const mkldnn::engine& eng);
96     ~MKLDNNMemoryInputNode() override;
97
98     bool created() const override {
99         return getType() == MemoryInput;
100     }
101
102     void setInputNode(MKLDNNNode* node) override {}
103  private:
104     static Register<MKLDNNMemoryInputNode> reg;
105 };
106
107
108
109 }  // namespace MKLDNNPlugin
110