1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
8 #include <mkldnn_node.h>
13 namespace MKLDNNPlugin {
15 class MKLDNNRNN : public MKLDNNNode {
17 MKLDNNRNN(const InferenceEngine::CNNLayerPtr& layer, const mkldnn::engine& eng);
18 ~MKLDNNRNN() override = default;
20 void getSupportedDescriptors() override;
21 void createPrimitive() override;
22 bool created() const override;
24 void createDescriptor(const std::vector<InferenceEngine::TensorDesc>& inputDesc,
25 const std::vector<InferenceEngine::TensorDesc>& outputDesc) override;
27 void execute(mkldnn::stream strm) override;
34 static Register<MKLDNNRNN> reg;
36 /** Specify mode Cell or Seq. true - Cell, false - Seq */
39 /** Native order if [batch, seq, data], other case is [seq, batch, data] */
40 bool nativeOrder = true;
42 /** Direction of iteration through sequence dimension */
43 mkldnn::rnn_direction direction = mkldnn::unidirectional;
45 /** RNN Cell desc (type/activation_alg/clip)*/
46 mkldnn::rnn_cell::desc cell_desc { mkldnn::algorithm::vanilla_lstm };
48 // Internal attributes
49 ptrdiff_t N = 0; /**< Batch value */
50 ptrdiff_t T = 0; /**< Sequence value */
51 ptrdiff_t DC = 0; /**< Input data channel size */
52 ptrdiff_t SC = 0; /**< State channel size value */
53 ptrdiff_t G = 0; /**< Gate size. LSTM - 4, GRU - 3, RNN - 1 */
54 ptrdiff_t Gb = 0; /**< Gate size for biases. Gb = GRU_lbr ? G+1 : G */
55 ptrdiff_t S = 2; /**< Num of state. LSTM - 2, GRU & RNN - 1 */
56 const ptrdiff_t L = 1; /**< What is it??. Constant for mkldnn impl */
57 const ptrdiff_t D = 1; /**< Num of direction. 1 or 2 */
59 MKLDNNMemoryDesc in_data_d;
60 MKLDNNMemoryDesc out_data_d;
62 MKLDNNMemoryDesc in_state_d;
63 MKLDNNMemoryDesc out_state_d;
65 MKLDNNMemoryDesc w_data_d;
66 MKLDNNMemoryDesc w_state_d;
67 MKLDNNMemoryDesc w_bias_d;
69 // List of in/out reorders if required
70 std::vector<mkldnn::reorder> exec_before;
71 std::vector<mkldnn::reorder> exec_after;
74 } // namespace MKLDNNPlugin