Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / gtests / test_batch_normalization.cpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <cmath>
18
19 #include "mkldnn_test_common.hpp"
20 #include "gtest/gtest.h"
21
22 #include "mkldnn.hpp"
23
24 namespace mkldnn {
25
26 struct test_bnrm_sizes_t {
27     int mb, c, d, h, w;
28 };
29
30 struct test_bnrm_formats_t {
31     mkldnn::memory::format data_format;
32     mkldnn::memory::format diff_format;
33 };
34
35 struct test_bnrm_params_t {
36     mkldnn::engine::kind engine_kind;
37     test_bnrm_formats_t formats;
38     test_bnrm_sizes_t sizes;
39     float eps;
40     int ndims;
41     bool expect_to_fail;
42     mkldnn_status_t expected_status;
43 };
44
45 template <typename data_t>
46 void check_bnrm_fwd(const test_bnrm_params_t &p,
47         const memory &src, const memory &mean, const memory &variance,
48         const memory &weights, const memory &dst, unsigned flags, prop_kind pk)
49 {
50     const test_bnrm_sizes_t &bp = p.sizes;
51     if (bp.mb * bp.c * bp.d * bp.h * bp.w == 0) return;
52
53     const bool use_weights = flags & use_scale_shift;
54     const bool calculate_stats = !(flags & use_global_stats);
55     const bool is_training = (pk == prop_kind::forward_training);
56
57     const data_t *src_data = (const data_t *)src.get_data_handle();
58     const data_t *weights_data = use_weights ? (const data_t *)weights.get_data_handle() : nullptr;
59     const data_t *mean_data = (!calculate_stats || is_training) ?
60            (const data_t *)mean.get_data_handle() : nullptr;
61     const data_t *variance_data = (!calculate_stats || is_training) ?
62            (const data_t *)variance.get_data_handle() : nullptr;
63     const data_t *dst_data = (data_t *)dst.get_data_handle();
64
65     const memory::desc src_d = src.get_primitive_desc().desc();
66     const memory::desc dst_d = dst.get_primitive_desc().desc();
67
68     data_t eps = static_cast<data_t>(1.e-4 * bp.mb * bp.d * bp.h * bp.w);
69
70     size_t padded_c = src.get_primitive_desc().desc().data.layout_desc
71         .blocking.padding_dims[1];
72
73     mkldnn::impl::parallel_nd(bp.c, [&](int c) {
74         data_t ref_mean = calculate_stats ? data_t(0) : mean_data[c];
75         data_t ref_variance = calculate_stats ? data_t(0) : variance_data[c];
76         if (calculate_stats) {
77             for (int n = 0; n < bp.mb; n++)
78                 for (int d = 0; d < bp.d; d++)
79                 for (int h = 0; h < bp.h; h++)
80                 for (int w = 0; w < bp.w; w++) {
81                     size_t sidx = n * padded_c * bp.d * bp.h * bp.w
82                         + c * bp.d * bp.h * bp.w
83                         + d * bp.h * bp.w + h * bp.w + w;
84                 ref_mean += src_data[map_index(src_d, sidx)];
85             }
86             ref_mean /= bp.mb * bp.d * bp.h * bp.w;
87             if (is_training) {
88                 data_t mean_norm_max = std::max(fabs(mean_data[c]), fabs(ref_mean));
89                 if (mean_norm_max < eps) mean_norm_max = data_t(1);
90                 EXPECT_NEAR((mean_data[c] - ref_mean) / mean_norm_max, 0., eps);
91             }
92
93             for (int n = 0; n < bp.mb; n++)
94             for (int d = 0; d < bp.d; d++)
95             for (int h = 0; h < bp.h; h++)
96                 for (int w = 0; w < bp.w; w++) {
97                     size_t sidx = n * padded_c * bp.d * bp.h * bp.w
98                     + c * bp.d * bp.h * bp.w + d * bp.h * bp.w + h * bp.w + w;
99                     data_t tmp = src_data[map_index(src_d, sidx)] - ref_mean;
100                     ref_variance += tmp * tmp;
101                 }
102             ref_variance /= bp.mb * bp.d * bp.h * bp.w;
103             if (is_training) {
104                 data_t variance_norm_max = std::max(fabs(variance_data[c]), fabs(ref_variance));
105                 if (variance_norm_max < eps) variance_norm_max = data_t(1);
106                 EXPECT_NEAR((variance_data[c] - ref_variance) / variance_norm_max, 0., eps);
107             }
108         }
109         data_t ref_sqrt_variance = static_cast<data_t>(sqrt(ref_variance + p.eps));
110         data_t ref_rsqrt_variance = data_t(1) / (ref_sqrt_variance);
111
112         if (use_weights) {
113             memory::desc weights_d = weights.get_primitive_desc().desc();
114             for (int n = 0; n < bp.mb; n++)
115             for (int d = 0; d < bp.d; d++)
116             for (int h = 0; h < bp.h; h++)
117                 for (int w = 0; w < bp.w; w++) {
118                     size_t sdidx = n * padded_c * bp.d * bp.h * bp.w
119                     + c * bp.d * bp.h * bp.w + d * bp.h * bp.w + h * bp.w + w;
120                     data_t ref_dst = weights_data[map_index(weights_d, c)]
121                             * (src_data[map_index(src_d, sdidx)]
122                             - ref_mean) * ref_rsqrt_variance
123                             + weights_data[map_index(weights_d, bp.c + c)];
124                     data_t out = dst_data[map_index(dst_d, sdidx)];
125                     data_t norm_max = std::max(fabs(out), fabs(ref_dst));
126                     if (norm_max < 10e-3) norm_max = data_t(1);
127                     EXPECT_NEAR((out - ref_dst) / norm_max, 0., eps);
128                 }
129         } else {
130             for (int n = 0; n < bp.mb; n++)
131             for (int d = 0; d < bp.d; d++)
132             for (int h = 0; h < bp.h; h++)
133                 for (int w = 0; w < bp.w; w++) {
134                     size_t sdidx = n * padded_c * bp.d * bp.h * bp.w
135                     + c * bp.d * bp.h * bp.w + d * bp.h * bp.w + h * bp.w + w;
136                     data_t ref_dst = (src_data[map_index(src_d, sdidx)]
137                             - ref_mean) * ref_rsqrt_variance;
138                     data_t out = dst_data[map_index(dst_d, sdidx)];
139                     data_t norm_max = std::max(fabs(out), fabs(ref_dst));
140                     if (norm_max < 10e-3) norm_max = data_t(1);
141                     EXPECT_NEAR((out - ref_dst) / norm_max, 0., eps);
142                 }
143         }
144     });
145 }
146
147 template <typename data_t>
148 void check_bnrm_bwd(const test_bnrm_params_t &p,
149         const memory &src, const memory &diff_dst, const memory &mean,
150         const memory &variance, const memory &weights, const memory &diff_src,
151         const memory &diff_weights, unsigned flags, prop_kind pk)
152 {
153     const test_bnrm_sizes_t &bp = p.sizes;
154     const bool use_weights = flags & use_scale_shift;
155     const bool calculate_diff_stats = !(flags & use_global_stats);
156
157     const data_t *src_data = (const data_t *)src.get_data_handle();
158     const data_t *weights_data = use_weights ? (const data_t *)weights.get_data_handle() : nullptr;
159     const data_t *diff_dst_data = (const data_t *)diff_dst.get_data_handle();
160     const data_t *mean_data = (const data_t *)mean.get_data_handle();
161     const data_t *variance_data = (const data_t *)variance.get_data_handle();
162     const data_t *diff_src_data = (data_t *)diff_src.get_data_handle();
163     const data_t *diff_weights_data = (pk == prop_kind::backward) ?
164             (data_t *)diff_weights.get_data_handle() : nullptr;
165
166     const memory::desc src_d = src.get_primitive_desc().desc();
167     const memory::desc diff_dst_d = diff_dst.get_primitive_desc().desc();
168     const memory::desc weights_d = weights.get_primitive_desc().desc();
169     const memory::desc diff_src_d = diff_src.get_primitive_desc().desc();
170     const memory::desc diff_weights_d = diff_weights.get_primitive_desc().desc();
171
172     if (bp.mb * bp.c * bp.d * bp.h * bp.w == 0) {
173         if (pk == backward) {
174             for (int c = 0; c < bp.c; ++c) {
175                auto dg = diff_weights_data[map_index(diff_weights_d, c)];
176                auto db = diff_weights_data[map_index(diff_weights_d, bp.c + c)];
177                EXPECT_NEAR(dg, 0., 1e-7);
178                EXPECT_NEAR(db, 0., 1e-7);
179             }
180         }
181         return;
182     }
183
184     const data_t eps = static_cast<data_t>(1.e-4 * bp.mb * bp.d * bp.h * bp.w);
185
186     size_t padded_c = src.get_primitive_desc().desc().data.layout_desc.blocking.padding_dims[1];
187     mkldnn::impl::parallel_nd(bp.c, [&](int c) {
188         data_t ref_diff_gamma = data_t(0);
189         data_t ref_diff_beta = data_t(0);
190
191         auto v_mean = mean_data[c];
192         auto v_variance = variance_data[c];
193         const data_t sqrt_variance = data_t(1.0 / sqrt(v_variance + p.eps));
194
195         auto gamma = use_weights ? weights_data[map_index(weights_d, c)] : 1;
196
197         for (int n = 0; n < bp.mb; n++)
198         for (int d = 0; d < bp.d; d++)
199         for (int h = 0; h < bp.h; h++)
200         for (int w = 0; w < bp.w; w++) {
201             size_t sidx = n * padded_c * bp.d * bp.h * bp.w + c * bp.d * bp.h * bp.w
202                     + d * bp.h * bp.w + h * bp.w + w;
203             ref_diff_gamma += (src_data[map_index(src_d, sidx)] - v_mean)
204                 * diff_dst_data[map_index(diff_dst_d, sidx)];
205             ref_diff_beta += diff_dst_data[map_index(diff_dst_d, sidx)];
206         }
207         ref_diff_gamma *= sqrt_variance;
208
209         if (pk == backward) {
210             auto diff_gamma = diff_weights_data[map_index(diff_weights_d, c)];
211             data_t norm_max = std::max(fabs(diff_gamma), fabs(ref_diff_gamma));
212             if (norm_max < 10e-3) norm_max = data_t(1);
213             EXPECT_NEAR((diff_gamma - ref_diff_gamma) / norm_max, 0., eps);
214
215             auto diff_beta = diff_weights_data[map_index(diff_weights_d, bp.c + c)];
216             norm_max = std::max(fabs(diff_beta), fabs(ref_diff_beta));
217             if (norm_max < 10e-3) norm_max = data_t(1);
218             EXPECT_NEAR((diff_beta - ref_diff_beta) / norm_max, 0., eps);
219         }
220
221         for (int n = 0; n < bp.mb; n++)
222         for (int d = 0; d < bp.d; d++)
223         for (int h = 0; h < bp.h; h++)
224             for (int w = 0; w < bp.w; w++) {
225                 size_t sidx = n * padded_c * bp.d * bp.h * bp.w
226                     + c * bp.d * bp.h * bp.w + d * bp.h * bp.w + h * bp.w + w;
227                 data_t ref_diff_src = diff_dst_data[map_index(diff_dst_d, sidx)];
228                 if (calculate_diff_stats) {
229                         ref_diff_src -= ref_diff_beta/(bp.mb*bp.d*bp.h*bp.w)
230                         + (src_data[map_index(src_d, sidx)] - v_mean)
231                         *ref_diff_gamma*sqrt_variance/(bp.mb*bp.d*bp.h*bp.w);
232                 }
233                 ref_diff_src *= gamma*sqrt_variance;
234                 data_t out_diff_src = diff_src_data[map_index(diff_src_d, sidx)];
235                 data_t norm_max = std::max(fabs(out_diff_src), fabs(ref_diff_src));
236                 if (norm_max < eps) norm_max = data_t(1);
237                 EXPECT_NEAR((out_diff_src - ref_diff_src) / norm_max, 0., eps);
238             }
239     });
240 }
241
242 template <typename data_t>
243 class bnrm_test : public ::testing::TestWithParam<test_bnrm_params_t> {
244 private:
245     std::shared_ptr<test_memory> src;
246     std::shared_ptr<test_memory> dst;
247     std::shared_ptr<test_memory> diff_src;
248     std::shared_ptr<test_memory> diff_dst;
249     std::shared_ptr<memory> weights;
250     std::shared_ptr<memory> diff_weights;
251     std::shared_ptr<memory> mean;
252     std::shared_ptr<memory> variance;
253     std::shared_ptr<memory::desc> data_desc;
254     std::shared_ptr<memory::desc> diff_desc;
255     std::shared_ptr<batch_normalization_forward::primitive_desc> bnrm_prim_desc;
256     std::shared_ptr<batch_normalization_backward::primitive_desc>
257         bnrm_bwd_prim_desc;
258     test_bnrm_params_t p;
259     std::shared_ptr<engine> eng;
260     memory::data_type data_type;
261
262 protected:
263     virtual void SetUp() {
264         p = ::testing::TestWithParam<decltype(p)>::GetParam();
265         catch_expected_failures([=](){Test();}, p.expect_to_fail,
266                     p.expected_status);
267     }
268
269     void Test() {
270         p = ::testing::TestWithParam<decltype(p)>::GetParam();
271
272         ASSERT_TRUE(p.engine_kind == engine::kind::cpu);
273         eng.reset(new engine(p.engine_kind, 0));
274         memory::data_type data_type = data_traits<data_t>::data_type;
275         ASSERT_EQ(data_type, mkldnn::memory::data_type::f32);
276
277         test_bnrm_sizes_t bs = p.sizes;
278         bool has_spatial = (p.formats.data_format != mkldnn_nc);
279         if (has_spatial)
280         {
281             if (p.ndims == 5)
282             {
283                 data_desc.reset(new memory::desc({ bs.mb, bs.c, bs.d, bs.h, bs.w },
284                     data_type, p.formats.data_format));
285                 diff_desc.reset(new memory::desc({ bs.mb, bs.c, bs.d, bs.h, bs.w },
286                     data_type, p.formats.diff_format));
287             } else {
288                 data_desc.reset(new memory::desc({ bs.mb, bs.c, bs.h, bs.w },
289                     data_type, p.formats.data_format));
290                 diff_desc.reset(new memory::desc({ bs.mb, bs.c, bs.h, bs.w },
291                     data_type, p.formats.diff_format));
292             }
293         }
294         else {
295             data_desc.reset(new memory::desc({ bs.mb, bs.c },
296                 data_type, p.formats.data_format));
297             diff_desc.reset(new memory::desc({ bs.mb, bs.c },
298                 data_type, p.formats.diff_format));
299         }
300
301         src.reset(new test_memory(*data_desc, *eng));
302         dst.reset(new test_memory(*data_desc, *eng));
303         diff_src.reset(new test_memory(*diff_desc, *eng));
304         diff_dst.reset(new test_memory(*diff_desc, *eng));
305
306         auto training = prop_kind::forward_training;
307         auto scoring = prop_kind::forward_scoring;
308
309
310         Forward(0u, scoring);
311         Forward(0u, training);
312         Forward(use_global_stats, training);
313         Forward(use_global_stats, scoring);
314         Forward(use_scale_shift, scoring);
315         Forward(use_scale_shift, training);
316         Forward(use_scale_shift | use_global_stats, training);
317
318         Backward(0u, backward_data);
319         Backward(use_global_stats, backward_data);
320         Backward(use_scale_shift, backward);
321         Backward(use_scale_shift, backward_data);
322         Backward(use_scale_shift | use_global_stats, backward);
323         Backward(use_scale_shift | use_global_stats, backward_data);
324
325     }
326
327     void Forward(unsigned flags, prop_kind pk) {
328         bool useScaleShift = flags & use_scale_shift;
329         bool useGlobalStats = flags & use_global_stats;
330         bool isTraining = pk == prop_kind::forward_training;
331
332         auto bnrm_desc = batch_normalization_forward::desc(pk,
333                     *data_desc, p.eps, flags);
334
335         bnrm_prim_desc.reset(new batch_normalization_forward::primitive_desc(
336                     bnrm_desc, *eng));
337
338         weights.reset(new memory(bnrm_prim_desc->weights_primitive_desc()));
339         if (isTraining || useGlobalStats) {
340             mean.reset(new memory(bnrm_prim_desc->mean_primitive_desc()));
341             variance.reset(
342                     new memory(bnrm_prim_desc->variance_primitive_desc()));
343         }
344
345         fill(src->get());
346         fill(dst->get());
347         if (useScaleShift) fill(*weights);
348         if (useGlobalStats) {
349             fill(*mean);
350             fill(*variance);
351         }
352         check_zero_tail<data_t>(1, src->get());
353         check_zero_tail<data_t>(1, dst->get());
354
355         auto bn = createBnrmFwd(isTraining, useGlobalStats, useScaleShift);
356
357         std::vector<primitive> pipeline;
358         pipeline.push_back(bn);
359         stream(stream::kind::lazy).submit(pipeline).wait();
360
361         check_zero_tail<data_t>(0, dst->get());
362
363         check_bnrm_fwd<data_t>(p, src->get(), *mean, *variance, *weights,
364                 dst->get(), flags, pk);
365
366     }
367
368     void Backward(unsigned flags, prop_kind pk) {
369         bool useScaleShift = flags & use_scale_shift;
370
371         auto bnrm_bwd_desc = batch_normalization_backward::desc(
372                 pk, *diff_desc, *data_desc, p.eps, flags);
373
374         bnrm_bwd_prim_desc.reset(
375                 new batch_normalization_backward::primitive_desc(
376                 bnrm_bwd_desc, *eng, *bnrm_prim_desc));
377
378         if (useScaleShift) weights.reset(new memory(
379                     bnrm_bwd_prim_desc->weights_primitive_desc()));
380         diff_weights.reset(new memory(bnrm_bwd_prim_desc->diff_weights_primitive_desc()));
381         mean.reset(new memory(bnrm_bwd_prim_desc->mean_primitive_desc()));
382         variance.reset(new memory(
383                     bnrm_bwd_prim_desc->variance_primitive_desc()));
384
385         if (useScaleShift) fill(*weights);
386         fill(diff_src->get());
387         fill(diff_dst->get());
388         fill(*mean);
389         fill(*variance);
390         check_zero_tail<data_t>(1, diff_src->get());
391         check_zero_tail<data_t>(1, diff_dst->get());
392
393         auto bnrm_bwd = createBnrmBwd(useScaleShift, pk);
394
395         std::vector<primitive> pipeline;
396         pipeline.push_back(bnrm_bwd);
397         stream(stream::kind::lazy).submit(pipeline).wait();
398
399         check_bnrm_bwd<data_t>(p,
400                 src->get(), diff_dst->get(), *mean, *variance, *weights,
401                 diff_src->get(), *diff_weights, flags, pk);
402         check_zero_tail<data_t>(0, diff_src->get());
403     }
404
405     void fill(memory &m, data_t mean = 1.) {
406         fill_data<data_t>(m.get_primitive_desc().get_size() / sizeof(data_t),
407                 reinterpret_cast<data_t *>(m.get_data_handle()));
408     }
409
410     primitive createBnrmFwd(bool isTraining, bool useGlobalStats,
411             bool useScaleShift)
412     {
413         if (!isTraining && !useGlobalStats) {
414             return useScaleShift
415                 ? batch_normalization_forward(*bnrm_prim_desc,
416                     src->get(), *weights, dst->get())
417                 : batch_normalization_forward(*bnrm_prim_desc, src->get(),
418                         dst->get());
419         } else {
420             if (useGlobalStats) {
421                 return useScaleShift
422                     ? batch_normalization_forward(*bnrm_prim_desc,
423                         src->get(), (const primitive::at)*mean,
424                         (const primitive::at)*variance, *weights, dst->get())
425                     : batch_normalization_forward(*bnrm_prim_desc,
426                         src->get(), (const primitive::at)*mean,
427                         (const primitive::at)*variance, dst->get());
428             } else {
429                 return useScaleShift
430                     ? batch_normalization_forward(*bnrm_prim_desc,
431                         src->get(), *weights, dst->get(), *mean, *variance)
432                     : batch_normalization_forward(*bnrm_prim_desc,
433                         src->get(), dst->get(), *mean, *variance);
434             }
435         }
436     }
437
438     primitive createBnrmBwd(bool useScaleShift, prop_kind pk)
439     {
440         if (useScaleShift) {
441             return pk == prop_kind::backward_data
442                 ? batch_normalization_backward(*bnrm_bwd_prim_desc,
443                     src->get(), *mean, *variance, diff_dst->get(), *weights,
444                     diff_src->get())
445                 : batch_normalization_backward(*bnrm_bwd_prim_desc,
446                     src->get(), *mean, *variance, diff_dst->get(), *weights,
447                     diff_src->get(), *diff_weights);
448         } else {
449             return batch_normalization_backward(*bnrm_bwd_prim_desc, src->get(),
450                     *mean, *variance, diff_dst->get(), diff_src->get());
451         }
452     }
453 };
454
455 using bnrm_test_float = bnrm_test<float>;
456
457 #define EXPAND_ARGS(args) args
458 TEST_P(bnrm_test_float, TestsBnrm)
459 {
460 }
461
462 #define EXPAND_SIZES_3D(...) { __VA_ARGS__ }
463 #define EXPAND_SIZES_2D(mb, c, h, w) { mb, c, 1, h, w }
464 #define EXPAND_FORMATS(data, diff) \
465     { memory::format::data, memory::format::diff }
466
467 #define ENGINE engine::kind::cpu
468 #define EPS 1e-5f
469
470 #define PARAMS(data, diff, mb, c, h, w, eps, ef, st) \
471     test_bnrm_params_t { ENGINE, EXPAND_FORMATS(data, diff), \
472         EXPAND_SIZES_2D(mb, c, h, w), eps, 4, ef, st }
473
474 #define PARAMS_3D(data, diff, mb, c, d, h, w, eps, ef, st) \
475     test_bnrm_params_t { ENGINE, EXPAND_FORMATS(data, diff), \
476         EXPAND_SIZES_3D(mb, c, d, h, w), eps, 5, ef, st }
477
478 #define PARAMS_N_3D(...) EXPAND_ARGS(PARAMS_3D(ncdhw, ncdhw, __VA_ARGS__, false, mkldnn_success))
479 #define PARAMS_B8_3D(...) EXPAND_ARGS(PARAMS_3D(nCdhw8c, nCdhw8c, __VA_ARGS__, false, mkldnn_success))
480 #define PARAMS_B16_3D(...) EXPAND_ARGS(PARAMS_3D(nCdhw16c, nCdhw16c, __VA_ARGS__, false, mkldnn_success))
481 #define PARAMS_N(...) EXPAND_ARGS(PARAMS(nchw, nchw, __VA_ARGS__, false, mkldnn_success))
482 #define PARAMS_NHWC(...) EXPAND_ARGS(PARAMS(nhwc, nhwc, __VA_ARGS__, false, mkldnn_success))
483 #define PARAMS_NC(...) EXPAND_ARGS(PARAMS(nc, nc, __VA_ARGS__, false, mkldnn_success))
484 #define PARAMS_B8(...) EXPAND_ARGS(PARAMS(nChw8c, nChw8c, __VA_ARGS__, false, mkldnn_success))
485 #define PARAMS_B16(...) EXPAND_ARGS(PARAMS(nChw16c, nChw16c, __VA_ARGS__, false, mkldnn_success))
486 #define PARAMS_EF(...) EXPAND_ARGS(PARAMS(nchw, nchw, __VA_ARGS__))
487
488 #define INST_TEST_CASE(str, ...) INSTANTIATE_TEST_CASE_P( \
489         str, bnrm_test_float, ::testing::Values(__VA_ARGS__))
490
491 INST_TEST_CASE(SimpleZeroDim,
492     PARAMS_N(0, 27, 9, 10, EPS),
493     PARAMS_N(1, 0, 10, 9, EPS),
494     PARAMS_N(4, 20, 0, 12, EPS)
495 );
496
497 INST_TEST_CASE(SimpleExpectedFails,
498     PARAMS_EF(-1, 27, 9, 10, EPS, true, mkldnn_invalid_arguments),
499     PARAMS_EF(1, -12, 10, 9, EPS, true, mkldnn_invalid_arguments),
500     PARAMS_EF(4, 20, -12, 12, EPS, true, mkldnn_invalid_arguments)
501 );
502
503 INST_TEST_CASE(Simple_nChw16c_padded,
504     PARAMS_B16(1, 27, 9, 10, EPS),
505     PARAMS_B16(1, 12, 10, 9, EPS),
506     PARAMS_B16(4, 20, 12, 12, EPS),
507     PARAMS_B16(4, 9, 16, 16, EPS)
508 );
509
510 INST_TEST_CASE(Simple_nCdhw16c_padded,
511     PARAMS_B16_3D(2, 12, 16, 8, 20, EPS),
512     PARAMS_B16_3D(2, 9, 16, 8, 20, EPS),
513     PARAMS_B16_3D(2, 23, 10, 8, 4, EPS),
514     PARAMS_B16_3D(2, 27, 10, 8, 4, EPS)
515 );
516
517 INST_TEST_CASE(Simple_nChw8c_padded,
518     PARAMS_B8(1, 27, 9, 10, EPS),
519     PARAMS_B8(1, 12, 10, 9, EPS),
520     PARAMS_B8(4, 20, 12, 12, EPS),
521     PARAMS_B8(4, 7, 16, 16, EPS)
522 );
523
524
525 INST_TEST_CASE(Simple_nCdhw16c,
526     PARAMS_B16_3D(2, 32, 4, 4, 4, EPS),
527     PARAMS_B16_3D(2, 32, 4, 4, 4, EPS),
528     PARAMS_B16_3D(2, 32, 8, 8, 8, EPS),
529     PARAMS_B16_3D(2, 32, 8, 8, 8, EPS),
530     PARAMS_B16_3D(2, 32, 16, 8, 20, EPS),
531     PARAMS_B16_3D(2, 32, 16, 8, 20, EPS),
532     PARAMS_B16_3D(2, 32, 10, 8, 4, EPS),
533     PARAMS_B16_3D(2, 32, 10, 8, 4, EPS)
534 );
535
536 INST_TEST_CASE(Simple_nCdhw8c,
537     PARAMS_B8_3D(2, 32, 4, 4, 4, EPS),
538     PARAMS_B8_3D(2, 32, 4, 4, 4, EPS),
539     PARAMS_B8_3D(2, 32, 8, 8, 8, EPS),
540     PARAMS_B8_3D(2, 32, 8, 8, 8, EPS),
541     PARAMS_B8_3D(2, 32, 16, 8, 20, EPS),
542     PARAMS_B8_3D(2, 32, 16, 8, 20, EPS),
543     PARAMS_B8_3D(2, 32, 10, 8, 4, EPS),
544     PARAMS_B8_3D(2, 32, 10, 8, 4, EPS)
545 );
546
547 INST_TEST_CASE(Simple_NC,
548     PARAMS_NC(2, 8, 1, 1, EPS),
549     PARAMS_NC(2, 10, 1, 1, EPS),
550     PARAMS_NC(2, 8, 1, 1, EPS),
551     PARAMS_NC(2, 10, 1, 1, EPS)
552 );
553
554 INST_TEST_CASE(Simple_NCDHW,
555     PARAMS_N_3D(2, 8, 1, 1, 1, EPS),
556     PARAMS_N_3D(2, 10, 1, 1, 1, EPS),
557     PARAMS_N_3D(2, 8, 4, 4, 4, EPS),
558     PARAMS_N_3D(2, 10, 4, 4, 4, EPS)
559 );
560
561 INST_TEST_CASE(Simple_NCHW,
562     PARAMS_N(2, 8, 1, 1, EPS),
563     PARAMS_N(2, 10, 1, 1, EPS),
564     PARAMS_N(2, 8, 4, 4, EPS),
565     PARAMS_N(2, 10, 4, 4, EPS)
566 );
567
568 INST_TEST_CASE(Simple_NHWC,
569     PARAMS_NHWC(2, 8, 1, 1, EPS),
570     PARAMS_NHWC(2, 10, 1, 1, EPS),
571     PARAMS_NHWC(2, 8, 4, 4, EPS),
572     PARAMS_NHWC(2, 10, 4, 4, EPS)
573 );
574
575 INST_TEST_CASE(Simple_Blocked,
576     PARAMS_B8(2, 8, 1, 1, EPS),
577     PARAMS_B8(2, 8, 4, 4, EPS),
578     PARAMS_B8(2, 8, 6, 6, EPS),
579     PARAMS_B8(2, 16, 4, 4, EPS),
580     PARAMS_B8(2, 16, 4, 4, EPS),
581     PARAMS_B8(2, 16, 8, 8, EPS),
582     PARAMS_B8(2, 16, 8, 8, EPS),
583     PARAMS_B8(2, 16, 16, 8, EPS),
584     PARAMS_B8(2, 16, 16, 8, EPS),
585     PARAMS_B8(2, 16, 10, 8, EPS),
586     PARAMS_B8(2, 16, 10, 8, EPS),
587     PARAMS_B16(2, 16, 4, 4, EPS),
588     PARAMS_B16(2, 16, 4, 4, EPS),
589     PARAMS_B16(2, 16, 8, 8, EPS),
590     PARAMS_B16(2, 16, 8, 8, EPS),
591     PARAMS_B16(2, 16, 16, 8, EPS),
592     PARAMS_B16(2, 16, 16, 8, EPS),
593     PARAMS_B16(2, 16, 10, 8, EPS),
594     PARAMS_B16(2, 16, 10, 8, EPS)
595 );
596
597 INST_TEST_CASE(GoogleNet_NCHW,
598     PARAMS_N(2, 64, 112, 112, EPS),
599     PARAMS_N(2, 64, 56, 56, EPS),
600     PARAMS_N(2, 192, 56, 56, EPS),
601     PARAMS_N(2, 96, 28, 28, EPS),
602     PARAMS_N(2, 16, 28, 28, EPS),
603     PARAMS_N(2, 64, 28, 28, EPS),
604     PARAMS_N(2, 128, 28, 28, EPS),
605     PARAMS_N(2, 32, 28, 28, EPS),
606     PARAMS_N(2, 96, 28, 28, EPS),
607     PARAMS_N(2, 96, 14, 14, EPS),
608     PARAMS_N(2, 16, 14, 14, EPS),
609     PARAMS_N(2, 192, 14, 14, EPS),
610     PARAMS_N(2, 208, 14, 14, EPS),
611     PARAMS_N(2, 48, 14, 14, EPS),
612     PARAMS_N(2, 64, 14, 14, EPS),
613     PARAMS_N(2, 112, 14, 14, EPS),
614     PARAMS_N(2, 24, 14, 14, EPS),
615     PARAMS_N(2, 160, 14, 14, EPS),
616     PARAMS_N(2, 224, 14, 14, EPS),
617     PARAMS_N(2, 128, 4, 4, EPS),
618     PARAMS_N(2, 128, 14, 14, EPS),
619     PARAMS_N(2, 512, 14, 14, EPS),
620     PARAMS_N(2, 256, 14, 14, EPS),
621     PARAMS_N(2, 144, 14, 14, EPS),
622     PARAMS_N(2, 32, 14, 14, EPS),
623     PARAMS_N(2, 228, 14, 14, EPS),
624     PARAMS_N(2, 528, 14, 14, EPS),
625     PARAMS_N(2, 320, 14, 14, EPS),
626     PARAMS_N(2, 160, 7, 7, EPS),
627     PARAMS_N(2, 32, 7, 7, EPS),
628     PARAMS_N(2, 256, 7, 7, EPS),
629     PARAMS_N(2, 320, 7, 7, EPS),
630     PARAMS_N(2, 128, 7, 7, EPS),
631     PARAMS_N(2, 192, 7, 7, EPS),
632     PARAMS_N(2, 48, 7, 7, EPS),
633     PARAMS_N(2, 384, 7, 7, EPS)
634 );
635
636 INST_TEST_CASE(GoogleNet_Blocked_8,
637     PARAMS_B8(2, 64, 112, 112, EPS),
638     PARAMS_B8(2, 64, 56, 56, EPS),
639     PARAMS_B8(2, 192, 56, 56, EPS),
640     PARAMS_B8(2, 96, 28, 28, EPS),
641     PARAMS_B8(2, 16, 28, 28, EPS),
642     PARAMS_B8(2, 64, 28, 28, EPS),
643     PARAMS_B8(2, 128, 28, 28, EPS),
644     PARAMS_B8(2, 32, 28, 28, EPS),
645     PARAMS_B8(2, 96, 28, 28, EPS),
646     PARAMS_B8(2, 96, 14, 14, EPS),
647     PARAMS_B8(2, 16, 14, 14, EPS),
648     PARAMS_B8(2, 192, 14, 14, EPS),
649     PARAMS_B8(2, 208, 14, 14, EPS),
650     PARAMS_B8(2, 48, 14, 14, EPS),
651     PARAMS_B8(2, 64, 14, 14, EPS),
652     PARAMS_B8(2, 112, 14, 14, EPS),
653     PARAMS_B8(2, 24, 14, 14, EPS),
654     PARAMS_B8(2, 160, 14, 14, EPS),
655     PARAMS_B8(2, 224, 14, 14, EPS),
656     PARAMS_B8(2, 128, 4, 4, EPS),
657     PARAMS_B8(2, 128, 14, 14, EPS),
658     PARAMS_B8(2, 512, 14, 14, EPS),
659     PARAMS_B8(2, 256, 14, 14, EPS),
660     PARAMS_B8(2, 144, 14, 14, EPS),
661     PARAMS_B8(2, 32, 14, 14, EPS),
662     PARAMS_B8(2, 528, 14, 14, EPS),
663     PARAMS_B8(2, 320, 14, 14, EPS),
664     PARAMS_B8(2, 160, 7, 7, EPS),
665     PARAMS_B8(2, 32, 7, 7, EPS),
666     PARAMS_B8(2, 256, 7, 7, EPS),
667     PARAMS_B8(2, 320, 7, 7, EPS),
668     PARAMS_B8(2, 128, 7, 7, EPS),
669     PARAMS_B8(2, 192, 7, 7, EPS),
670     PARAMS_B8(2, 48, 7, 7, EPS),
671     PARAMS_B8(2, 384, 7, 7, EPS)
672 );
673
674 INST_TEST_CASE(GoogleNet_Blocked_16,
675     PARAMS_B16(2, 64, 112, 112, EPS),
676     PARAMS_B16(2, 64, 56, 56, EPS),
677     PARAMS_B16(2, 192, 56, 56, EPS),
678     PARAMS_B16(2, 96, 28, 28, EPS),
679     PARAMS_B16(2, 16, 28, 28, EPS),
680     PARAMS_B16(2, 64, 28, 28, EPS),
681     PARAMS_B16(2, 128, 28, 28, EPS),
682     PARAMS_B16(2, 32, 28, 28, EPS),
683     PARAMS_B16(2, 96, 28, 28, EPS),
684     PARAMS_B16(2, 96, 14, 14, EPS),
685     PARAMS_B16(2, 16, 14, 14, EPS),
686     PARAMS_B16(2, 192, 14, 14, EPS),
687     PARAMS_B16(2, 208, 14, 14, EPS),
688     PARAMS_B16(2, 48, 14, 14, EPS),
689     PARAMS_B16(2, 64, 14, 14, EPS),
690     PARAMS_B16(2, 112, 14, 14, EPS),
691     //PARAMS_B16(2, 24, 14, 14, EPS),
692     PARAMS_B16(2, 160, 14, 14, EPS),
693     PARAMS_B16(2, 224, 14, 14, EPS),
694     PARAMS_B16(2, 128, 4, 4, EPS),
695     PARAMS_B16(2, 128, 14, 14, EPS),
696     PARAMS_B16(2, 512, 14, 14, EPS),
697     PARAMS_B16(2, 256, 14, 14, EPS),
698     PARAMS_B16(2, 144, 14, 14, EPS),
699     PARAMS_B16(2, 32, 14, 14, EPS),
700     PARAMS_B16(2, 528, 14, 14, EPS),
701     PARAMS_B16(2, 320, 14, 14, EPS),
702     PARAMS_B16(2, 160, 7, 7, EPS),
703     PARAMS_B16(2, 32, 7, 7, EPS),
704     PARAMS_B16(2, 256, 7, 7, EPS),
705     PARAMS_B16(2, 320, 7, 7, EPS),
706     PARAMS_B16(2, 128, 7, 7, EPS),
707     PARAMS_B16(2, 192, 7, 7, EPS),
708     PARAMS_B16(2, 48, 7, 7, EPS),
709     PARAMS_B16(2, 384, 7, 7, EPS)
710 );
711
712 }