Publishing R3
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / gtests / test_lrn_backward.cpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <cmath>
18
19 #include "mkldnn_test_common.hpp"
20 #include "gtest/gtest.h"
21
22 #include "mkldnn.hpp"
23
24 namespace mkldnn {
25
26 enum {ACROSS=0,WITHIN=1};
27
28 struct test_lrn_desc_t {
29     int mb, c;
30     int h, w;
31     float alpha, beta, k;
32     int local_size;
33     int kind; // 0 ac, 1 wc
34 };
35
36 struct lrn_test_params {
37     prop_kind aprop_kind;
38     engine::kind engine_kind;
39     algorithm aalgorithm;
40     memory::format data_format;
41     memory::format diff_data_format;
42     test_lrn_desc_t test_ld;
43     bool expect_to_fail;
44     mkldnn_status_t expected_status;
45 };
46
47 template <typename data_t>
48 void check_lrn_fwd(const lrn_test_params &p, const memory &src, const memory &dst)
49 {
50     data_t *src_ptr = (data_t *)src.get_data_handle();
51     data_t *dst_ptr = (data_t *)dst.get_data_handle();
52
53     const int C = p.test_ld.c;
54     const int H = p.test_ld.h;
55     const int W = p.test_ld.w;
56     const int size = p.test_ld.local_size;
57     const int CSIZE = p.test_ld.kind == ACROSS ? size : 1;
58     const int HWSIZE = size + 1 - CSIZE;
59     const int summands = p.test_ld.kind == ACROSS ? size : size*size;
60     const int padded_c = src.get_primitive_desc().desc().data.layout_desc.blocking.padding_dims[1];
61
62     const memory::desc src_d = src.get_primitive_desc().desc();
63     const memory::desc dst_d = dst.get_primitive_desc().desc();
64
65     auto off = [=](int n, int c, int h, int w)
66     {
67         return ((n * padded_c + c) * p.test_ld.h + h) * p.test_ld.w + w;
68     };
69
70     auto ker = [=](data_t *d, int n, int oc, int oh, int ow)
71     {
72         data_t sum = 0.0;
73         for (int c = oc; c < oc + CSIZE; ++c) {
74             if (c < (CSIZE - 1) / 2) continue;
75             if (c >= C + (CSIZE - 1) / 2) continue;
76             for (int h = oh; h < oh + HWSIZE; ++h) {
77                 if (h < (HWSIZE - 1) / 2) continue;
78                 if (h >= H + (HWSIZE - 1) / 2) continue;
79                 for (int w = ow; w < ow + HWSIZE; ++w) {
80                     if (w < (HWSIZE - 1) / 2) continue;
81                     if (w >= W + (HWSIZE - 1) / 2) continue;
82                     data_t s = src_ptr[map_index(src_d,off(n, c - (CSIZE - 1) / 2, h - (HWSIZE - 1) / 2, w - (HWSIZE - 1) / 2))];
83                     sum += s * s;
84                 }
85             }
86         }
87
88         auto const norm_coef = std::pow(p.test_ld.k + p.test_ld.alpha * sum / summands,
89                     p.test_ld.beta);
90         data_t ref_out = static_cast<data_t>(src_ptr[map_index(src_d, off(n, oc, oh, ow))]/norm_coef);
91         data_t eps = static_cast<data_t>(1.e-7f*(2*summands+5));
92         data_t out = d[0];
93         data_t norm_max = std::max(fabs(out), fabs(ref_out));
94         if (norm_max < eps) norm_max = 1.;
95         EXPECT_NEAR(out, ref_out, eps*norm_max);
96     };
97
98     const int N = p.test_ld.mb;
99 #   pragma omp parallel for collapse(4) schedule(static)
100     for (int n = 0; n < N; ++n) {
101         for (int c = 0; c < padded_c; ++c) {
102             for (int h = 0; h < H; ++h) {
103                 for (int w = 0; w < W; ++w) {
104                     ker(&dst_ptr[map_index(dst_d,off(n, c, h, w))], n, c, h, w);
105                 }
106             }
107         }
108     }
109 }
110
111 template <typename data_t>
112 void check_lrn_bwd(const lrn_test_params &p, const memory &src,
113         const memory &diff_dst, const memory &diff_src)
114 {
115     data_t *src_ptr = (data_t *)src.get_data_handle();
116     data_t *diff_dst_ptr = (data_t *)diff_dst.get_data_handle();
117     data_t *diff_src_ptr = (data_t *)diff_src.get_data_handle();
118
119     const int MB = p.test_ld.mb;
120     const int C = p.test_ld.c;
121     const int H = p.test_ld.h;
122     const int W = p.test_ld.w;
123     const int local_size = p.test_ld.local_size;
124     size_t padded_c = src.get_primitive_desc().desc().data.layout_desc.blocking.padding_dims[1];
125
126     data_t *ref_diff_src_ptr = new data_t[MB*(padded_c)*H*W];
127
128     const memory::desc src_d = src.get_primitive_desc().desc();
129     const memory::desc diff_dst_d = diff_dst.get_primitive_desc().desc();
130     const memory::desc diff_src_d = diff_src.get_primitive_desc().desc();
131
132     auto off = [=](int n, int c, int h, int w)
133     {
134         return ((n * padded_c + c) * H + h) * W + w;
135     };
136
137     auto get_omega = [=](data_t c_k, int kernel_size, float alpha, int C,
138             const data_t *src, int n, int c, int h, int w) {
139         data_t sum = 0.0;
140
141         int half_kernel_size = (kernel_size - 1) / 2;
142         int c_start = (c < half_kernel_size) ? 0 : c - half_kernel_size;
143         int c_end = c + kernel_size - half_kernel_size;
144         c_end = c_end < C ? c_end : C;
145         for (int i = c_start; i < c_end; ++i) {
146             data_t value = src[map_index(src_d, off(n, i, h, w))];
147             sum += value * value;
148         }
149         sum *= alpha / kernel_size;
150         return c_k + sum;
151     };
152
153     auto ker = [=](data_t *d, int mb, int oc, int oh, int ow) {
154         const float alpha = p.test_ld.alpha;
155         const float beta = p.test_ld.beta;
156         const float k = p.test_ld.k;
157         const int kernel_size = p.test_ld.local_size;
158         int ks_start = kernel_size/2 > oc ? kernel_size/2 - oc : 0;
159         int ks_stop = C - oc <= kernel_size/2 ? C - oc + kernel_size/2 : kernel_size;
160
161         data_t A = 0, B = 0, omega_mid = 0;
162
163         for (int ks = ks_start; ks < ks_stop; ks++) {
164             int _t = oc + ks - (kernel_size/2);
165             data_t omega = get_omega(static_cast<data_t>(k), kernel_size, alpha, C,
166                     src_ptr, mb, _t, oh, ow);
167
168             if (ks == kernel_size/2) omega_mid = omega;
169
170             data_t t = src_ptr[map_index(src_d, off(mb, _t, oh, ow))] / powf((float)omega, (float)beta);
171             B +=  (1.0f / omega) * t * diff_dst_ptr[map_index(diff_dst_d, off(mb, _t, oh, ow))];
172         }
173
174         A = (1.0f / powf((float)omega_mid, (float)beta))
175             * diff_dst_ptr[map_index(diff_dst_d, off(mb, oc, oh, ow))];
176         B *= src_ptr[map_index(src_d, off(mb, oc, oh, ow))];
177         B *= (2.0f * alpha * beta) / kernel_size;
178         *d = A - B;
179     };
180
181 #   pragma omp parallel for collapse(4) schedule(static)
182     for (int mb = 0; mb < MB; ++mb) {
183         for (int c = 0; c < C; ++c) {
184             for (int h = 0; h < H; ++h) {
185                 for (int w = 0; w < W; ++w) {
186                     ker(&ref_diff_src_ptr[map_index(diff_src_d, off(mb, c, h, w))],
187                             mb, c, h, w);
188                     auto A = ref_diff_src_ptr[map_index(diff_src_d, off(mb, c, h, w))];
189                     auto B = diff_src_ptr[map_index(diff_src_d, off(mb, c, h, w))];
190                     data_t eps = static_cast<data_t>( 1.e-6*((2*(2*local_size + 3) + 6)*local_size
191                         + (2*local_size + 3) + 9) );
192                     data_t norm_max = std::max(fabs(A), fabs(B));
193                     if (norm_max < eps) norm_max = 1.;
194                     EXPECT_NEAR(A, B, eps*norm_max);
195                 }
196             }
197         }
198     }
199
200     delete [] ref_diff_src_ptr;
201 }
202
203 template <typename data_t>
204 class lrn_test : public ::testing::TestWithParam<lrn_test_params> {
205 private:
206     std::shared_ptr<test_memory> src;
207     std::shared_ptr<test_memory> dst;
208     std::shared_ptr<test_memory> diff_src;
209     std::shared_ptr<test_memory> diff_dst;
210     std::shared_ptr<memory> workspace;
211     std::shared_ptr<memory::desc> src_desc;
212     std::shared_ptr<memory::desc> dst_desc;
213     std::shared_ptr<memory::desc> diff_src_desc;
214     std::shared_ptr<memory::desc> diff_dst_desc;
215     std::shared_ptr<lrn_forward::primitive_desc> lrn_fwd_prim_desc;
216     std::shared_ptr<lrn_forward::primitive_desc> lrn_bwd_prim_desc;
217     lrn_test_params p;
218     memory::dims padR;
219     std::shared_ptr<engine> eng;
220     memory::data_type data_type;
221     bool is_training;
222
223 protected:
224     virtual void SetUp() {
225         p = ::testing::TestWithParam<decltype(p)>::GetParam();
226         catch_expected_failures([=](){Test();}, p.expect_to_fail,
227                     p.expected_status);
228     }
229
230     void Test() {
231         p = ::testing::TestWithParam<decltype(p)>::GetParam();
232
233         ASSERT_TRUE(p.engine_kind == engine::kind::cpu);
234         eng.reset(new engine(p.engine_kind, 0));
235         data_type = data_traits<data_t>::data_type;
236         ASSERT_EQ(data_type, mkldnn::memory::data_type::f32);
237
238         test_lrn_desc_t ld = p.test_ld;
239
240         src_desc.reset(new memory::desc({ ld.mb, ld.c, ld.h, ld.w },
241                 data_type, p.data_format));
242         dst_desc.reset(new memory::desc({ ld.mb, ld.c, ld.h, ld.w },
243                 data_type, p.data_format));
244         diff_src_desc.reset(new memory::desc({ ld.mb, ld.c, ld.h, ld.w },
245                 data_type, p.diff_data_format));
246         diff_dst_desc.reset(new memory::desc({ ld.mb, ld.c, ld.h, ld.w },
247                 data_type, p.diff_data_format));
248
249         is_training = p.aprop_kind == prop_kind::forward_training;
250
251         Forward();
252         if (is_training)
253             Backward();
254     }
255
256     void Forward() {
257         auto lrn_desc = lrn_forward::desc(p.aprop_kind, p.aalgorithm, *src_desc,
258                 p.test_ld.local_size, p.test_ld.alpha, p.test_ld.beta,
259                 p.test_ld.k);
260         lrn_fwd_prim_desc.reset(new lrn_forward::primitive_desc(lrn_desc, *eng));
261
262         src.reset(new test_memory(*src_desc, *eng));
263         dst.reset(new test_memory(*dst_desc, *eng));
264
265         fill_data<data_t>(src->get_size() / sizeof(data_t),
266                 (data_t *)src->get().get_data_handle());
267         fill_data<data_t>(dst->get_size() / sizeof(data_t),
268                 (data_t *)dst->get().get_data_handle());
269         check_zero_tail<data_t>(1, src->get());
270         check_zero_tail<data_t>(1, dst->get());
271
272         // Execute
273         std::vector<primitive> pipeline;
274         auto s = stream(stream::kind::lazy);
275         if (is_training) {
276             auto workspace_primitive_desc =
277                 lrn_fwd_prim_desc->workspace_primitive_desc();
278             workspace.reset(new memory(workspace_primitive_desc));
279             auto l = lrn_forward(*lrn_fwd_prim_desc, src->get(), *workspace,
280                     dst->get());
281             pipeline.push_back(l);
282             s.submit(pipeline).wait();
283         } else {
284             auto l = lrn_forward(*lrn_fwd_prim_desc, src->get(),
285                     dst->get());
286             pipeline.push_back(l);
287             s.submit(pipeline).wait();
288         }
289
290         check_zero_tail<data_t>(0, dst->get());
291
292         check_lrn_fwd<data_t>(p, src->get(), dst->get());
293     }
294
295     void Backward()
296     {
297         auto lrn_desc = lrn_backward::desc(p.aalgorithm,
298                 *src_desc, *diff_dst_desc, p.test_ld.local_size,
299                 p.test_ld.alpha, p.test_ld.beta, p.test_ld.k);
300
301         src.reset(new test_memory(*src_desc, *eng));
302         diff_src.reset(new test_memory(*diff_src_desc, *eng));
303         diff_dst.reset(new test_memory(*diff_dst_desc, *eng));
304
305         auto lrn_prim_desc = lrn_backward::primitive_desc(lrn_desc, *eng,
306                 *lrn_fwd_prim_desc);
307
308         fill_data<data_t>(src->get_size() / sizeof(data_t),
309                 (data_t *)src->get().get_data_handle());
310
311         fill_data<data_t>(diff_dst->get_size() / sizeof(data_t),
312                 (data_t *)diff_dst->get().get_data_handle());
313
314         fill_data<data_t>(diff_src->get_size() / sizeof(data_t),
315                 (data_t *)diff_src->get().get_data_handle());
316         check_zero_tail<data_t>(1, src->get());
317         check_zero_tail<data_t>(1, diff_dst->get());
318         check_zero_tail<data_t>(1, diff_src->get());
319
320         // Execute
321         std::vector<primitive> pipeline;
322         auto s = stream(stream::kind::lazy);
323         auto l = lrn_backward(lrn_prim_desc, src->get(), diff_dst->get(),
324                 *workspace, diff_src->get());
325         pipeline.push_back(l);
326         s.submit(pipeline).wait();
327
328         check_zero_tail<data_t>(0, diff_src->get());
329
330         check_lrn_bwd<data_t>(p, src->get(), diff_dst->get(), diff_src->get());
331     }
332 };
333
334 using lrn_test_float = lrn_test<float>;
335 using lrn_test_params_float = lrn_test_params;
336
337 TEST_P(lrn_test_float, TestsLRN)
338 {
339 }
340
341 INSTANTIATE_TEST_CASE_P(TestLRNBackward_nChw16c_padded, lrn_test_float,
342         ::testing::Values(
343             lrn_test_params_float{ prop_kind::forward_training,
344             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
345             memory::format::nChw16c, { 2, 17, 4, 4, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } }
346             , lrn_test_params_float{ prop_kind::forward_scoring,
347             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
348             memory::format::nChw16c, { 2, 19, 4, 4, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } }
349             , lrn_test_params_float{ prop_kind::forward_training,
350             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
351             memory::format::nChw16c, { 2, 26, 4, 4, 1.0e-4f, 0.75f, 5.7f, 5, ACROSS } }
352             , lrn_test_params_float{ prop_kind::forward_scoring,
353             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
354             memory::format::nChw16c, { 2, 12, 4, 4, 1.0e-4f, 0.75f, 5.7f, 5, ACROSS } }
355             ));
356
357 INSTANTIATE_TEST_CASE_P(TestLRNForwardEF, lrn_test_float,
358         ::testing::Values(
359             lrn_test_params_float{ prop_kind::forward_training,
360             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nchw,
361             memory::format::nchw, { 2, 10, 4, 4, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } }
362             , lrn_test_params_float{ prop_kind::forward_scoring,
363             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nchw,
364             memory::format::nchw, { 2, 10, 4, 4, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } }
365             , lrn_test_params_float{ prop_kind::forward_training,
366             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nchw,
367             memory::format::nchw, { 2, 10, 4, 4, 1.0e-4f, 0.75f, 3.0f, 5, ACROSS } }
368             , lrn_test_params_float{ prop_kind::forward_scoring,
369             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nchw,
370             memory::format::nchw, { 2, 10, 4, 4, 1.0e-4f, 0.75f, 3.0f, 5, ACROSS } }
371             ));
372
373 INSTANTIATE_TEST_CASE_P(TestLRN, lrn_test_float,
374         ::testing::Values(
375             lrn_test_params_float{ prop_kind::forward_training,
376             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nchw,
377             memory::format::nchw, { 2, 10, 4, 4, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } }
378             , lrn_test_params_float{ prop_kind::forward_scoring,
379             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nchw,
380             memory::format::nchw, { 2, 10, 4, 4, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } }
381             , lrn_test_params_float{ prop_kind::forward_training,
382             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nchw,
383             memory::format::nchw, { 2, 10, 4, 4, 1.0e-4f, 0.75f, 4.0f, 5, ACROSS } }
384             , lrn_test_params_float{ prop_kind::forward_scoring,
385             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nchw,
386             memory::format::nchw, { 2, 10, 4, 4, 1.0e-4f, 0.75f, 4.0f, 5, ACROSS } }
387             , lrn_test_params_float{ prop_kind::forward_scoring,
388             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nchw,
389             memory::format::nchw, { 20, 12, 7, 7, 1.0e-2f, 0.5f, 1.0f, 3, ACROSS } }
390             , lrn_test_params_float{ prop_kind::forward_scoring,
391             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nchw,
392             memory::format::nchw, { 20, 12, 7, 7, 1.0e-2f, 0.5f, 1.0f, 3, ACROSS } }
393             , lrn_test_params_float{ prop_kind::forward_scoring,
394             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nchw,
395             memory::format::nchw, { 20, 12, 7, 7, 1.0e-2f, 0.5f, 6.5f, 3, ACROSS } }
396             , lrn_test_params_float{ prop_kind::forward_scoring,
397             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nchw,
398             memory::format::nchw, { 20, 12, 7, 7, 1.0e-2f, 0.5f, 6.5f, 3, ACROSS } }
399             ));
400
401 INSTANTIATE_TEST_CASE_P(TestLRNNHWC, lrn_test_float,
402         ::testing::Values(
403             lrn_test_params_float{ prop_kind::forward_training,
404             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nhwc,
405             memory::format::nhwc, { 2, 10, 4, 4, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } }
406             , lrn_test_params_float{ prop_kind::forward_scoring,
407             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nhwc,
408             memory::format::nhwc, { 2, 10, 4, 4, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } }
409             , lrn_test_params_float{ prop_kind::forward_training,
410             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nhwc,
411             memory::format::nhwc, { 2, 10, 4, 4, 1.0e-4f, 0.75f, 3.0f, 5, ACROSS } }
412             , lrn_test_params_float{ prop_kind::forward_scoring,
413             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nhwc,
414             memory::format::nhwc, { 2, 10, 4, 4, 1.0e-4f, 0.75f, 3.0f, 5, ACROSS } }
415             ));
416
417 INSTANTIATE_TEST_CASE_P(TestLRN_nChw8c, lrn_test_float,
418         ::testing::Values(
419             lrn_test_params_float{ prop_kind::forward_training,
420             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw8c,
421             memory::format::nChw8c, { 2, 16, 4, 4, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } }
422             , lrn_test_params_float{ prop_kind::forward_scoring,
423             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw8c,
424             memory::format::nChw8c, { 2, 16, 4, 4, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } }
425             , lrn_test_params_float{ prop_kind::forward_training,
426             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw8c,
427             memory::format::nChw8c, { 2, 16, 4, 4, 1.0e-4f, 0.75f, 5.0f, 5, ACROSS } }
428             , lrn_test_params_float{ prop_kind::forward_scoring,
429             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw8c,
430             memory::format::nChw8c, { 2, 16, 4, 4, 1.0e-4f, 0.75f, 5.0f, 5, ACROSS } }
431             , lrn_test_params_float{ prop_kind::forward_scoring,
432             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw8c,
433             memory::format::nChw8c, { 1, 8, 1, 1, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } }
434             , lrn_test_params_float{ prop_kind::forward_scoring,
435             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw8c,
436             memory::format::nChw8c, { 1, 8, 1, 1, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } }
437             , lrn_test_params_float{ prop_kind::forward_scoring,
438             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw8c,
439             memory::format::nChw8c, { 1, 8, 1, 1, 1.0e-4f, 0.75f, 2.2f, 5, ACROSS } }
440             , lrn_test_params_float{ prop_kind::forward_scoring,
441             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw8c,
442             memory::format::nChw8c, { 1, 8, 1, 1, 1.0e-4f, 0.75f, 2.2f, 5, ACROSS } }
443             , lrn_test_params_float{ prop_kind::forward_scoring,
444             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw8c,
445             memory::format::nChw8c, { 1, 32, 5, 5, 1.0e-2f, 0.7f, 1.0f, 3, ACROSS } }
446             , lrn_test_params_float{ prop_kind::forward_scoring,
447             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw8c,
448             memory::format::nChw8c, { 1, 32, 5, 5, 1.0e-2f, 0.7f, 1.0f, 3, ACROSS } }
449             , lrn_test_params_float{ prop_kind::forward_scoring,
450             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw8c,
451             memory::format::nChw8c, { 1, 32, 5, 5, 1.0e-2f, 0.7f, 0.1f, 3, ACROSS } }
452             , lrn_test_params_float{ prop_kind::forward_scoring,
453             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw8c,
454             memory::format::nChw8c, { 1, 32, 5, 5, 1.0e-2f, 0.7f, 0.1f, 3, ACROSS } }
455             ));
456
457 INSTANTIATE_TEST_CASE_P(TestLRN_nChw16c, lrn_test_float,
458         ::testing::Values(
459             lrn_test_params_float{ prop_kind::forward_training,
460             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
461             memory::format::nChw16c, { 2, 16, 4, 4, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } }
462             , lrn_test_params_float{ prop_kind::forward_scoring,
463             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
464             memory::format::nChw16c, { 2, 16, 4, 4, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } }
465             , lrn_test_params_float{ prop_kind::forward_training,
466             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
467             memory::format::nChw16c, { 2, 16, 4, 4, 1.0e-4f, 0.75f, 5.0f, 5, ACROSS } }
468             , lrn_test_params_float{ prop_kind::forward_scoring,
469             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
470             memory::format::nChw16c, { 2, 16, 4, 4, 1.0e-4f, 0.75f, 5.0f, 5, ACROSS } }
471             , lrn_test_params_float{ prop_kind::forward_scoring,
472             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
473             memory::format::nChw16c, { 1, 16, 1, 1, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } }
474             , lrn_test_params_float{ prop_kind::forward_scoring,
475             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
476             memory::format::nChw16c, { 1, 16, 1, 1, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } }
477             , lrn_test_params_float{ prop_kind::forward_scoring,
478             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
479             memory::format::nChw16c, { 1, 16, 1, 1, 1.0e-4f, 0.75f, 2.2f, 5, ACROSS } }
480             , lrn_test_params_float{ prop_kind::forward_scoring,
481             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
482             memory::format::nChw16c, { 1, 16, 1, 1, 1.0e-4f, 0.75f, 2.2f, 5, ACROSS } }
483             , lrn_test_params_float{ prop_kind::forward_scoring,
484             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
485             memory::format::nChw16c, { 1, 32, 5, 5, 1.0e-2f, 0.7f, 1.0f, 3, ACROSS } }
486             , lrn_test_params_float{ prop_kind::forward_scoring,
487             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
488             memory::format::nChw16c, { 1, 32, 5, 5, 1.0e-2f, 0.7f, 1.0f, 3, ACROSS } }
489             , lrn_test_params_float{ prop_kind::forward_scoring,
490             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
491             memory::format::nChw16c, { 1, 32, 5, 5, 1.0e-2f, 0.7f, 0.1f, 3, ACROSS } }
492             , lrn_test_params_float{ prop_kind::forward_scoring,
493             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
494             memory::format::nChw16c, { 1, 32, 5, 5, 1.0e-2f, 0.7f, 0.1f, 3, ACROSS } }
495             ));
496
497 INSTANTIATE_TEST_CASE_P(
498         TestLRNCaffeNCHW, lrn_test_float,
499         ::testing::Values(
500             lrn_test_params_float{ prop_kind::forward_training,
501             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nchw,
502             memory::format::nchw, { 2, 4, 5, 5, 1.0f, 0.75f, 1.0f, 5, ACROSS } }
503             , lrn_test_params_float{ prop_kind::forward_scoring,
504             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nchw,
505             memory::format::nchw, { 2, 4, 5, 5, 1.0f, 0.75f, 1.0f, 5, ACROSS } }
506             , lrn_test_params_float{ prop_kind::forward_training,
507             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nchw,
508             memory::format::nchw, { 2, 4, 5, 5, 1.0f, 0.75f, 1.0f, 5, ACROSS } }
509             , lrn_test_params_float{ prop_kind::forward_scoring,
510             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nchw,
511             memory::format::nchw, { 2, 4, 5, 5, 1.0f, 0.75f, 1.0f, 5, ACROSS } }
512             ));
513
514 INSTANTIATE_TEST_CASE_P(
515         TestLRNCaffeNHWC, lrn_test_float,
516         ::testing::Values(
517             lrn_test_params_float{ prop_kind::forward_training,
518             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nhwc,
519             memory::format::nhwc, { 2, 4, 5, 5, 1.0f, 0.75f, 1.0f, 5, ACROSS } }
520             , lrn_test_params_float{ prop_kind::forward_scoring,
521             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nhwc,
522             memory::format::nhwc, { 2, 4, 5, 5, 1.0f, 0.75f, 1.0f, 5, ACROSS } }
523             , lrn_test_params_float{ prop_kind::forward_training,
524             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nhwc,
525             memory::format::nhwc, { 2, 4, 5, 5, 1.0f, 0.75f, 1.0f, 5, ACROSS } }
526             , lrn_test_params_float{ prop_kind::forward_scoring,
527             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nhwc,
528             memory::format::nhwc, { 2, 4, 5, 5, 1.0f, 0.75f, 1.0f, 5, ACROSS } }
529             ));
530
531 INSTANTIATE_TEST_CASE_P(
532         TestLRNCaffe_nChw8c, lrn_test_float,
533         ::testing::Values(
534             lrn_test_params_float{ prop_kind::forward_training,
535             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw8c,
536             memory::format::nChw8c, { 2, 96, 55, 55, 1.0f, 0.75f, 1.0f, 3, ACROSS } }
537             , lrn_test_params_float{ prop_kind::forward_scoring,
538             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw8c,
539             memory::format::nChw8c, { 2, 96, 55, 55, 1.0f, 0.75f, 1.0f, 3, ACROSS } }
540             , lrn_test_params_float{ prop_kind::forward_training,
541             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw8c,
542             memory::format::nChw8c, { 2, 96, 55, 55, 1.0f, 0.75f, 1.0f, 3, ACROSS } }
543             , lrn_test_params_float{ prop_kind::forward_scoring,
544             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw8c,
545             memory::format::nChw8c, { 2, 96, 55, 55, 1.0f, 0.75f, 1.0f, 3, ACROSS } }
546             ));
547
548 INSTANTIATE_TEST_CASE_P(
549         TestLRNCaffe_nChw16c, lrn_test_float,
550         ::testing::Values(
551             lrn_test_params_float{ prop_kind::forward_training,
552             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
553             memory::format::nChw16c, { 2, 96, 55, 55, 1.0f, 0.75f, 1.0f, 3, ACROSS } }
554             , lrn_test_params_float{ prop_kind::forward_scoring,
555             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
556             memory::format::nChw16c, { 2, 96, 55, 55, 1.0f, 0.75f, 1.0f, 3, ACROSS } }
557             , lrn_test_params_float{ prop_kind::forward_training,
558             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
559             memory::format::nChw16c, { 2, 96, 55, 55, 1.0f, 0.75f, 1.0f, 3, ACROSS } }
560             , lrn_test_params_float{ prop_kind::forward_scoring,
561             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
562             memory::format::nChw16c, { 2, 96, 55, 55, 1.0f, 0.75f, 1.0f, 3, ACROSS } }
563             ));
564
565 INSTANTIATE_TEST_CASE_P(
566         TestLRNAlexnetNCHW, lrn_test_float,
567         ::testing::Values(
568             lrn_test_params_float{ prop_kind::forward_training,
569             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nchw,
570             memory::format::nchw, { 2, 96, 55, 55, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } }
571             , lrn_test_params_float{ prop_kind::forward_scoring,
572             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nchw,
573             memory::format::nchw, { 2, 96, 55, 55, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } }
574             , lrn_test_params_float{ prop_kind::forward_training,
575             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nchw,
576             memory::format::nchw, { 2, 256, 27, 27, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } }
577             , lrn_test_params_float{ prop_kind::forward_scoring,
578             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nchw,
579             memory::format::nchw, { 2, 256, 27, 27, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } }
580             ));
581
582 INSTANTIATE_TEST_CASE_P(
583         TestLRNAlexnetNHWC, lrn_test_float,
584         ::testing::Values(
585                 lrn_test_params_float{ prop_kind::forward_training,
586                 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nhwc,
587                 memory::format::nhwc, { 2, 96, 55, 55, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } },
588                 lrn_test_params_float{ prop_kind::forward_scoring,
589                 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nhwc,
590                 memory::format::nhwc, { 2, 96, 55, 55, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } },
591                 lrn_test_params_float{ prop_kind::forward_training,
592                 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nhwc,
593                 memory::format::nhwc, { 2, 256, 27, 27, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } },
594                 lrn_test_params_float{ prop_kind::forward_scoring,
595                 engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nhwc,
596                 memory::format::nhwc, { 2, 256, 27, 27, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } }
597             ));
598
599 INSTANTIATE_TEST_CASE_P(
600         TestLRNAlexnet_nChw8c, lrn_test_float,
601         ::testing::Values(
602             lrn_test_params_float{ prop_kind::forward_training,
603             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw8c,
604             memory::format::nChw8c, { 2, 96, 55, 55, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } },
605             lrn_test_params_float{ prop_kind::forward_scoring,
606             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw8c,
607             memory::format::nChw8c, { 2, 96, 55, 55, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } },
608             lrn_test_params_float{ prop_kind::forward_training,
609             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw8c,
610             memory::format::nChw8c, { 2, 256, 27, 27, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } },
611             lrn_test_params_float{ prop_kind::forward_scoring,
612             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw8c,
613             memory::format::nChw8c, { 2, 256, 27, 27, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } }
614             ));
615
616 INSTANTIATE_TEST_CASE_P(
617         TestLRNAlexnet_nChw16c, lrn_test_float,
618         ::testing::Values(
619             lrn_test_params_float{ prop_kind::forward_training,
620             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
621             memory::format::nChw16c, { 2, 96, 55, 55, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } },
622             lrn_test_params_float{ prop_kind::forward_scoring,
623             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
624             memory::format::nChw16c, { 2, 96, 55, 55, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } },
625             lrn_test_params_float{ prop_kind::forward_training,
626             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
627             memory::format::nChw16c, { 2, 256, 27, 27, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } },
628             lrn_test_params_float{ prop_kind::forward_scoring,
629             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
630             memory::format::nChw16c, { 2, 256, 27, 27, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } }
631             ));
632
633 INSTANTIATE_TEST_CASE_P(
634         TestLRNGoogleNetV1NCHW, lrn_test_float,
635         ::testing::Values(
636             lrn_test_params_float{ prop_kind::forward_training,
637             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nchw,
638             memory::format::nchw, { 2, 64, 56, 56, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } },
639             lrn_test_params_float{ prop_kind::forward_scoring,
640             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nchw,
641             memory::format::nchw, { 2, 64, 56, 56, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } },
642             lrn_test_params_float{ prop_kind::forward_training,
643             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nchw,
644             memory::format::nchw, { 2, 192, 56, 56, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } },
645             lrn_test_params_float{ prop_kind::forward_scoring,
646             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nchw,
647             memory::format::nchw, { 2, 192, 56, 56, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } }
648             ));
649
650 INSTANTIATE_TEST_CASE_P(
651         TestLRNGoogleNetV1_nChw8c, lrn_test_float,
652         ::testing::Values(
653             lrn_test_params_float{ prop_kind::forward_training,
654             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw8c,
655             memory::format::nChw8c, { 2, 64, 56, 56, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } },
656             lrn_test_params_float{ prop_kind::forward_scoring,
657             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw8c,
658             memory::format::nChw8c, { 2, 64, 56, 56, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } },
659             lrn_test_params_float{ prop_kind::forward_training,
660             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw8c,
661             memory::format::nChw8c, { 2, 192, 56, 56, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } },
662             lrn_test_params_float{ prop_kind::forward_scoring,
663             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw8c,
664             memory::format::nChw8c, { 2, 192, 56, 56, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } }
665             ));
666
667 INSTANTIATE_TEST_CASE_P(
668         TestLRNGoogleNetV1_nChw16c, lrn_test_float,
669         ::testing::Values(
670             lrn_test_params_float{ prop_kind::forward_training,
671             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
672             memory::format::nChw16c, { 2, 64, 56, 56, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } },
673             lrn_test_params_float{ prop_kind::forward_scoring,
674             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
675             memory::format::nChw16c, { 2, 64, 56, 56, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } },
676             lrn_test_params_float{ prop_kind::forward_training,
677             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
678             memory::format::nChw16c, { 2, 192, 56, 56, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } },
679             lrn_test_params_float{ prop_kind::forward_scoring,
680             engine::kind::cpu, algorithm::lrn_across_channels, memory::format::nChw16c,
681             memory::format::nChw16c, { 2, 192, 56, 56, 1.0e-4f, 0.75f, 1.0f, 5, ACROSS } }
682             ));
683
684 // Backward does not support WITHIN yet.
685 /*
686 INSTANTIATE_TEST_CASE_P(
687         TestLRNRCNNBlocked, lrn_test_float,
688         ::testing::Values(
689             lrn_test_params_float{ prop_kind::forward_training,
690             engine::kind::cpu, algorithm::lrn_within_channel, memory::format::nChw8c,
691             memory::format::nChw8c, { 2, 96, 55, 55, 1.0e-4f, 0.75f, 3, WITHIN } }
692             , lrn_test_params_float{ prop_kind::forward_scoring,
693             engine::kind::cpu, algorithm::lrn_within_channel, memory::format::nChw8c,
694             memory::format::nChw8c, { 2, 96, 55, 55, 1.0e-4f, 0.75f, 3, WITHIN } }
695             , lrn_test_params_float{ prop_kind::forward_training,
696             engine::kind::cpu, algorithm::lrn_within_channel, memory::format::nChw8c,
697             memory::format::nChw8c, { 2, 256, 27, 27, 1.0e-4f, 0.75f, 3, WITHIN } }
698             , lrn_test_params_float{ prop_kind::forward_scoring,
699             engine::kind::cpu, algorithm::lrn_within_channel, memory::format::nChw8c,
700             memory::format::nChw8c, { 2, 256, 27, 27, 1.0e-4f, 0.75f, 3, WITHIN } }
701             , lrn_test_params_float{ prop_kind::forward_training,
702             engine::kind::cpu, algorithm::lrn_within_channel, memory::format::nChw8c,
703             memory::format::nChw8c, { 2, 96, 55, 55, 1.0e-4f, 0.75f, 5, WITHIN } }
704             , lrn_test_params_float{ prop_kind::forward_scoring,
705             engine::kind::cpu, algorithm::lrn_within_channel, memory::format::nChw8c,
706             memory::format::nChw8c, { 2, 96, 55, 55, 1.0e-4f, 0.75f, 5, WITHIN } }
707             , lrn_test_params_float{ prop_kind::forward_training,
708             engine::kind::cpu, algorithm::lrn_within_channel, memory::format::nChw8c,
709             memory::format::nChw8c, { 2, 256, 27, 27, 1.0e-4f, 0.75f, 5, WITHIN } }
710             , lrn_test_params_float{ prop_kind::forward_scoring,
711             engine::kind::cpu, algorithm::lrn_within_channel, memory::format::nChw8c,
712             memory::format::nChw8c, { 2, 256, 27, 27, 1.0e-4f, 0.75f, 5, WITHIN } }
713             ));
714 */
715 }