Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / mkldnn_plugin / nodes / mkldnn_activation_node.h
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #pragma once
6
7 #include <ie_common.h>
8 #include <mkldnn_node.h>
9 #include "details/caseless.hpp"
10 #include <string>
11 #include <memory>
12 #include <vector>
13
14 namespace MKLDNNPlugin {
15
16 class MKLDNNActivationNode : public MKLDNNNode {
17 public:
18     MKLDNNActivationNode(const InferenceEngine::CNNLayerPtr& layer, const mkldnn::engine& eng);
19     ~MKLDNNActivationNode() override = default;
20
21     void getSupportedDescriptors() override;
22     void initOptimalPrimitiveDescriptor() override;
23     void createDescriptor(const std::vector<InferenceEngine::TensorDesc>& inputDesc,
24                           const std::vector<InferenceEngine::TensorDesc>& outputDesc) override;
25     void createPrimitive() override;
26     bool created() const override;
27
28     mkldnn::algorithm getAlgorithm() {
29         if (!initialized)
30             initValues();
31         return algorithm;
32     }
33
34     float getAlpha() {
35         if (!initialized)
36             initValues();
37         return alpha;
38     }
39
40     float getBeta() {
41         if (!initialized)
42             initValues();
43         return beta;
44     }
45
46     MKLDNNMemoryDesc getSrcMemDesc(mkldnn::primitive_desc_iterator &primitive_desc_it, size_t idx) override;
47     MKLDNNMemoryDesc getDstMemDesc(mkldnn::primitive_desc_iterator &primitive_desc_it, size_t idx) override;
48
49 private:
50     void initValues();
51     static Register<MKLDNNActivationNode> reg;
52     bool initialized = false;
53     float alpha = 0.0f;
54     float beta = 0.0f;
55     static InferenceEngine::details::caseless_map<std::string,
56             std::function<void(InferenceEngine::GenericLayer*, mkldnn::algorithm&, float&, float&)>> initializers;
57     mkldnn::algorithm algorithm;
58 };
59
60 }  // namespace MKLDNNPlugin
61