1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
19 #include "mkldnn_test_common.hpp"
20 #include "gtest/gtest.h"
26 enum {ACROSS=0,WITHIN=1};
28 struct test_lrn_desc_t {
33 int kind; // 0 ac, 1 wc
36 struct lrn_test_params {
38 engine::kind engine_kind;
40 memory::format data_format;
41 memory::format diff_data_format;
42 test_lrn_desc_t test_ld;
44 mkldnn_status_t expected_status;
47 template <typename data_t>
48 void check_lrn_fwd(const lrn_test_params &p, const memory &src, const memory &dst)
50 data_t *src_ptr = (data_t *)src.get_data_handle();
51 data_t *dst_ptr = (data_t *)dst.get_data_handle();
53 const int C = p.test_ld.c;
54 const int H = p.test_ld.h;
55 const int W = p.test_ld.w;
56 const int size = p.test_ld.local_size;
57 const int CSIZE = p.test_ld.kind == ACROSS ? size : 1;
58 const int HWSIZE = size + 1 - CSIZE;
59 const int summands = p.test_ld.kind == ACROSS ? size : size*size;
60 const int padded_c = src.get_primitive_desc().desc().data.layout_desc.blocking.padding_dims[1];
62 const memory::desc src_d = src.get_primitive_desc().desc();
63 const memory::desc dst_d = dst.get_primitive_desc().desc();
65 auto off = [=](int n, int c, int h, int w)
67 return ((n * padded_c + c) * p.test_ld.h + h) * p.test_ld.w + w;
70 auto ker = [=](data_t *d, int n, int oc, int oh, int ow)
73 for (int c = oc; c < oc + CSIZE; ++c) {
74 if (c < (CSIZE - 1) / 2) continue;
75 if (c >= C + (CSIZE - 1) / 2) continue;
76 for (int h = oh; h < oh + HWSIZE; ++h) {
77 if (h < (HWSIZE - 1) / 2) continue;
78 if (h >= H + (HWSIZE - 1) / 2) continue;
79 for (int w = ow; w < ow + HWSIZE; ++w) {
80 if (w < (HWSIZE - 1) / 2) continue;
81 if (w >= W + (HWSIZE - 1) / 2) continue;
82 data_t s = src_ptr[map_index(src_d,off(n, c - (CSIZE - 1) / 2, h - (HWSIZE - 1) / 2, w - (HWSIZE - 1) / 2))];
88 auto const norm_coef = std::pow(p.test_ld.k + p.test_ld.alpha * sum / summands,
90 data_t ref_out = static_cast<data_t>(src_ptr[map_index(src_d, off(n, oc, oh, ow))]/norm_coef);
91 data_t eps = static_cast<data_t>(1.e-7f*(2*summands+5));
93 data_t norm_max = std::max(fabs(out), fabs(ref_out));
94 if (norm_max < eps) norm_max = 1.;
95 EXPECT_NEAR(out, ref_out, eps*norm_max);
98 const int N = p.test_ld.mb;
99 # pragma omp parallel for collapse(4) schedule(static)
100 for (int n = 0; n < N; ++n) {
101 for (int c = 0; c < padded_c; ++c) {
102 for (int h = 0; h < H; ++h) {
103 for (int w = 0; w < W; ++w) {
104 ker(&dst_ptr[map_index(dst_d,off(n, c, h, w))], n, c, h, w);
111 template <typename data_t>
112 void check_lrn_bwd(const lrn_test_params &p, const memory &src,
113 const memory &diff_dst, const memory &diff_src)
115 data_t *src_ptr = (data_t *)src.get_data_handle();
116 data_t *diff_dst_ptr = (data_t *)diff_dst.get_data_handle();
117 data_t *diff_src_ptr = (data_t *)diff_src.get_data_handle();
119 const int MB = p.test_ld.mb;
120 const int C = p.test_ld.c;
121 const int H = p.test_ld.h;
122 const int W = p.test_ld.w;
123 const int local_size = p.test_ld.local_size;
124 size_t padded_c = src.get_primitive_desc().desc().data.layout_desc.blocking.padding_dims[1];
126 data_t *ref_diff_src_ptr = new data_t[MB*(padded_c)*H*W];
128 const memory::desc src_d = src.get_primitive_desc().desc();
129 const memory::desc diff_dst_d = diff_dst.get_primitive_desc().desc();
130 const memory::desc diff_src_d = diff_src.get_primitive_desc().desc();
132 auto off = [=](int n, int c, int h, int w)
134 return ((n * padded_c + c) * H + h) * W + w;
137 auto get_omega = [=](data_t c_k, int kernel_size, float alpha, int C,
138 const data_t *src, int n, int c, int h, int w) {
141 int half_kernel_size = (kernel_size - 1) / 2;
142 int c_start = (c < half_kernel_size) ? 0 : c - half_kernel_size;
143 int c_end = c + kernel_size - half_kernel_size;
144 c_end = c_end < C ? c_end : C;
145 for (int i = c_start; i < c_end; ++i) {
146 data_t value = src[map_index(src_d, off(n, i, h, w))];
147 sum += value * value;
149 sum *= alpha / kernel_size;
153 auto ker = [=](data_t *d, int mb, int oc, int oh, int ow) {
154 const float alpha = p.test_ld.alpha;
155 const float beta = p.test_ld.beta;
156 const float k = p.test_ld.k;
157 const int kernel_size = p.test_ld.local_size;
158 int ks_start = kernel_size/2 > oc ? kernel_size/2 - oc : 0;
159 int ks_stop = C - oc <= kernel_size/2 ? C - oc + kernel_size/2 : kernel_size;
161 data_t A = 0, B = 0, omega_mid = 0;
163 for (int ks = ks_start; ks < ks_stop; ks++) {
164 int _t = oc + ks - (kernel_size/2);
165 data_t omega = get_omega(static_cast<data_t>(k), kernel_size, alpha, C,
166 src_ptr, mb, _t, oh, ow);
168 if (ks == kernel_size/2) omega_mid = omega;
170 data_t t = src_ptr[map_index(src_d, off(mb, _t, oh, ow))] / powf((float)omega, (float)beta);
171 B += (1.0f / omega) * t * diff_dst_ptr[map_index(diff_dst_d, off(mb, _t, oh, ow))];
174 A = (1.0f / powf((float)omega_mid, (float)beta))
175 * diff_dst_ptr[map_index(diff_dst_d, off(mb, oc, oh, ow))];
176 B *= src_ptr[map_index(src_d, off(mb, oc, oh, ow))];
177 B *= (2.0f * alpha * beta) / kernel_size;
181 # pragma omp parallel for collapse(4) schedule(static)
182 for (int mb = 0; mb < MB; ++mb) {
183 for (int c = 0; c < C; ++c) {
184 for (int h = 0; h < H; ++h) {
185 for (int w = 0; w < W; ++w) {
186 ker(&ref_diff_src_ptr[map_index(diff_src_d, off(mb, c, h, w))],
188 auto A = ref_diff_src_ptr[map_index(diff_src_d, off(mb, c, h, w))];
189 auto B = diff_src_ptr[map_index(diff_src_d, off(mb, c, h, w))];
190 data_t eps = static_cast<data_t>( 1.e-6*((2*(2*local_size + 3) + 6)*local_size
191 + (2*local_size + 3) + 9) );
192 data_t norm_max = std::max(fabs(A), fabs(B));
193 if (norm_max < eps) norm_max = 1.;
194 EXPECT_NEAR(A, B, eps*norm_max);
200 delete [] ref_diff_src_ptr;
203 template <typename data_t>
204 class lrn_test : public ::testing::TestWithParam<lrn_test_params> {
206 std::shared_ptr<test_memory> src;
207 std::shared_ptr<test_memory> dst;
208 std::shared_ptr<test_memory> diff_src;
209 std::shared_ptr<test_memory> diff_dst;
210 std::shared_ptr<memory> workspace;
211 std::shared_ptr<memory::desc> src_desc;
212 std::shared_ptr<memory::desc> dst_desc;
213 std::shared_ptr<memory::desc> diff_src_desc;
214 std::shared_ptr<memory::desc> diff_dst_desc;
215 std::shared_ptr<lrn_forward::primitive_desc> lrn_fwd_prim_desc;
216 std::shared_ptr<lrn_forward::primitive_desc> lrn_bwd_prim_desc;
219 std::shared_ptr<engine> eng;
220 memory::data_type data_type;
224 virtual void SetUp() {
225 p = ::testing::TestWithParam<decltype(p)>::GetParam();
226 catch_expected_failures([=](){Test();}, p.expect_to_fail,
231 p = ::testing::TestWithParam<decltype(p)>::GetParam();
233 ASSERT_TRUE(p.engine_kind == engine::kind::cpu);
234 eng.reset(new engine(p.engine_kind, 0));
235 data_type = data_traits<data_t>::data_type;
236 ASSERT_EQ(data_type, mkldnn::memory::data_type::f32);
238 test_lrn_desc_t ld = p.test_ld;
240 src_desc.reset(new memory::desc({ ld.mb, ld.c, ld.h, ld.w },
241 data_type, p.data_format));
242 dst_desc.reset(new memory::desc({ ld.mb, ld.c, ld.h, ld.w },
243 data_type, p.data_format));
244 diff_src_desc.reset(new memory::desc({ ld.mb, ld.c, ld.h, ld.w },
245 data_type, p.diff_data_format));
246 diff_dst_desc.reset(new memory::desc({ ld.mb, ld.c, ld.h, ld.w },
247 data_type, p.diff_data_format));
249 is_training = p.aprop_kind == prop_kind::forward_training;
257 auto lrn_desc = lrn_forward::desc(p.aprop_kind, p.aalgorithm, *src_desc,
258 p.test_ld.local_size, p.test_ld.alpha, p.test_ld.beta,
260 lrn_fwd_prim_desc.reset(new lrn_forward::primitive_desc(lrn_desc, *eng));
262 src.reset(new test_memory(*src_desc, *eng));
263 dst.reset(new test_memory(*dst_desc, *eng));
265 fill_data<data_t>(src->get_size() / sizeof(data_t),
266 (data_t *)src->get().get_data_handle());
267 fill_data<data_t>(dst->get_size() / sizeof(data_t),
268 (data_t *)dst->get().get_data_handle());
269 check_zero_tail<data_t>(1, src->get());
270 check_zero_tail<data_t>(1, dst->get());
273 std::vector<primitive> pipeline;
274 auto s = stream(stream::kind::lazy);
276 auto workspace_primitive_desc =
277 lrn_fwd_prim_desc->workspace_primitive_desc();
278 workspace.reset(new memory(workspace_primitive_desc));
279 auto l = lrn_forward(*lrn_fwd_prim_desc, src->get(), *workspace,
281 pipeline.push_back(l);
282 s.submit(pipeline).wait();
284 auto l = lrn_forward(*lrn_fwd_prim_desc, src->get(),
286 pipeline.push_back(l);
287 s.submit(pipeline).wait();
290 check_zero_tail<data_t>(0, dst->get());
292 check_lrn_fwd<data_t>(p, src->get(), dst->get());
297 auto lrn_desc = lrn_backward::desc(p.aalgorithm,
298 *src_desc, *diff_dst_desc, p.test_ld.local_size,
299 p.test_ld.alpha, p.test_ld.beta, p.test_ld.k);
301 src.reset(new test_memory(*src_desc, *eng));
302 diff_src.reset(new test_memory(*diff_src_desc, *eng));
303 diff_dst.reset(new test_memory(*diff_dst_desc, *eng));
305 auto lrn_prim_desc = lrn_backward::primitive_desc(lrn_desc, *eng,
308 fill_data<data_t>(src->get_size() / sizeof(data_t),
309 (data_t *)src->get().get_data_handle());
311 fill_data<data_t>(diff_dst->get_size() / sizeof(data_t),
312 (data_t *)diff_dst->get().get_data_handle());
314 fill_data<data_t>(diff_src->get_size() / sizeof(data_t),
315 (data_t *)diff_src->get().get_data_handle());
316 check_zero_tail<data_t>(1, src->get());
317 check_zero_tail<data_t>(1, diff_dst->get());
318 check_zero_tail<data_t>(1, diff_src->get());
321 std::vector<primitive> pipeline;
322 auto s = stream(stream::kind::lazy);
323 auto l = lrn_backward(lrn_prim_desc, src->get(), diff_dst->get(),
324 *workspace, diff_src->get());
325 pipeline.push_back(l);
326 s.submit(pipeline).wait();
328 check_zero_tail<data_t>(0, diff_src->get());
330 check_lrn_bwd<data_t>(p, src->get(), diff_dst->get(), diff_src->get());
334 using lrn_test_float = lrn_test<float>;
335 using lrn_test_params_float = lrn_test_params;
337 TEST_P(lrn_test_float, TestsLRN)
341 INSTANTIATE_TEST_CASE_P(TestLRNBackward_nChw16c_padded, lrn_test_float,
343 lrn_test_params_float{ prop_kind::forward_training,
344 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
345 memory::format::nChw16c, { 2, 17, 4, 4, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } }
346 , lrn_test_params_float{ prop_kind::forward_scoring,
347 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
348 memory::format::nChw16c, { 2, 19, 4, 4, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } }
349 , lrn_test_params_float{ prop_kind::forward_training,
350 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
351 memory::format::nChw16c, { 2, 26, 4, 4, 1.0e-4f, 0.75f, 5.7f, 5, ACROSS } }
352 , lrn_test_params_float{ prop_kind::forward_scoring,
353 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
354 memory::format::nChw16c, { 2, 12, 4, 4, 1.0e-4f, 0.75f, 5.7f, 5, ACROSS } }
357 INSTANTIATE_TEST_CASE_P(TestLRNForwardEF, lrn_test_float,
359 lrn_test_params_float{ prop_kind::forward_training,
360 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nchw,
361 memory::format::nchw, { 2, 10, 4, 4, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } }
362 , lrn_test_params_float{ prop_kind::forward_scoring,
363 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nchw,
364 memory::format::nchw, { 2, 10, 4, 4, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } }
365 , lrn_test_params_float{ prop_kind::forward_training,
366 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nchw,
367 memory::format::nchw, { 2, 10, 4, 4, 1.0e-4f, 0.75f, 3.0f, 5, ACROSS } }
368 , lrn_test_params_float{ prop_kind::forward_scoring,
369 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nchw,
370 memory::format::nchw, { 2, 10, 4, 4, 1.0e-4f, 0.75f, 3.0f, 5, ACROSS } }
373 INSTANTIATE_TEST_CASE_P(TestLRN, lrn_test_float,
375 lrn_test_params_float{ prop_kind::forward_training,
376 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nchw,
377 memory::format::nchw, { 2, 10, 4, 4, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } }
378 , lrn_test_params_float{ prop_kind::forward_scoring,
379 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nchw,
380 memory::format::nchw, { 2, 10, 4, 4, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } }
381 , lrn_test_params_float{ prop_kind::forward_training,
382 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nchw,
383 memory::format::nchw, { 2, 10, 4, 4, 1.0e-4f, 0.75f, 4.0f, 5, ACROSS } }
384 , lrn_test_params_float{ prop_kind::forward_scoring,
385 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nchw,
386 memory::format::nchw, { 2, 10, 4, 4, 1.0e-4f, 0.75f, 4.0f, 5, ACROSS } }
387 , lrn_test_params_float{ prop_kind::forward_scoring,
388 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nchw,
389 memory::format::nchw, { 20, 12, 7, 7, 1.0e-2f, 0.5f, 1.0f, 3, ACROSS } }
390 , lrn_test_params_float{ prop_kind::forward_scoring,
391 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nchw,
392 memory::format::nchw, { 20, 12, 7, 7, 1.0e-2f, 0.5f, 1.0f, 3, ACROSS } }
393 , lrn_test_params_float{ prop_kind::forward_scoring,
394 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nchw,
395 memory::format::nchw, { 20, 12, 7, 7, 1.0e-2f, 0.5f, 6.5f, 3, ACROSS } }
396 , lrn_test_params_float{ prop_kind::forward_scoring,
397 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nchw,
398 memory::format::nchw, { 20, 12, 7, 7, 1.0e-2f, 0.5f, 6.5f, 3, ACROSS } }
401 INSTANTIATE_TEST_CASE_P(TestLRNNHWC, lrn_test_float,
403 lrn_test_params_float{ prop_kind::forward_training,
404 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nhwc,
405 memory::format::nhwc, { 2, 10, 4, 4, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } }
406 , lrn_test_params_float{ prop_kind::forward_scoring,
407 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nhwc,
408 memory::format::nhwc, { 2, 10, 4, 4, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } }
409 , lrn_test_params_float{ prop_kind::forward_training,
410 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nhwc,
411 memory::format::nhwc, { 2, 10, 4, 4, 1.0e-4f, 0.75f, 3.0f, 5, ACROSS } }
412 , lrn_test_params_float{ prop_kind::forward_scoring,
413 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nhwc,
414 memory::format::nhwc, { 2, 10, 4, 4, 1.0e-4f, 0.75f, 3.0f, 5, ACROSS } }
417 INSTANTIATE_TEST_CASE_P(TestLRN_nChw8c, lrn_test_float,
419 lrn_test_params_float{ prop_kind::forward_training,
420 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw8c,
421 memory::format::nChw8c, { 2, 16, 4, 4, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } }
422 , lrn_test_params_float{ prop_kind::forward_scoring,
423 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw8c,
424 memory::format::nChw8c, { 2, 16, 4, 4, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } }
425 , lrn_test_params_float{ prop_kind::forward_training,
426 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw8c,
427 memory::format::nChw8c, { 2, 16, 4, 4, 1.0e-4f, 0.75f, 5.0f, 5, ACROSS } }
428 , lrn_test_params_float{ prop_kind::forward_scoring,
429 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw8c,
430 memory::format::nChw8c, { 2, 16, 4, 4, 1.0e-4f, 0.75f, 5.0f, 5, ACROSS } }
431 , lrn_test_params_float{ prop_kind::forward_scoring,
432 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw8c,
433 memory::format::nChw8c, { 1, 8, 1, 1, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } }
434 , lrn_test_params_float{ prop_kind::forward_scoring,
435 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw8c,
436 memory::format::nChw8c, { 1, 8, 1, 1, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } }
437 , lrn_test_params_float{ prop_kind::forward_scoring,
438 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw8c,
439 memory::format::nChw8c, { 1, 8, 1, 1, 1.0e-4f, 0.75f, 2.2f, 5, ACROSS } }
440 , lrn_test_params_float{ prop_kind::forward_scoring,
441 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw8c,
442 memory::format::nChw8c, { 1, 8, 1, 1, 1.0e-4f, 0.75f, 2.2f, 5, ACROSS } }
443 , lrn_test_params_float{ prop_kind::forward_scoring,
444 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw8c,
445 memory::format::nChw8c, { 1, 32, 5, 5, 1.0e-2f, 0.7f, 1.0f, 3, ACROSS } }
446 , lrn_test_params_float{ prop_kind::forward_scoring,
447 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw8c,
448 memory::format::nChw8c, { 1, 32, 5, 5, 1.0e-2f, 0.7f, 1.0f, 3, ACROSS } }
449 , lrn_test_params_float{ prop_kind::forward_scoring,
450 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw8c,
451 memory::format::nChw8c, { 1, 32, 5, 5, 1.0e-2f, 0.7f, 0.1f, 3, ACROSS } }
452 , lrn_test_params_float{ prop_kind::forward_scoring,
453 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw8c,
454 memory::format::nChw8c, { 1, 32, 5, 5, 1.0e-2f, 0.7f, 0.1f, 3, ACROSS } }
457 INSTANTIATE_TEST_CASE_P(TestLRN_nChw16c, lrn_test_float,
459 lrn_test_params_float{ prop_kind::forward_training,
460 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
461 memory::format::nChw16c, { 2, 16, 4, 4, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } }
462 , lrn_test_params_float{ prop_kind::forward_scoring,
463 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
464 memory::format::nChw16c, { 2, 16, 4, 4, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } }
465 , lrn_test_params_float{ prop_kind::forward_training,
466 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
467 memory::format::nChw16c, { 2, 16, 4, 4, 1.0e-4f, 0.75f, 5.0f, 5, ACROSS } }
468 , lrn_test_params_float{ prop_kind::forward_scoring,
469 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
470 memory::format::nChw16c, { 2, 16, 4, 4, 1.0e-4f, 0.75f, 5.0f, 5, ACROSS } }
471 , lrn_test_params_float{ prop_kind::forward_scoring,
472 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
473 memory::format::nChw16c, { 1, 16, 1, 1, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } }
474 , lrn_test_params_float{ prop_kind::forward_scoring,
475 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
476 memory::format::nChw16c, { 1, 16, 1, 1, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } }
477 , lrn_test_params_float{ prop_kind::forward_scoring,
478 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
479 memory::format::nChw16c, { 1, 16, 1, 1, 1.0e-4f, 0.75f, 2.2f, 5, ACROSS } }
480 , lrn_test_params_float{ prop_kind::forward_scoring,
481 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
482 memory::format::nChw16c, { 1, 16, 1, 1, 1.0e-4f, 0.75f, 2.2f, 5, ACROSS } }
483 , lrn_test_params_float{ prop_kind::forward_scoring,
484 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
485 memory::format::nChw16c, { 1, 32, 5, 5, 1.0e-2f, 0.7f, 1.0f, 3, ACROSS } }
486 , lrn_test_params_float{ prop_kind::forward_scoring,
487 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
488 memory::format::nChw16c, { 1, 32, 5, 5, 1.0e-2f, 0.7f, 1.0f, 3, ACROSS } }
489 , lrn_test_params_float{ prop_kind::forward_scoring,
490 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
491 memory::format::nChw16c, { 1, 32, 5, 5, 1.0e-2f, 0.7f, 0.1f, 3, ACROSS } }
492 , lrn_test_params_float{ prop_kind::forward_scoring,
493 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
494 memory::format::nChw16c, { 1, 32, 5, 5, 1.0e-2f, 0.7f, 0.1f, 3, ACROSS } }
497 INSTANTIATE_TEST_CASE_P(
498 TestLRNCaffeNCHW, lrn_test_float,
500 lrn_test_params_float{ prop_kind::forward_training,
501 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nchw,
502 memory::format::nchw, { 2, 4, 5, 5, 1.0f, 0.75f, 1.0f, 5, ACROSS } }
503 , lrn_test_params_float{ prop_kind::forward_scoring,
504 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nchw,
505 memory::format::nchw, { 2, 4, 5, 5, 1.0f, 0.75f, 1.0f, 5, ACROSS } }
506 , lrn_test_params_float{ prop_kind::forward_training,
507 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nchw,
508 memory::format::nchw, { 2, 4, 5, 5, 1.0f, 0.75f, 1.0f, 5, ACROSS } }
509 , lrn_test_params_float{ prop_kind::forward_scoring,
510 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nchw,
511 memory::format::nchw, { 2, 4, 5, 5, 1.0f, 0.75f, 1.0f, 5, ACROSS } }
514 INSTANTIATE_TEST_CASE_P(
515 TestLRNCaffeNHWC, lrn_test_float,
517 lrn_test_params_float{ prop_kind::forward_training,
518 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nhwc,
519 memory::format::nhwc, { 2, 4, 5, 5, 1.0f, 0.75f, 1.0f, 5, ACROSS } }
520 , lrn_test_params_float{ prop_kind::forward_scoring,
521 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nhwc,
522 memory::format::nhwc, { 2, 4, 5, 5, 1.0f, 0.75f, 1.0f, 5, ACROSS } }
523 , lrn_test_params_float{ prop_kind::forward_training,
524 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nhwc,
525 memory::format::nhwc, { 2, 4, 5, 5, 1.0f, 0.75f, 1.0f, 5, ACROSS } }
526 , lrn_test_params_float{ prop_kind::forward_scoring,
527 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nhwc,
528 memory::format::nhwc, { 2, 4, 5, 5, 1.0f, 0.75f, 1.0f, 5, ACROSS } }
531 INSTANTIATE_TEST_CASE_P(
532 TestLRNCaffe_nChw8c, lrn_test_float,
534 lrn_test_params_float{ prop_kind::forward_training,
535 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw8c,
536 memory::format::nChw8c, { 2, 96, 55, 55, 1.0f, 0.75f, 1.0f, 3, ACROSS } }
537 , lrn_test_params_float{ prop_kind::forward_scoring,
538 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw8c,
539 memory::format::nChw8c, { 2, 96, 55, 55, 1.0f, 0.75f, 1.0f, 3, ACROSS } }
540 , lrn_test_params_float{ prop_kind::forward_training,
541 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw8c,
542 memory::format::nChw8c, { 2, 96, 55, 55, 1.0f, 0.75f, 1.0f, 3, ACROSS } }
543 , lrn_test_params_float{ prop_kind::forward_scoring,
544 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw8c,
545 memory::format::nChw8c, { 2, 96, 55, 55, 1.0f, 0.75f, 1.0f, 3, ACROSS } }
548 INSTANTIATE_TEST_CASE_P(
549 TestLRNCaffe_nChw16c, lrn_test_float,
551 lrn_test_params_float{ prop_kind::forward_training,
552 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
553 memory::format::nChw16c, { 2, 96, 55, 55, 1.0f, 0.75f, 1.0f, 3, ACROSS } }
554 , lrn_test_params_float{ prop_kind::forward_scoring,
555 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
556 memory::format::nChw16c, { 2, 96, 55, 55, 1.0f, 0.75f, 1.0f, 3, ACROSS } }
557 , lrn_test_params_float{ prop_kind::forward_training,
558 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
559 memory::format::nChw16c, { 2, 96, 55, 55, 1.0f, 0.75f, 1.0f, 3, ACROSS } }
560 , lrn_test_params_float{ prop_kind::forward_scoring,
561 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
562 memory::format::nChw16c, { 2, 96, 55, 55, 1.0f, 0.75f, 1.0f, 3, ACROSS } }
565 INSTANTIATE_TEST_CASE_P(
566 TestLRNAlexnetNCHW, lrn_test_float,
568 lrn_test_params_float{ prop_kind::forward_training,
569 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nchw,
570 memory::format::nchw, { 2, 96, 55, 55, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } }
571 , lrn_test_params_float{ prop_kind::forward_scoring,
572 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nchw,
573 memory::format::nchw, { 2, 96, 55, 55, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } }
574 , lrn_test_params_float{ prop_kind::forward_training,
575 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nchw,
576 memory::format::nchw, { 2, 256, 27, 27, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } }
577 , lrn_test_params_float{ prop_kind::forward_scoring,
578 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nchw,
579 memory::format::nchw, { 2, 256, 27, 27, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } }
582 INSTANTIATE_TEST_CASE_P(
583 TestLRNAlexnetNHWC, lrn_test_float,
585 lrn_test_params_float{ prop_kind::forward_training,
586 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nhwc,
587 memory::format::nhwc, { 2, 96, 55, 55, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } },
588 lrn_test_params_float{ prop_kind::forward_scoring,
589 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nhwc,
590 memory::format::nhwc, { 2, 96, 55, 55, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } },
591 lrn_test_params_float{ prop_kind::forward_training,
592 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nhwc,
593 memory::format::nhwc, { 2, 256, 27, 27, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } },
594 lrn_test_params_float{ prop_kind::forward_scoring,
595 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nhwc,
596 memory::format::nhwc, { 2, 256, 27, 27, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } }
599 INSTANTIATE_TEST_CASE_P(
600 TestLRNAlexnet_nChw8c, lrn_test_float,
602 lrn_test_params_float{ prop_kind::forward_training,
603 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw8c,
604 memory::format::nChw8c, { 2, 96, 55, 55, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } },
605 lrn_test_params_float{ prop_kind::forward_scoring,
606 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw8c,
607 memory::format::nChw8c, { 2, 96, 55, 55, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } },
608 lrn_test_params_float{ prop_kind::forward_training,
609 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw8c,
610 memory::format::nChw8c, { 2, 256, 27, 27, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } },
611 lrn_test_params_float{ prop_kind::forward_scoring,
612 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw8c,
613 memory::format::nChw8c, { 2, 256, 27, 27, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } }
616 INSTANTIATE_TEST_CASE_P(
617 TestLRNAlexnet_nChw16c, lrn_test_float,
619 lrn_test_params_float{ prop_kind::forward_training,
620 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
621 memory::format::nChw16c, { 2, 96, 55, 55, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } },
622 lrn_test_params_float{ prop_kind::forward_scoring,
623 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
624 memory::format::nChw16c, { 2, 96, 55, 55, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } },
625 lrn_test_params_float{ prop_kind::forward_training,
626 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
627 memory::format::nChw16c, { 2, 256, 27, 27, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } },
628 lrn_test_params_float{ prop_kind::forward_scoring,
629 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
630 memory::format::nChw16c, { 2, 256, 27, 27, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } }
633 INSTANTIATE_TEST_CASE_P(
634 TestLRNGoogleNetV1NCHW, lrn_test_float,
636 lrn_test_params_float{ prop_kind::forward_training,
637 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nchw,
638 memory::format::nchw, { 2, 64, 56, 56, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } },
639 lrn_test_params_float{ prop_kind::forward_scoring,
640 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nchw,
641 memory::format::nchw, { 2, 64, 56, 56, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } },
642 lrn_test_params_float{ prop_kind::forward_training,
643 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nchw,
644 memory::format::nchw, { 2, 192, 56, 56, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } },
645 lrn_test_params_float{ prop_kind::forward_scoring,
646 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nchw,
647 memory::format::nchw, { 2, 192, 56, 56, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } }
650 INSTANTIATE_TEST_CASE_P(
651 TestLRNGoogleNetV1_nChw8c, lrn_test_float,
653 lrn_test_params_float{ prop_kind::forward_training,
654 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw8c,
655 memory::format::nChw8c, { 2, 64, 56, 56, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } },
656 lrn_test_params_float{ prop_kind::forward_scoring,
657 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw8c,
658 memory::format::nChw8c, { 2, 64, 56, 56, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } },
659 lrn_test_params_float{ prop_kind::forward_training,
660 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw8c,
661 memory::format::nChw8c, { 2, 192, 56, 56, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } },
662 lrn_test_params_float{ prop_kind::forward_scoring,
663 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw8c,
664 memory::format::nChw8c, { 2, 192, 56, 56, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } }
667 INSTANTIATE_TEST_CASE_P(
668 TestLRNGoogleNetV1_nChw16c, lrn_test_float,
670 lrn_test_params_float{ prop_kind::forward_training,
671 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
672 memory::format::nChw16c, { 2, 64, 56, 56, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } },
673 lrn_test_params_float{ prop_kind::forward_scoring,
674 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
675 memory::format::nChw16c, { 2, 64, 56, 56, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } },
676 lrn_test_params_float{ prop_kind::forward_training,
677 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
678 memory::format::nChw16c, { 2, 192, 56, 56, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } },
679 lrn_test_params_float{ prop_kind::forward_scoring,
680 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
681 memory::format::nChw16c, { 2, 192, 56, 56, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } }
684 // Backward does not support WITHIN yet.
686 INSTANTIATE_TEST_CASE_P(
687 TestLRNRCNNBlocked, lrn_test_float,
689 lrn_test_params_float{ prop_kind::forward_training,
690 engine::kind::cpu, algorithm::lrn_within_channel, memory::format::nChw8c,
691 memory::format::nChw8c, { 2, 96, 55, 55, 1.0e-4f, 0.75f, 3, WITHIN } }
692 , lrn_test_params_float{ prop_kind::forward_scoring,
693 engine::kind::cpu, algorithm::lrn_within_channel, memory::format::nChw8c,
694 memory::format::nChw8c, { 2, 96, 55, 55, 1.0e-4f, 0.75f, 3, WITHIN } }
695 , lrn_test_params_float{ prop_kind::forward_training,
696 engine::kind::cpu, algorithm::lrn_within_channel, memory::format::nChw8c,
697 memory::format::nChw8c, { 2, 256, 27, 27, 1.0e-4f, 0.75f, 3, WITHIN } }
698 , lrn_test_params_float{ prop_kind::forward_scoring,
699 engine::kind::cpu, algorithm::lrn_within_channel, memory::format::nChw8c,
700 memory::format::nChw8c, { 2, 256, 27, 27, 1.0e-4f, 0.75f, 3, WITHIN } }
701 , lrn_test_params_float{ prop_kind::forward_training,
702 engine::kind::cpu, algorithm::lrn_within_channel, memory::format::nChw8c,
703 memory::format::nChw8c, { 2, 96, 55, 55, 1.0e-4f, 0.75f, 5, WITHIN } }
704 , lrn_test_params_float{ prop_kind::forward_scoring,
705 engine::kind::cpu, algorithm::lrn_within_channel, memory::format::nChw8c,
706 memory::format::nChw8c, { 2, 96, 55, 55, 1.0e-4f, 0.75f, 5, WITHIN } }
707 , lrn_test_params_float{ prop_kind::forward_training,
708 engine::kind::cpu, algorithm::lrn_within_channel, memory::format::nChw8c,
709 memory::format::nChw8c, { 2, 256, 27, 27, 1.0e-4f, 0.75f, 5, WITHIN } }
710 , lrn_test_params_float{ prop_kind::forward_scoring,
711 engine::kind::cpu, algorithm::lrn_within_channel, memory::format::nChw8c,
712 memory::format::nChw8c, { 2, 256, 27, 27, 1.0e-4f, 0.75f, 5, WITHIN } }