Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / mkldnn_plugin / nodes / mkldnn_rnn.h
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #pragma once
6
7 #include <ie_common.h>
8 #include <mkldnn_node.h>
9 #include <string>
10 #include <memory>
11 #include <vector>
12
13 namespace MKLDNNPlugin {
14
15 class MKLDNNRNN : public MKLDNNNode {
16 public:
17     MKLDNNRNN(const InferenceEngine::CNNLayerPtr& layer, const mkldnn::engine& eng);
18     ~MKLDNNRNN() override = default;
19
20     void getSupportedDescriptors() override;
21     void createPrimitive() override;
22     bool created() const override;
23
24     void createDescriptor(const std::vector<InferenceEngine::TensorDesc>& inputDesc,
25                           const std::vector<InferenceEngine::TensorDesc>& outputDesc) override;
26
27     void execute(mkldnn::stream strm) override;
28
29 private:
30     void fillCellDesc();
31     void fillSeqDesc();
32
33 private:
34     static Register<MKLDNNRNN> reg;
35
36     /** Specify mode Cell or Seq. true - Cell, false - Seq */
37     bool is_cell = false;
38
39     /** Native order if [batch, seq, data], other case is [seq, batch, data] */
40     bool nativeOrder = true;
41
42     /** Direction of iteration through sequence dimension */
43     mkldnn::rnn_direction direction = mkldnn::unidirectional;
44
45     /** RNN Cell desc (type/activation_alg/clip)*/
46     mkldnn::rnn_cell::desc cell_desc { mkldnn::algorithm::vanilla_lstm };
47
48     // Internal attributes
49     ptrdiff_t N = 0;   /**< Batch value */
50     ptrdiff_t T = 0;   /**< Sequence value */
51     ptrdiff_t DC = 0;  /**< Input data channel size */
52     ptrdiff_t SC = 0;  /**< State channel size value */
53     ptrdiff_t G = 0;   /**< Gate size. LSTM - 4, GRU - 3, RNN - 1 */
54     ptrdiff_t Gb = 0;  /**< Gate size for biases. Gb = GRU_lbr ? G+1 : G */
55     ptrdiff_t S = 2;   /**< Num of state. LSTM - 2, GRU & RNN - 1 */
56     const ptrdiff_t L = 1;   /**< What is it??. Constant for mkldnn impl */
57     const ptrdiff_t D = 1;   /**< Num of direction. 1 or 2 */
58
59     MKLDNNMemoryDesc in_data_d;
60     MKLDNNMemoryDesc out_data_d;
61
62     MKLDNNMemoryDesc in_state_d;
63     MKLDNNMemoryDesc out_state_d;
64
65     MKLDNNMemoryDesc w_data_d;
66     MKLDNNMemoryDesc w_state_d;
67     MKLDNNMemoryDesc w_bias_d;
68
69     // List of in/out reorders if required
70     std::vector<mkldnn::reorder> exec_before;
71     std::vector<mkldnn::reorder> exec_after;
72 };
73
74 }  // namespace MKLDNNPlugin
75