Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / mkldnn_plugin / mkldnn_edge.h
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #pragma once
6
7 #include <ie_api.h>
8 #include <memory>
9 #include "mkldnn_memory.h"
10 #include "mkldnn_dims.h"
11 #include <map>
12 #include <vector>
13
14 namespace MKLDNNPlugin {
15
16 class MKLDNNNode;
17 class MKLDNNEdge;
18
19 using MKLDNNEdgePtr = std::shared_ptr<MKLDNNEdge>;
20 using MKLDNNEdgeWeakPtr = std::weak_ptr<MKLDNNEdge>;
21
22 class MKLDNNEdge : public InferenceEngine::details::no_copy {
23 public:
24     MKLDNNEdge(const std::shared_ptr<MKLDNNNode>& parent,
25                const std::shared_ptr<MKLDNNNode>& child,
26                int pr_port = 0, int ch_port = 0);
27
28     enum class Status {
29         Uninitialized,
30         NeedAllocation,
31         NotAllocated,
32         Allocated,
33         Validated
34     };
35
36     inline Status getStatus() const noexcept {
37         return status;
38     }
39
40     void changeStatus(Status state);
41
42     virtual void init();
43     virtual void allocate(const void* mem_ptr = nullptr);
44     virtual void validate();
45     void drop();
46
47     const std::shared_ptr<MKLDNNNode> getParent() const;
48     const std::shared_ptr<MKLDNNNode> getChild() const;
49
50     InferenceEngine::Blob::Ptr getBlob();
51     InferenceEngine::TensorDesc getDesc();
52
53     const MKLDNNDims &getDims();
54     const MKLDNNMemory& getMemory();
55     MKLDNNMemoryPtr& getMemoryPtr();
56
57     bool needReorder();
58     bool isDropped();
59
60     int getInputNum();
61     int getOutputNum();
62
63     void sharedMemFrom(const MKLDNNEdgePtr& edge);
64     MKLDNNEdgePtr getSharedEdge() const;
65
66 private:
67     std::weak_ptr<MKLDNNNode> parent;
68     std::weak_ptr<MKLDNNNode> child;
69     int parent_port;
70     int child_port;
71
72     MKLDNNEdgeWeakPtr memoryFromEdge;
73     MKLDNNDims dims;
74     MKLDNNMemoryPtr memoryPtr;
75     Status status = Status::Uninitialized;
76
77     InferenceEngine::TensorDesc getInputDesc();
78     InferenceEngine::TensorDesc getOutputDesc();
79     InferenceEngine::TensorDesc getSpecifiedInputDesc(std::map<mkldnn::memory::format, size_t> formats);
80     InferenceEngine::TensorDesc getSpecifiedOutputDesc(std::map<mkldnn::memory::format, size_t> formats);
81
82     InferenceEngine::TensorDesc inputDesc;
83     InferenceEngine::TensorDesc outputDesc;
84
85     bool nodeCanChangeDesc(const std::shared_ptr<MKLDNNPlugin::MKLDNNNode>& node) const;
86
87     enum LOOK { LOOK_UP = 1, LOOK_DOWN = 2, LOOK_BOTH = LOOK_UP | LOOK_DOWN, LOOK_NO_RECURRENT = 4 };
88
89     MKLDNNEdgePtr getBaseEdge(int look = LOOK_BOTH);
90     bool inPlace(LOOK look = LOOK_BOTH);
91     friend class MKLDNNGraph;
92 };
93
94 }  // namespace MKLDNNPlugin