1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
20 #include "gtest/gtest.h"
21 #include "mkldnn_test_common.hpp"
27 struct test_rnn_sizes_t {
29 int l, int d, int t, int mb,
30 int slc, int sic, int dlc, int dic) :
31 l(l), d(d), t(t), mb(mb),
32 slc(slc), sic(sic), dlc(dlc), dic(dic) {}
36 int slc, sic, dlc, dic;
39 struct test_rnn_formats_t {
40 mkldnn::memory::format src_layer_fmt;
41 mkldnn::memory::format src_iter_fmt;
42 mkldnn::memory::format weights_layer_fmt;
43 mkldnn::memory::format weights_iter_fmt;
44 mkldnn::memory::format bias_fmt;
45 mkldnn::memory::format dst_layer_fmt;
46 mkldnn::memory::format dst_iter_fmt;
49 struct test_rnn_params_t {
50 const mkldnn::engine::kind engine_kind;
51 mkldnn::algorithm aalgorithm;
52 mkldnn::algorithm activation;
53 mkldnn::rnn_direction direction;
54 test_rnn_formats_t fmts;
55 test_rnn_sizes_t sizes;
57 mkldnn_status_t expected_status;
60 // We assume uniform data type accross tensors for now
61 template <typename data_t>
62 class rnn_forward_test
63 : public ::testing::TestWithParam<test_rnn_params_t> {
65 virtual void SetUp() {
66 auto p = ::testing::TestWithParam<test_rnn_params_t>::GetParam();
67 catch_expected_failures([=](){Test();}, p.expect_to_fail,
68 p.expected_status, false);
72 auto p = ::testing::TestWithParam<test_rnn_params_t>::GetParam();
73 ASSERT_TRUE(p.engine_kind == engine::kind::cpu);
74 auto eng = engine(p.engine_kind, 0);
75 //@todo check algorithm is one of the supported by RNN
76 //ASSERT_EQ(p.aalgorithm, algorithm::vanilla_lstm);
78 // Initialize the data
79 memory::data_type prec = data_traits<data_t>::data_type;
81 auto t = dims.t, mb = dims.mb, l = dims.l, d = dims.d;
82 auto slc = dims.slc, sic = dims.sic, dlc = dims.dlc, dic = dims.dic;
85 switch (p.aalgorithm) {
89 case gru_linear_before_reset:
95 mkldnn::memory::dims weights_layer_dims = {l, d, slc, g, dic};
96 mkldnn::memory::dims weights_iter_dims = {l, d, sic, g, dic};
97 mkldnn::memory::dims bias_dims = {l, d, g, dic};
98 mkldnn::memory::dims src_layer_dims = {t, mb, slc};
99 mkldnn::memory::dims src_iter_dims = {l, d, s, mb, sic};
100 mkldnn::memory::dims dst_layer_dims = {t, mb, dlc};
101 mkldnn::memory::dims dst_iter_dims = {l, d, s, mb, dic};
103 auto weights_layer_md_any = memory::desc({weights_layer_dims}, prec, memory::format::any);
104 auto weights_iter_md_any = memory::desc({weights_iter_dims}, prec, memory::format::any);
105 auto bias_md_any = memory::desc({bias_dims}, prec, memory::format::any);
106 auto src_layer_md_any = memory::desc({src_layer_dims}, prec, memory::format::any);
107 auto src_iter_md_any = memory::desc({src_iter_dims}, prec, memory::format::any);
108 auto dst_layer_md_any = memory::desc({dst_layer_dims}, prec, memory::format::any);
109 auto dst_iter_md_any = memory::desc({dst_iter_dims}, prec, memory::format::any);
111 auto weights_layer_md_tgt = memory::desc({weights_layer_dims}, prec, p.fmts.weights_layer_fmt);
112 auto weights_iter_md_tgt = memory::desc({weights_iter_dims}, prec, p.fmts.weights_iter_fmt);
113 auto bias_md_tgt = memory::desc({bias_dims}, prec, p.fmts.bias_fmt);
114 auto src_layer_md_tgt = memory::desc({src_layer_dims}, prec, p.fmts.src_layer_fmt);
115 auto src_iter_md_tgt = memory::desc({src_iter_dims}, prec, p.fmts.src_iter_fmt);
116 auto dst_layer_md_tgt = memory::desc({dst_layer_dims}, prec, p.fmts.dst_layer_fmt);
117 auto dst_iter_md_tgt = memory::desc({dst_iter_dims}, prec, p.fmts.dst_iter_fmt);
120 // Create the reference descriptor
121 rnn_cell::desc cell(p.aalgorithm, p.activation);
122 auto direction = p.direction;
124 rnn_forward::desc ref_desc(prop_kind::forward_inference, cell,
125 direction, src_layer_md_any, src_iter_md_any,
126 weights_layer_md_any, weights_iter_md_any, bias_md_any,
127 dst_layer_md_any, dst_iter_md_any);
128 auto ref_prim_desc = rnn_forward::primitive_desc(ref_desc, eng);
130 // Query the descriptor for memory descriptors
131 auto weights_layer_md_ref = ref_prim_desc.weights_layer_primitive_desc().desc();
132 auto weights_iter_md_ref = ref_prim_desc.weights_iter_primitive_desc().desc();
133 auto bias_md_ref = ref_prim_desc.bias_primitive_desc().desc();
134 auto src_layer_md_ref = ref_prim_desc.src_layer_primitive_desc().desc();
135 auto src_iter_md_ref = ref_prim_desc.src_iter_primitive_desc().desc();
136 auto dst_layer_md_ref = ref_prim_desc.dst_layer_primitive_desc().desc();
137 auto dst_iter_md_ref = ref_prim_desc.dst_iter_primitive_desc().desc();
139 auto are_equal_md = [](memory::desc a, memory::desc b, engine eng){
140 return memory::primitive_desc(a, eng)
141 == memory::primitive_desc(b, eng);
145 are_equal_md(weights_layer_md_ref, weights_layer_md_tgt, eng)
146 && are_equal_md(weights_iter_md_ref, weights_iter_md_tgt, eng)
147 && are_equal_md(bias_md_ref, bias_md_tgt, eng)
148 && are_equal_md(src_layer_md_ref, src_layer_md_tgt, eng)
149 && are_equal_md(src_iter_md_ref, src_iter_md_tgt, eng)
150 && are_equal_md(dst_layer_md_ref, dst_layer_md_tgt, eng)
151 && are_equal_md(dst_iter_md_ref, dst_iter_md_tgt, eng);
153 if (skip_test) return;
155 /* initialize data */
156 auto weights_layer_ref = memory({weights_layer_md_ref, eng});
157 auto weights_iter_ref = memory({weights_iter_md_ref, eng});
158 auto bias_ref = memory({bias_md_ref, eng});
159 auto src_layer_ref = memory({src_layer_md_ref, eng});
160 auto src_iter_ref = memory({src_iter_md_ref, eng});
161 auto dst_layer_ref = memory({dst_layer_md_ref, eng});
162 auto dst_iter_ref = memory({dst_iter_md_ref, eng});
164 auto weights_layer_tgt = memory({weights_layer_md_tgt, eng});
165 auto weights_iter_tgt = memory({weights_iter_md_tgt, eng});
166 auto bias_tgt = memory({bias_md_tgt, eng});
167 auto src_layer_tgt = memory({src_layer_md_tgt, eng});
168 auto src_iter_tgt = memory({src_iter_md_tgt, eng});
169 auto dst_layer_tgt = memory({dst_layer_md_tgt, eng});
170 auto dst_iter_tgt = memory({dst_iter_md_tgt, eng});
172 auto init_tensor = [&](memory a, memory b) {
173 auto a_ptr = static_cast<float *>(a.get_data_handle());
174 auto desc = a.get_primitive_desc().desc();
175 auto a_dims = desc.data.dims;
176 auto a_ndims = desc.data.ndims;
177 auto n_elems = std::accumulate(a_dims, a_dims + a_ndims, size_t(1),
178 std::multiplies<float>());
179 for(size_t i = 0; i < n_elems; i++)
180 a_ptr[map_index(desc, i, false)] = i;
181 stream(stream::kind::eager).submit({reorder(a, b)}).wait();
184 init_tensor(weights_layer_ref, weights_layer_tgt);
185 init_tensor(weights_iter_ref, weights_iter_tgt);
186 init_tensor(bias_ref, bias_tgt);
187 init_tensor(src_layer_ref, src_layer_tgt);
188 init_tensor(src_iter_ref, src_iter_tgt);
190 // run the non packed version
191 auto prim_ref = rnn_forward(ref_prim_desc, src_layer_ref, src_iter_ref,
192 weights_layer_ref, weights_iter_ref, bias_ref,
193 dst_layer_ref, dst_iter_ref, null_memory(eng));
194 stream(stream::kind::eager).submit({prim_ref}).wait();
196 // run the packed version
197 rnn_forward::desc tgt_desc(prop_kind::forward_inference, cell,
198 direction, src_layer_md_tgt, src_iter_md_tgt,
199 weights_layer_md_tgt, weights_iter_md_tgt, bias_md_tgt,
200 dst_layer_md_tgt, dst_iter_md_tgt);
201 auto tgt_prim_desc = rnn_forward::primitive_desc(tgt_desc, eng);
202 auto prim_tgt = rnn_forward(tgt_prim_desc, src_layer_tgt, src_iter_tgt,
203 weights_layer_tgt, weights_iter_tgt, bias_tgt,
204 dst_layer_tgt, dst_iter_tgt, null_memory(eng));
205 stream(stream::kind::eager).submit({prim_tgt}).wait();
207 // compare dst_layer and dst_iter
208 compare_data<data_t>(dst_layer_ref, dst_layer_tgt, 1e-5);
209 compare_data<data_t>(dst_iter_ref, dst_iter_tgt, 1e-5);
213 using eng = engine::kind;
214 using fmt = memory::format;
215 using alg = algorithm;
216 using dir = rnn_direction;
217 using rnn_forward_test_f32 = rnn_forward_test<float>;
218 using cfg_f32 = test_rnn_params_t;
220 TEST_P(rnn_forward_test_f32, TestsRnn) { }
221 INSTANTIATE_TEST_CASE_P(TestRnn, rnn_forward_test_f32,
223 cfg_f32{eng::cpu, alg::vanilla_rnn, alg::eltwise_tanh, dir::unidirectional_left2right,
224 {fmt::tnc, fmt::ldsnc, fmt::ldigo, fmt::ldigo, fmt::ldgo, fmt::tnc, fmt::ldsnc},
225 test_rnn_sizes_t(1, 1, 10, 16, 100, 100, 100, 100)},
226 cfg_f32{eng::cpu, alg::vanilla_lstm, alg::eltwise_tanh, dir::unidirectional_left2right,
227 {fmt::tnc, fmt::ldsnc, fmt::ldigo, fmt::ldigo, fmt::ldgo, fmt::tnc, fmt::ldsnc},
228 test_rnn_sizes_t(1, 1, 10, 16, 100, 100, 100, 100)},
229 /* Check for invalid parameters: unsupported unrolling */
230 cfg_f32{eng::cpu, alg::vanilla_rnn, alg::eltwise_tanh, dir::unidirectional_left2right,
231 {fmt::tnc, fmt::ldsnc, fmt::ldigo, fmt::ldigo, fmt::ldgo, fmt::tnc, fmt::ldsnc},
232 test_rnn_sizes_t(2, 1, 10, 16, 200, 100, 100, 100), true, mkldnn_invalid_arguments},
233 cfg_f32{eng::cpu, alg::vanilla_rnn, alg::eltwise_tanh, dir::unidirectional_left2right,
234 {fmt::tnc, fmt::ldsnc, fmt::ldigo, fmt::ldigo, fmt::ldgo, fmt::tnc, fmt::ldsnc},
235 test_rnn_sizes_t(2, 1, 10, 16, 100, 200, 100, 100), true, mkldnn_invalid_arguments},
236 /* Check for invalid parameters: inconsistent dimensions */
237 cfg_f32{eng::cpu, alg::vanilla_rnn, alg::eltwise_tanh, dir::unidirectional_left2right,
238 {fmt::tnc, fmt::ldsnc, fmt::ldigo, fmt::ldigo, fmt::ldgo, fmt::tnc, fmt::ldsnc},
239 test_rnn_sizes_t(2, 1, 10, 16, 100, 100, 50, 100), true, mkldnn_invalid_arguments}