Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / gtests / test_rnn_forward.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <utility>
18 #include <numeric>
19
20 #include "gtest/gtest.h"
21 #include "mkldnn_test_common.hpp"
22
23 #include "mkldnn.hpp"
24
25 namespace mkldnn {
26
27 struct test_rnn_sizes_t {
28     test_rnn_sizes_t(
29         int l, int d, int t, int mb,
30         int slc, int sic, int dlc, int dic) :
31         l(l), d(d), t(t), mb(mb),
32         slc(slc), sic(sic), dlc(dlc), dic(dic) {}
33     int l, d;
34     int t;
35     int mb;
36     int slc, sic, dlc, dic;
37 };
38
39 struct test_rnn_formats_t {
40     mkldnn::memory::format src_layer_fmt;
41     mkldnn::memory::format src_iter_fmt;
42     mkldnn::memory::format weights_layer_fmt;
43     mkldnn::memory::format weights_iter_fmt;
44     mkldnn::memory::format bias_fmt;
45     mkldnn::memory::format dst_layer_fmt;
46     mkldnn::memory::format dst_iter_fmt;
47 };
48
49 struct test_rnn_params_t {
50     const mkldnn::engine::kind engine_kind;
51     mkldnn::algorithm aalgorithm;
52     mkldnn::algorithm activation;
53     mkldnn::rnn_direction direction;
54     test_rnn_formats_t fmts;
55     test_rnn_sizes_t sizes;
56     bool expect_to_fail;
57     mkldnn_status_t expected_status;
58 };
59
60 // We assume uniform data type accross tensors for now
61 template <typename data_t>
62 class rnn_forward_test
63     : public ::testing::TestWithParam<test_rnn_params_t> {
64 protected:
65     virtual void SetUp() {
66         auto p = ::testing::TestWithParam<test_rnn_params_t>::GetParam();
67         catch_expected_failures([=](){Test();}, p.expect_to_fail,
68                 p.expected_status, false);
69     }
70
71     void Test() {
72         auto p = ::testing::TestWithParam<test_rnn_params_t>::GetParam();
73         ASSERT_TRUE(p.engine_kind == engine::kind::cpu);
74         auto eng = engine(p.engine_kind, 0);
75         //@todo check algorithm is one of the supported by RNN
76         //ASSERT_EQ(p.aalgorithm, algorithm::vanilla_lstm);
77
78         // Initialize the data
79         memory::data_type prec = data_traits<data_t>::data_type;
80         auto dims = p.sizes;
81         auto t = dims.t, mb = dims.mb, l = dims.l, d = dims.d;
82         auto slc = dims.slc, sic = dims.sic, dlc = dims.dlc, dic = dims.dic;
83         int s, g;
84
85         switch (p.aalgorithm) {
86         case vanilla_lstm:
87             g = 4; s = 2; break;
88         case vanilla_gru:
89         case gru_linear_before_reset:
90             g = 3; s = 1; break;
91         default:
92             g = 1; s = 1; break;
93         };
94
95         mkldnn::memory::dims weights_layer_dims = {l, d, slc, g, dic};
96         mkldnn::memory::dims weights_iter_dims = {l, d, sic, g, dic};
97         mkldnn::memory::dims bias_dims = {l, d, g, dic};
98         mkldnn::memory::dims src_layer_dims = {t, mb, slc};
99         mkldnn::memory::dims src_iter_dims = {l, d, s, mb, sic};
100         mkldnn::memory::dims dst_layer_dims = {t, mb, dlc};
101         mkldnn::memory::dims dst_iter_dims = {l, d, s, mb, dic};
102
103         auto weights_layer_md_any = memory::desc({weights_layer_dims}, prec, memory::format::any);
104         auto weights_iter_md_any = memory::desc({weights_iter_dims}, prec, memory::format::any);
105         auto bias_md_any = memory::desc({bias_dims}, prec, memory::format::any);
106         auto src_layer_md_any = memory::desc({src_layer_dims}, prec, memory::format::any);
107         auto src_iter_md_any = memory::desc({src_iter_dims}, prec, memory::format::any);
108         auto dst_layer_md_any = memory::desc({dst_layer_dims}, prec, memory::format::any);
109         auto dst_iter_md_any = memory::desc({dst_iter_dims}, prec, memory::format::any);
110
111         auto weights_layer_md_tgt = memory::desc({weights_layer_dims}, prec, p.fmts.weights_layer_fmt);
112         auto weights_iter_md_tgt = memory::desc({weights_iter_dims}, prec, p.fmts.weights_iter_fmt);
113         auto bias_md_tgt = memory::desc({bias_dims}, prec, p.fmts.bias_fmt);
114         auto src_layer_md_tgt = memory::desc({src_layer_dims}, prec, p.fmts.src_layer_fmt);
115         auto src_iter_md_tgt = memory::desc({src_iter_dims}, prec, p.fmts.src_iter_fmt);
116         auto dst_layer_md_tgt = memory::desc({dst_layer_dims}, prec, p.fmts.dst_layer_fmt);
117         auto dst_iter_md_tgt = memory::desc({dst_iter_dims}, prec, p.fmts.dst_iter_fmt);
118
119
120         // Create the reference descriptor
121         rnn_cell::desc cell(p.aalgorithm, p.activation);
122         auto direction = p.direction;
123
124         rnn_forward::desc ref_desc(prop_kind::forward_inference, cell,
125                 direction, src_layer_md_any, src_iter_md_any,
126                 weights_layer_md_any, weights_iter_md_any, bias_md_any,
127                 dst_layer_md_any, dst_iter_md_any);
128         auto ref_prim_desc = rnn_forward::primitive_desc(ref_desc, eng);
129
130         // Query the descriptor for memory descriptors
131         auto weights_layer_md_ref = ref_prim_desc.weights_layer_primitive_desc().desc();
132         auto weights_iter_md_ref = ref_prim_desc.weights_iter_primitive_desc().desc();
133         auto bias_md_ref = ref_prim_desc.bias_primitive_desc().desc();
134         auto src_layer_md_ref = ref_prim_desc.src_layer_primitive_desc().desc();
135         auto src_iter_md_ref = ref_prim_desc.src_iter_primitive_desc().desc();
136         auto dst_layer_md_ref = ref_prim_desc.dst_layer_primitive_desc().desc();
137         auto dst_iter_md_ref = ref_prim_desc.dst_iter_primitive_desc().desc();
138
139         auto are_equal_md = [](memory::desc a, memory::desc b, engine eng){
140             return memory::primitive_desc(a, eng)
141                     == memory::primitive_desc(b, eng);
142         };
143
144         bool skip_test =
145             are_equal_md(weights_layer_md_ref, weights_layer_md_tgt, eng)
146             && are_equal_md(weights_iter_md_ref, weights_iter_md_tgt, eng)
147             && are_equal_md(bias_md_ref, bias_md_tgt, eng)
148             && are_equal_md(src_layer_md_ref, src_layer_md_tgt, eng)
149             && are_equal_md(src_iter_md_ref, src_iter_md_tgt, eng)
150             && are_equal_md(dst_layer_md_ref, dst_layer_md_tgt, eng)
151             && are_equal_md(dst_iter_md_ref, dst_iter_md_tgt, eng);
152
153         if (skip_test) return;
154
155         /* initialize data */
156         auto weights_layer_ref = memory({weights_layer_md_ref, eng});
157         auto weights_iter_ref = memory({weights_iter_md_ref, eng});
158         auto bias_ref = memory({bias_md_ref, eng});
159         auto src_layer_ref = memory({src_layer_md_ref, eng});
160         auto src_iter_ref = memory({src_iter_md_ref, eng});
161         auto dst_layer_ref = memory({dst_layer_md_ref, eng});
162         auto dst_iter_ref = memory({dst_iter_md_ref, eng});
163
164         auto weights_layer_tgt = memory({weights_layer_md_tgt, eng});
165         auto weights_iter_tgt = memory({weights_iter_md_tgt, eng});
166         auto bias_tgt = memory({bias_md_tgt, eng});
167         auto src_layer_tgt = memory({src_layer_md_tgt, eng});
168         auto src_iter_tgt = memory({src_iter_md_tgt, eng});
169         auto dst_layer_tgt = memory({dst_layer_md_tgt, eng});
170         auto dst_iter_tgt = memory({dst_iter_md_tgt, eng});
171
172         auto init_tensor = [&](memory a, memory b) {
173             auto a_ptr = static_cast<float *>(a.get_data_handle());
174             auto desc = a.get_primitive_desc().desc();
175             auto a_dims = desc.data.dims;
176             auto a_ndims = desc.data.ndims;
177             auto n_elems = std::accumulate(a_dims, a_dims + a_ndims, size_t(1),
178                     std::multiplies<float>());
179             for(size_t i = 0; i < n_elems; i++)
180                 a_ptr[map_index(desc, i, false)] = i;
181             stream(stream::kind::eager).submit({reorder(a, b)}).wait();
182         };
183
184         init_tensor(weights_layer_ref, weights_layer_tgt);
185         init_tensor(weights_iter_ref, weights_iter_tgt);
186         init_tensor(bias_ref, bias_tgt);
187         init_tensor(src_layer_ref, src_layer_tgt);
188         init_tensor(src_iter_ref, src_iter_tgt);
189
190         // run the non packed version
191         auto prim_ref = rnn_forward(ref_prim_desc, src_layer_ref, src_iter_ref,
192                 weights_layer_ref, weights_iter_ref, bias_ref,
193                 dst_layer_ref, dst_iter_ref, null_memory(eng));
194         stream(stream::kind::eager).submit({prim_ref}).wait();
195
196         // run the packed version
197         rnn_forward::desc tgt_desc(prop_kind::forward_inference, cell,
198                 direction, src_layer_md_tgt, src_iter_md_tgt,
199                 weights_layer_md_tgt, weights_iter_md_tgt, bias_md_tgt,
200                 dst_layer_md_tgt, dst_iter_md_tgt);
201         auto tgt_prim_desc = rnn_forward::primitive_desc(tgt_desc, eng);
202         auto prim_tgt = rnn_forward(tgt_prim_desc, src_layer_tgt, src_iter_tgt,
203                 weights_layer_tgt, weights_iter_tgt, bias_tgt,
204                 dst_layer_tgt, dst_iter_tgt, null_memory(eng));
205         stream(stream::kind::eager).submit({prim_tgt}).wait();
206
207         // compare dst_layer and dst_iter
208         compare_data<data_t>(dst_layer_ref, dst_layer_tgt, 1e-5);
209         compare_data<data_t>(dst_iter_ref, dst_iter_tgt, 1e-5);
210     }
211 };
212
213     using eng = engine::kind;
214     using fmt = memory::format;
215     using alg = algorithm;
216     using dir = rnn_direction;
217     using rnn_forward_test_f32 = rnn_forward_test<float>;
218     using cfg_f32 = test_rnn_params_t;
219
220 TEST_P(rnn_forward_test_f32, TestsRnn) { }
221 INSTANTIATE_TEST_CASE_P(TestRnn, rnn_forward_test_f32,
222         ::testing::Values(
223             cfg_f32{eng::cpu, alg::vanilla_rnn, alg::eltwise_tanh, dir::unidirectional_left2right,
224                 {fmt::tnc, fmt::ldsnc, fmt::ldigo, fmt::ldigo, fmt::ldgo, fmt::tnc, fmt::ldsnc},
225                     test_rnn_sizes_t(1, 1, 10, 16, 100, 100, 100, 100)},
226             cfg_f32{eng::cpu, alg::vanilla_lstm, alg::eltwise_tanh, dir::unidirectional_left2right,
227                 {fmt::tnc, fmt::ldsnc, fmt::ldigo, fmt::ldigo, fmt::ldgo, fmt::tnc, fmt::ldsnc},
228                     test_rnn_sizes_t(1, 1, 10, 16, 100, 100, 100, 100)},
229             /* Check for invalid parameters: unsupported unrolling */
230             cfg_f32{eng::cpu, alg::vanilla_rnn, alg::eltwise_tanh, dir::unidirectional_left2right,
231                 {fmt::tnc, fmt::ldsnc, fmt::ldigo, fmt::ldigo, fmt::ldgo, fmt::tnc, fmt::ldsnc},
232                     test_rnn_sizes_t(2, 1, 10, 16, 200, 100, 100, 100), true, mkldnn_invalid_arguments},
233             cfg_f32{eng::cpu, alg::vanilla_rnn, alg::eltwise_tanh, dir::unidirectional_left2right,
234                 {fmt::tnc, fmt::ldsnc, fmt::ldigo, fmt::ldigo, fmt::ldgo, fmt::tnc, fmt::ldsnc},
235                     test_rnn_sizes_t(2, 1, 10, 16, 100, 200, 100, 100), true, mkldnn_invalid_arguments},
236             /* Check for invalid parameters: inconsistent dimensions */
237             cfg_f32{eng::cpu, alg::vanilla_rnn, alg::eltwise_tanh, dir::unidirectional_left2right,
238                 {fmt::tnc, fmt::ldsnc, fmt::ldigo, fmt::ldigo, fmt::ldgo, fmt::tnc, fmt::ldsnc},
239                     test_rnn_sizes_t(2, 1, 10, 16, 100, 100, 50, 100), true, mkldnn_invalid_arguments}
240             )
241     );
242
243 }