1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
9 #include "mkldnn_memory.h"
10 #include "mkldnn_dims.h"
14 namespace MKLDNNPlugin {
19 using MKLDNNEdgePtr = std::shared_ptr<MKLDNNEdge>;
20 using MKLDNNEdgeWeakPtr = std::weak_ptr<MKLDNNEdge>;
22 class MKLDNNEdge : public InferenceEngine::details::no_copy {
24 MKLDNNEdge(const std::shared_ptr<MKLDNNNode>& parent,
25 const std::shared_ptr<MKLDNNNode>& child,
26 int pr_port = 0, int ch_port = 0);
36 inline Status getStatus() const noexcept {
40 void changeStatus(Status state);
43 virtual void allocate(const void* mem_ptr = nullptr);
44 virtual void validate();
47 const std::shared_ptr<MKLDNNNode> getParent() const;
48 const std::shared_ptr<MKLDNNNode> getChild() const;
50 InferenceEngine::Blob::Ptr getBlob();
51 InferenceEngine::TensorDesc getDesc();
53 const MKLDNNDims &getDims();
54 const MKLDNNMemory& getMemory();
55 MKLDNNMemoryPtr& getMemoryPtr();
63 void sharedMemFrom(const MKLDNNEdgePtr& edge);
64 MKLDNNEdgePtr getSharedEdge() const;
67 std::weak_ptr<MKLDNNNode> parent;
68 std::weak_ptr<MKLDNNNode> child;
72 MKLDNNEdgeWeakPtr memoryFromEdge;
74 MKLDNNMemoryPtr memoryPtr;
75 Status status = Status::Uninitialized;
77 InferenceEngine::TensorDesc getInputDesc();
78 InferenceEngine::TensorDesc getOutputDesc();
79 InferenceEngine::TensorDesc getSpecifiedInputDesc(std::map<mkldnn::memory::format, size_t> formats);
80 InferenceEngine::TensorDesc getSpecifiedOutputDesc(std::map<mkldnn::memory::format, size_t> formats);
82 InferenceEngine::TensorDesc inputDesc;
83 InferenceEngine::TensorDesc outputDesc;
85 bool nodeCanChangeDesc(const std::shared_ptr<MKLDNNPlugin::MKLDNNNode>& node) const;
87 enum LOOK { LOOK_UP = 1, LOOK_DOWN = 2, LOOK_BOTH = LOOK_UP | LOOK_DOWN, LOOK_NO_RECURRENT = 4 };
89 MKLDNNEdgePtr getBaseEdge(int look = LOOK_BOTH);
90 bool inPlace(LOOK look = LOOK_BOTH);
91 friend class MKLDNNGraph;
94 } // namespace MKLDNNPlugin