1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
19 #include "mkldnn_test_common.hpp"
20 #include "gtest/gtest.h"
26 struct test_bnrm_sizes_t {
30 struct test_bnrm_formats_t {
31 mkldnn::memory::format data_format;
32 mkldnn::memory::format diff_format;
35 struct test_bnrm_params_t {
36 mkldnn::engine::kind engine_kind;
37 test_bnrm_formats_t formats;
38 test_bnrm_sizes_t sizes;
42 mkldnn_status_t expected_status;
45 template <typename data_t>
46 void check_bnrm_fwd(const test_bnrm_params_t &p,
47 const memory &src, const memory &mean, const memory &variance,
48 const memory &weights, const memory &dst, unsigned flags, prop_kind pk)
50 const test_bnrm_sizes_t &bp = p.sizes;
51 if (bp.mb * bp.c * bp.d * bp.h * bp.w == 0) return;
53 const bool use_weights = flags & use_scale_shift;
54 const bool calculate_stats = !(flags & use_global_stats);
55 const bool is_training = (pk == prop_kind::forward_training);
57 const data_t *src_data = (const data_t *)src.get_data_handle();
58 const data_t *weights_data = use_weights ? (const data_t *)weights.get_data_handle() : nullptr;
59 const data_t *mean_data = (!calculate_stats || is_training) ?
60 (const data_t *)mean.get_data_handle() : nullptr;
61 const data_t *variance_data = (!calculate_stats || is_training) ?
62 (const data_t *)variance.get_data_handle() : nullptr;
63 const data_t *dst_data = (data_t *)dst.get_data_handle();
65 const memory::desc src_d = src.get_primitive_desc().desc();
66 const memory::desc dst_d = dst.get_primitive_desc().desc();
68 data_t eps = static_cast<data_t>(1.e-4 * bp.mb * bp.d * bp.h * bp.w);
70 size_t padded_c = src.get_primitive_desc().desc().data.layout_desc
71 .blocking.padding_dims[1];
73 mkldnn::impl::parallel_nd(bp.c, [&](int c) {
74 data_t ref_mean = calculate_stats ? data_t(0) : mean_data[c];
75 data_t ref_variance = calculate_stats ? data_t(0) : variance_data[c];
76 if (calculate_stats) {
77 for (int n = 0; n < bp.mb; n++)
78 for (int d = 0; d < bp.d; d++)
79 for (int h = 0; h < bp.h; h++)
80 for (int w = 0; w < bp.w; w++) {
81 size_t sidx = n * padded_c * bp.d * bp.h * bp.w
82 + c * bp.d * bp.h * bp.w
83 + d * bp.h * bp.w + h * bp.w + w;
84 ref_mean += src_data[map_index(src_d, sidx)];
86 ref_mean /= bp.mb * bp.d * bp.h * bp.w;
88 data_t mean_norm_max = std::max(fabs(mean_data[c]), fabs(ref_mean));
89 if (mean_norm_max < eps) mean_norm_max = data_t(1);
90 EXPECT_NEAR((mean_data[c] - ref_mean) / mean_norm_max, 0., eps);
93 for (int n = 0; n < bp.mb; n++)
94 for (int d = 0; d < bp.d; d++)
95 for (int h = 0; h < bp.h; h++)
96 for (int w = 0; w < bp.w; w++) {
97 size_t sidx = n * padded_c * bp.d * bp.h * bp.w
98 + c * bp.d * bp.h * bp.w + d * bp.h * bp.w + h * bp.w + w;
99 data_t tmp = src_data[map_index(src_d, sidx)] - ref_mean;
100 ref_variance += tmp * tmp;
102 ref_variance /= bp.mb * bp.d * bp.h * bp.w;
104 data_t variance_norm_max = std::max(fabs(variance_data[c]), fabs(ref_variance));
105 if (variance_norm_max < eps) variance_norm_max = data_t(1);
106 EXPECT_NEAR((variance_data[c] - ref_variance) / variance_norm_max, 0., eps);
109 data_t ref_sqrt_variance = static_cast<data_t>(sqrt(ref_variance + p.eps));
110 data_t ref_rsqrt_variance = data_t(1) / (ref_sqrt_variance);
113 memory::desc weights_d = weights.get_primitive_desc().desc();
114 for (int n = 0; n < bp.mb; n++)
115 for (int d = 0; d < bp.d; d++)
116 for (int h = 0; h < bp.h; h++)
117 for (int w = 0; w < bp.w; w++) {
118 size_t sdidx = n * padded_c * bp.d * bp.h * bp.w
119 + c * bp.d * bp.h * bp.w + d * bp.h * bp.w + h * bp.w + w;
120 data_t ref_dst = weights_data[map_index(weights_d, c)]
121 * (src_data[map_index(src_d, sdidx)]
122 - ref_mean) * ref_rsqrt_variance
123 + weights_data[map_index(weights_d, bp.c + c)];
124 data_t out = dst_data[map_index(dst_d, sdidx)];
125 data_t norm_max = std::max(fabs(out), fabs(ref_dst));
126 if (norm_max < 10e-3) norm_max = data_t(1);
127 EXPECT_NEAR((out - ref_dst) / norm_max, 0., eps);
130 for (int n = 0; n < bp.mb; n++)
131 for (int d = 0; d < bp.d; d++)
132 for (int h = 0; h < bp.h; h++)
133 for (int w = 0; w < bp.w; w++) {
134 size_t sdidx = n * padded_c * bp.d * bp.h * bp.w
135 + c * bp.d * bp.h * bp.w + d * bp.h * bp.w + h * bp.w + w;
136 data_t ref_dst = (src_data[map_index(src_d, sdidx)]
137 - ref_mean) * ref_rsqrt_variance;
138 data_t out = dst_data[map_index(dst_d, sdidx)];
139 data_t norm_max = std::max(fabs(out), fabs(ref_dst));
140 if (norm_max < 10e-3) norm_max = data_t(1);
141 EXPECT_NEAR((out - ref_dst) / norm_max, 0., eps);
147 template <typename data_t>
148 void check_bnrm_bwd(const test_bnrm_params_t &p,
149 const memory &src, const memory &diff_dst, const memory &mean,
150 const memory &variance, const memory &weights, const memory &diff_src,
151 const memory &diff_weights, unsigned flags, prop_kind pk)
153 const test_bnrm_sizes_t &bp = p.sizes;
154 const bool use_weights = flags & use_scale_shift;
155 const bool calculate_diff_stats = !(flags & use_global_stats);
157 const data_t *src_data = (const data_t *)src.get_data_handle();
158 const data_t *weights_data = use_weights ? (const data_t *)weights.get_data_handle() : nullptr;
159 const data_t *diff_dst_data = (const data_t *)diff_dst.get_data_handle();
160 const data_t *mean_data = (const data_t *)mean.get_data_handle();
161 const data_t *variance_data = (const data_t *)variance.get_data_handle();
162 const data_t *diff_src_data = (data_t *)diff_src.get_data_handle();
163 const data_t *diff_weights_data = (pk == prop_kind::backward) ?
164 (data_t *)diff_weights.get_data_handle() : nullptr;
166 const memory::desc src_d = src.get_primitive_desc().desc();
167 const memory::desc diff_dst_d = diff_dst.get_primitive_desc().desc();
168 const memory::desc weights_d = weights.get_primitive_desc().desc();
169 const memory::desc diff_src_d = diff_src.get_primitive_desc().desc();
170 const memory::desc diff_weights_d = diff_weights.get_primitive_desc().desc();
172 if (bp.mb * bp.c * bp.d * bp.h * bp.w == 0) {
173 if (pk == backward) {
174 for (int c = 0; c < bp.c; ++c) {
175 auto dg = diff_weights_data[map_index(diff_weights_d, c)];
176 auto db = diff_weights_data[map_index(diff_weights_d, bp.c + c)];
177 EXPECT_NEAR(dg, 0., 1e-7);
178 EXPECT_NEAR(db, 0., 1e-7);
184 const data_t eps = static_cast<data_t>(1.e-4 * bp.mb * bp.d * bp.h * bp.w);
186 size_t padded_c = src.get_primitive_desc().desc().data.layout_desc.blocking.padding_dims[1];
187 mkldnn::impl::parallel_nd(bp.c, [&](int c) {
188 data_t ref_diff_gamma = data_t(0);
189 data_t ref_diff_beta = data_t(0);
191 auto v_mean = mean_data[c];
192 auto v_variance = variance_data[c];
193 const data_t sqrt_variance = data_t(1.0 / sqrt(v_variance + p.eps));
195 auto gamma = use_weights ? weights_data[map_index(weights_d, c)] : 1;
197 for (int n = 0; n < bp.mb; n++)
198 for (int d = 0; d < bp.d; d++)
199 for (int h = 0; h < bp.h; h++)
200 for (int w = 0; w < bp.w; w++) {
201 size_t sidx = n * padded_c * bp.d * bp.h * bp.w + c * bp.d * bp.h * bp.w
202 + d * bp.h * bp.w + h * bp.w + w;
203 ref_diff_gamma += (src_data[map_index(src_d, sidx)] - v_mean)
204 * diff_dst_data[map_index(diff_dst_d, sidx)];
205 ref_diff_beta += diff_dst_data[map_index(diff_dst_d, sidx)];
207 ref_diff_gamma *= sqrt_variance;
209 if (pk == backward) {
210 auto diff_gamma = diff_weights_data[map_index(diff_weights_d, c)];
211 data_t norm_max = std::max(fabs(diff_gamma), fabs(ref_diff_gamma));
212 if (norm_max < 10e-3) norm_max = data_t(1);
213 EXPECT_NEAR((diff_gamma - ref_diff_gamma) / norm_max, 0., eps);
215 auto diff_beta = diff_weights_data[map_index(diff_weights_d, bp.c + c)];
216 norm_max = std::max(fabs(diff_beta), fabs(ref_diff_beta));
217 if (norm_max < 10e-3) norm_max = data_t(1);
218 EXPECT_NEAR((diff_beta - ref_diff_beta) / norm_max, 0., eps);
221 for (int n = 0; n < bp.mb; n++)
222 for (int d = 0; d < bp.d; d++)
223 for (int h = 0; h < bp.h; h++)
224 for (int w = 0; w < bp.w; w++) {
225 size_t sidx = n * padded_c * bp.d * bp.h * bp.w
226 + c * bp.d * bp.h * bp.w + d * bp.h * bp.w + h * bp.w + w;
227 data_t ref_diff_src = diff_dst_data[map_index(diff_dst_d, sidx)];
228 if (calculate_diff_stats) {
229 ref_diff_src -= ref_diff_beta/(bp.mb*bp.d*bp.h*bp.w)
230 + (src_data[map_index(src_d, sidx)] - v_mean)
231 *ref_diff_gamma*sqrt_variance/(bp.mb*bp.d*bp.h*bp.w);
233 ref_diff_src *= gamma*sqrt_variance;
234 data_t out_diff_src = diff_src_data[map_index(diff_src_d, sidx)];
235 data_t norm_max = std::max(fabs(out_diff_src), fabs(ref_diff_src));
236 if (norm_max < eps) norm_max = data_t(1);
237 EXPECT_NEAR((out_diff_src - ref_diff_src) / norm_max, 0., eps);
242 template <typename data_t>
243 class bnrm_test : public ::testing::TestWithParam<test_bnrm_params_t> {
245 std::shared_ptr<test_memory> src;
246 std::shared_ptr<test_memory> dst;
247 std::shared_ptr<test_memory> diff_src;
248 std::shared_ptr<test_memory> diff_dst;
249 std::shared_ptr<memory> weights;
250 std::shared_ptr<memory> diff_weights;
251 std::shared_ptr<memory> mean;
252 std::shared_ptr<memory> variance;
253 std::shared_ptr<memory::desc> data_desc;
254 std::shared_ptr<memory::desc> diff_desc;
255 std::shared_ptr<batch_normalization_forward::primitive_desc> bnrm_prim_desc;
256 std::shared_ptr<batch_normalization_backward::primitive_desc>
258 test_bnrm_params_t p;
259 std::shared_ptr<engine> eng;
260 memory::data_type data_type;
263 virtual void SetUp() {
264 p = ::testing::TestWithParam<decltype(p)>::GetParam();
265 catch_expected_failures([=](){Test();}, p.expect_to_fail,
270 p = ::testing::TestWithParam<decltype(p)>::GetParam();
272 ASSERT_TRUE(p.engine_kind == engine::kind::cpu);
273 eng.reset(new engine(p.engine_kind, 0));
274 memory::data_type data_type = data_traits<data_t>::data_type;
275 ASSERT_EQ(data_type, mkldnn::memory::data_type::f32);
277 test_bnrm_sizes_t bs = p.sizes;
278 bool has_spatial = (p.formats.data_format != mkldnn_nc);
283 data_desc.reset(new memory::desc({ bs.mb, bs.c, bs.d, bs.h, bs.w },
284 data_type, p.formats.data_format));
285 diff_desc.reset(new memory::desc({ bs.mb, bs.c, bs.d, bs.h, bs.w },
286 data_type, p.formats.diff_format));
288 data_desc.reset(new memory::desc({ bs.mb, bs.c, bs.h, bs.w },
289 data_type, p.formats.data_format));
290 diff_desc.reset(new memory::desc({ bs.mb, bs.c, bs.h, bs.w },
291 data_type, p.formats.diff_format));
295 data_desc.reset(new memory::desc({ bs.mb, bs.c },
296 data_type, p.formats.data_format));
297 diff_desc.reset(new memory::desc({ bs.mb, bs.c },
298 data_type, p.formats.diff_format));
301 src.reset(new test_memory(*data_desc, *eng));
302 dst.reset(new test_memory(*data_desc, *eng));
303 diff_src.reset(new test_memory(*diff_desc, *eng));
304 diff_dst.reset(new test_memory(*diff_desc, *eng));
306 auto training = prop_kind::forward_training;
307 auto scoring = prop_kind::forward_scoring;
310 Forward(0u, scoring);
311 Forward(0u, training);
312 Forward(use_global_stats, training);
313 Forward(use_global_stats, scoring);
314 Forward(use_scale_shift, scoring);
315 Forward(use_scale_shift, training);
316 Forward(use_scale_shift | use_global_stats, training);
318 Backward(0u, backward_data);
319 Backward(use_global_stats, backward_data);
320 Backward(use_scale_shift, backward);
321 Backward(use_scale_shift, backward_data);
322 Backward(use_scale_shift | use_global_stats, backward);
323 Backward(use_scale_shift | use_global_stats, backward_data);
327 void Forward(unsigned flags, prop_kind pk) {
328 bool useScaleShift = flags & use_scale_shift;
329 bool useGlobalStats = flags & use_global_stats;
330 bool isTraining = pk == prop_kind::forward_training;
332 auto bnrm_desc = batch_normalization_forward::desc(pk,
333 *data_desc, p.eps, flags);
335 bnrm_prim_desc.reset(new batch_normalization_forward::primitive_desc(
338 weights.reset(new memory(bnrm_prim_desc->weights_primitive_desc()));
339 if (isTraining || useGlobalStats) {
340 mean.reset(new memory(bnrm_prim_desc->mean_primitive_desc()));
342 new memory(bnrm_prim_desc->variance_primitive_desc()));
347 if (useScaleShift) fill(*weights);
348 if (useGlobalStats) {
352 check_zero_tail<data_t>(1, src->get());
353 check_zero_tail<data_t>(1, dst->get());
355 auto bn = createBnrmFwd(isTraining, useGlobalStats, useScaleShift);
357 std::vector<primitive> pipeline;
358 pipeline.push_back(bn);
359 stream(stream::kind::lazy).submit(pipeline).wait();
361 check_zero_tail<data_t>(0, dst->get());
363 check_bnrm_fwd<data_t>(p, src->get(), *mean, *variance, *weights,
364 dst->get(), flags, pk);
368 void Backward(unsigned flags, prop_kind pk) {
369 bool useScaleShift = flags & use_scale_shift;
371 auto bnrm_bwd_desc = batch_normalization_backward::desc(
372 pk, *diff_desc, *data_desc, p.eps, flags);
374 bnrm_bwd_prim_desc.reset(
375 new batch_normalization_backward::primitive_desc(
376 bnrm_bwd_desc, *eng, *bnrm_prim_desc));
378 if (useScaleShift) weights.reset(new memory(
379 bnrm_bwd_prim_desc->weights_primitive_desc()));
380 diff_weights.reset(new memory(bnrm_bwd_prim_desc->diff_weights_primitive_desc()));
381 mean.reset(new memory(bnrm_bwd_prim_desc->mean_primitive_desc()));
382 variance.reset(new memory(
383 bnrm_bwd_prim_desc->variance_primitive_desc()));
385 if (useScaleShift) fill(*weights);
386 fill(diff_src->get());
387 fill(diff_dst->get());
390 check_zero_tail<data_t>(1, diff_src->get());
391 check_zero_tail<data_t>(1, diff_dst->get());
393 auto bnrm_bwd = createBnrmBwd(useScaleShift, pk);
395 std::vector<primitive> pipeline;
396 pipeline.push_back(bnrm_bwd);
397 stream(stream::kind::lazy).submit(pipeline).wait();
399 check_bnrm_bwd<data_t>(p,
400 src->get(), diff_dst->get(), *mean, *variance, *weights,
401 diff_src->get(), *diff_weights, flags, pk);
402 check_zero_tail<data_t>(0, diff_src->get());
405 void fill(memory &m, data_t mean = 1.) {
406 fill_data<data_t>(m.get_primitive_desc().get_size() / sizeof(data_t),
407 reinterpret_cast<data_t *>(m.get_data_handle()));
410 primitive createBnrmFwd(bool isTraining, bool useGlobalStats,
413 if (!isTraining && !useGlobalStats) {
415 ? batch_normalization_forward(*bnrm_prim_desc,
416 src->get(), *weights, dst->get())
417 : batch_normalization_forward(*bnrm_prim_desc, src->get(),
420 if (useGlobalStats) {
422 ? batch_normalization_forward(*bnrm_prim_desc,
423 src->get(), (const primitive::at)*mean,
424 (const primitive::at)*variance, *weights, dst->get())
425 : batch_normalization_forward(*bnrm_prim_desc,
426 src->get(), (const primitive::at)*mean,
427 (const primitive::at)*variance, dst->get());
430 ? batch_normalization_forward(*bnrm_prim_desc,
431 src->get(), *weights, dst->get(), *mean, *variance)
432 : batch_normalization_forward(*bnrm_prim_desc,
433 src->get(), dst->get(), *mean, *variance);
438 primitive createBnrmBwd(bool useScaleShift, prop_kind pk)
441 return pk == prop_kind::backward_data
442 ? batch_normalization_backward(*bnrm_bwd_prim_desc,
443 src->get(), *mean, *variance, diff_dst->get(), *weights,
445 : batch_normalization_backward(*bnrm_bwd_prim_desc,
446 src->get(), *mean, *variance, diff_dst->get(), *weights,
447 diff_src->get(), *diff_weights);
449 return batch_normalization_backward(*bnrm_bwd_prim_desc, src->get(),
450 *mean, *variance, diff_dst->get(), diff_src->get());
455 using bnrm_test_float = bnrm_test<float>;
457 #define EXPAND_ARGS(args) args
458 TEST_P(bnrm_test_float, TestsBnrm)
462 #define EXPAND_SIZES_3D(...) { __VA_ARGS__ }
463 #define EXPAND_SIZES_2D(mb, c, h, w) { mb, c, 1, h, w }
464 #define EXPAND_FORMATS(data, diff) \
465 { memory::format::data, memory::format::diff }
467 #define ENGINE engine::kind::cpu
470 #define PARAMS(data, diff, mb, c, h, w, eps, ef, st) \
471 test_bnrm_params_t { ENGINE, EXPAND_FORMATS(data, diff), \
472 EXPAND_SIZES_2D(mb, c, h, w), eps, 4, ef, st }
474 #define PARAMS_3D(data, diff, mb, c, d, h, w, eps, ef, st) \
475 test_bnrm_params_t { ENGINE, EXPAND_FORMATS(data, diff), \
476 EXPAND_SIZES_3D(mb, c, d, h, w), eps, 5, ef, st }
478 #define PARAMS_N_3D(...) EXPAND_ARGS(PARAMS_3D(ncdhw, ncdhw, __VA_ARGS__, false, mkldnn_success))
479 #define PARAMS_B8_3D(...) EXPAND_ARGS(PARAMS_3D(nCdhw8c, nCdhw8c, __VA_ARGS__, false, mkldnn_success))
480 #define PARAMS_B16_3D(...) EXPAND_ARGS(PARAMS_3D(nCdhw16c, nCdhw16c, __VA_ARGS__, false, mkldnn_success))
481 #define PARAMS_N(...) EXPAND_ARGS(PARAMS(nchw, nchw, __VA_ARGS__, false, mkldnn_success))
482 #define PARAMS_NHWC(...) EXPAND_ARGS(PARAMS(nhwc, nhwc, __VA_ARGS__, false, mkldnn_success))
483 #define PARAMS_NC(...) EXPAND_ARGS(PARAMS(nc, nc, __VA_ARGS__, false, mkldnn_success))
484 #define PARAMS_B8(...) EXPAND_ARGS(PARAMS(nChw8c, nChw8c, __VA_ARGS__, false, mkldnn_success))
485 #define PARAMS_B16(...) EXPAND_ARGS(PARAMS(nChw16c, nChw16c, __VA_ARGS__, false, mkldnn_success))
486 #define PARAMS_EF(...) EXPAND_ARGS(PARAMS(nchw, nchw, __VA_ARGS__))
488 #define INST_TEST_CASE(str, ...) INSTANTIATE_TEST_CASE_P( \
489 str, bnrm_test_float, ::testing::Values(__VA_ARGS__))
491 INST_TEST_CASE(SimpleZeroDim,
492 PARAMS_N(0, 27, 9, 10, EPS),
493 PARAMS_N(1, 0, 10, 9, EPS),
494 PARAMS_N(4, 20, 0, 12, EPS)
497 INST_TEST_CASE(SimpleExpectedFails,
498 PARAMS_EF(-1, 27, 9, 10, EPS, true, mkldnn_invalid_arguments),
499 PARAMS_EF(1, -12, 10, 9, EPS, true, mkldnn_invalid_arguments),
500 PARAMS_EF(4, 20, -12, 12, EPS, true, mkldnn_invalid_arguments)
503 INST_TEST_CASE(Simple_nChw16c_padded,
504 PARAMS_B16(1, 27, 9, 10, EPS),
505 PARAMS_B16(1, 12, 10, 9, EPS),
506 PARAMS_B16(4, 20, 12, 12, EPS),
507 PARAMS_B16(4, 9, 16, 16, EPS)
510 INST_TEST_CASE(Simple_nCdhw16c_padded,
511 PARAMS_B16_3D(2, 12, 16, 8, 20, EPS),
512 PARAMS_B16_3D(2, 9, 16, 8, 20, EPS),
513 PARAMS_B16_3D(2, 23, 10, 8, 4, EPS),
514 PARAMS_B16_3D(2, 27, 10, 8, 4, EPS)
517 INST_TEST_CASE(Simple_nChw8c_padded,
518 PARAMS_B8(1, 27, 9, 10, EPS),
519 PARAMS_B8(1, 12, 10, 9, EPS),
520 PARAMS_B8(4, 20, 12, 12, EPS),
521 PARAMS_B8(4, 7, 16, 16, EPS)
525 INST_TEST_CASE(Simple_nCdhw16c,
526 PARAMS_B16_3D(2, 32, 4, 4, 4, EPS),
527 PARAMS_B16_3D(2, 32, 4, 4, 4, EPS),
528 PARAMS_B16_3D(2, 32, 8, 8, 8, EPS),
529 PARAMS_B16_3D(2, 32, 8, 8, 8, EPS),
530 PARAMS_B16_3D(2, 32, 16, 8, 20, EPS),
531 PARAMS_B16_3D(2, 32, 16, 8, 20, EPS),
532 PARAMS_B16_3D(2, 32, 10, 8, 4, EPS),
533 PARAMS_B16_3D(2, 32, 10, 8, 4, EPS)
536 INST_TEST_CASE(Simple_nCdhw8c,
537 PARAMS_B8_3D(2, 32, 4, 4, 4, EPS),
538 PARAMS_B8_3D(2, 32, 4, 4, 4, EPS),
539 PARAMS_B8_3D(2, 32, 8, 8, 8, EPS),
540 PARAMS_B8_3D(2, 32, 8, 8, 8, EPS),
541 PARAMS_B8_3D(2, 32, 16, 8, 20, EPS),
542 PARAMS_B8_3D(2, 32, 16, 8, 20, EPS),
543 PARAMS_B8_3D(2, 32, 10, 8, 4, EPS),
544 PARAMS_B8_3D(2, 32, 10, 8, 4, EPS)
547 INST_TEST_CASE(Simple_NC,
548 PARAMS_NC(2, 8, 1, 1, EPS),
549 PARAMS_NC(2, 10, 1, 1, EPS),
550 PARAMS_NC(2, 8, 1, 1, EPS),
551 PARAMS_NC(2, 10, 1, 1, EPS)
554 INST_TEST_CASE(Simple_NCDHW,
555 PARAMS_N_3D(2, 8, 1, 1, 1, EPS),
556 PARAMS_N_3D(2, 10, 1, 1, 1, EPS),
557 PARAMS_N_3D(2, 8, 4, 4, 4, EPS),
558 PARAMS_N_3D(2, 10, 4, 4, 4, EPS)
561 INST_TEST_CASE(Simple_NCHW,
562 PARAMS_N(2, 8, 1, 1, EPS),
563 PARAMS_N(2, 10, 1, 1, EPS),
564 PARAMS_N(2, 8, 4, 4, EPS),
565 PARAMS_N(2, 10, 4, 4, EPS)
568 INST_TEST_CASE(Simple_NHWC,
569 PARAMS_NHWC(2, 8, 1, 1, EPS),
570 PARAMS_NHWC(2, 10, 1, 1, EPS),
571 PARAMS_NHWC(2, 8, 4, 4, EPS),
572 PARAMS_NHWC(2, 10, 4, 4, EPS)
575 INST_TEST_CASE(Simple_Blocked,
576 PARAMS_B8(2, 8, 1, 1, EPS),
577 PARAMS_B8(2, 8, 4, 4, EPS),
578 PARAMS_B8(2, 8, 6, 6, EPS),
579 PARAMS_B8(2, 16, 4, 4, EPS),
580 PARAMS_B8(2, 16, 4, 4, EPS),
581 PARAMS_B8(2, 16, 8, 8, EPS),
582 PARAMS_B8(2, 16, 8, 8, EPS),
583 PARAMS_B8(2, 16, 16, 8, EPS),
584 PARAMS_B8(2, 16, 16, 8, EPS),
585 PARAMS_B8(2, 16, 10, 8, EPS),
586 PARAMS_B8(2, 16, 10, 8, EPS),
587 PARAMS_B16(2, 16, 4, 4, EPS),
588 PARAMS_B16(2, 16, 4, 4, EPS),
589 PARAMS_B16(2, 16, 8, 8, EPS),
590 PARAMS_B16(2, 16, 8, 8, EPS),
591 PARAMS_B16(2, 16, 16, 8, EPS),
592 PARAMS_B16(2, 16, 16, 8, EPS),
593 PARAMS_B16(2, 16, 10, 8, EPS),
594 PARAMS_B16(2, 16, 10, 8, EPS)
597 INST_TEST_CASE(GoogleNet_NCHW,
598 PARAMS_N(2, 64, 112, 112, EPS),
599 PARAMS_N(2, 64, 56, 56, EPS),
600 PARAMS_N(2, 192, 56, 56, EPS),
601 PARAMS_N(2, 96, 28, 28, EPS),
602 PARAMS_N(2, 16, 28, 28, EPS),
603 PARAMS_N(2, 64, 28, 28, EPS),
604 PARAMS_N(2, 128, 28, 28, EPS),
605 PARAMS_N(2, 32, 28, 28, EPS),
606 PARAMS_N(2, 96, 28, 28, EPS),
607 PARAMS_N(2, 96, 14, 14, EPS),
608 PARAMS_N(2, 16, 14, 14, EPS),
609 PARAMS_N(2, 192, 14, 14, EPS),
610 PARAMS_N(2, 208, 14, 14, EPS),
611 PARAMS_N(2, 48, 14, 14, EPS),
612 PARAMS_N(2, 64, 14, 14, EPS),
613 PARAMS_N(2, 112, 14, 14, EPS),
614 PARAMS_N(2, 24, 14, 14, EPS),
615 PARAMS_N(2, 160, 14, 14, EPS),
616 PARAMS_N(2, 224, 14, 14, EPS),
617 PARAMS_N(2, 128, 4, 4, EPS),
618 PARAMS_N(2, 128, 14, 14, EPS),
619 PARAMS_N(2, 512, 14, 14, EPS),
620 PARAMS_N(2, 256, 14, 14, EPS),
621 PARAMS_N(2, 144, 14, 14, EPS),
622 PARAMS_N(2, 32, 14, 14, EPS),
623 PARAMS_N(2, 228, 14, 14, EPS),
624 PARAMS_N(2, 528, 14, 14, EPS),
625 PARAMS_N(2, 320, 14, 14, EPS),
626 PARAMS_N(2, 160, 7, 7, EPS),
627 PARAMS_N(2, 32, 7, 7, EPS),
628 PARAMS_N(2, 256, 7, 7, EPS),
629 PARAMS_N(2, 320, 7, 7, EPS),
630 PARAMS_N(2, 128, 7, 7, EPS),
631 PARAMS_N(2, 192, 7, 7, EPS),
632 PARAMS_N(2, 48, 7, 7, EPS),
633 PARAMS_N(2, 384, 7, 7, EPS)
636 INST_TEST_CASE(GoogleNet_Blocked_8,
637 PARAMS_B8(2, 64, 112, 112, EPS),
638 PARAMS_B8(2, 64, 56, 56, EPS),
639 PARAMS_B8(2, 192, 56, 56, EPS),
640 PARAMS_B8(2, 96, 28, 28, EPS),
641 PARAMS_B8(2, 16, 28, 28, EPS),
642 PARAMS_B8(2, 64, 28, 28, EPS),
643 PARAMS_B8(2, 128, 28, 28, EPS),
644 PARAMS_B8(2, 32, 28, 28, EPS),
645 PARAMS_B8(2, 96, 28, 28, EPS),
646 PARAMS_B8(2, 96, 14, 14, EPS),
647 PARAMS_B8(2, 16, 14, 14, EPS),
648 PARAMS_B8(2, 192, 14, 14, EPS),
649 PARAMS_B8(2, 208, 14, 14, EPS),
650 PARAMS_B8(2, 48, 14, 14, EPS),
651 PARAMS_B8(2, 64, 14, 14, EPS),
652 PARAMS_B8(2, 112, 14, 14, EPS),
653 PARAMS_B8(2, 24, 14, 14, EPS),
654 PARAMS_B8(2, 160, 14, 14, EPS),
655 PARAMS_B8(2, 224, 14, 14, EPS),
656 PARAMS_B8(2, 128, 4, 4, EPS),
657 PARAMS_B8(2, 128, 14, 14, EPS),
658 PARAMS_B8(2, 512, 14, 14, EPS),
659 PARAMS_B8(2, 256, 14, 14, EPS),
660 PARAMS_B8(2, 144, 14, 14, EPS),
661 PARAMS_B8(2, 32, 14, 14, EPS),
662 PARAMS_B8(2, 528, 14, 14, EPS),
663 PARAMS_B8(2, 320, 14, 14, EPS),
664 PARAMS_B8(2, 160, 7, 7, EPS),
665 PARAMS_B8(2, 32, 7, 7, EPS),
666 PARAMS_B8(2, 256, 7, 7, EPS),
667 PARAMS_B8(2, 320, 7, 7, EPS),
668 PARAMS_B8(2, 128, 7, 7, EPS),
669 PARAMS_B8(2, 192, 7, 7, EPS),
670 PARAMS_B8(2, 48, 7, 7, EPS),
671 PARAMS_B8(2, 384, 7, 7, EPS)
674 INST_TEST_CASE(GoogleNet_Blocked_16,
675 PARAMS_B16(2, 64, 112, 112, EPS),
676 PARAMS_B16(2, 64, 56, 56, EPS),
677 PARAMS_B16(2, 192, 56, 56, EPS),
678 PARAMS_B16(2, 96, 28, 28, EPS),
679 PARAMS_B16(2, 16, 28, 28, EPS),
680 PARAMS_B16(2, 64, 28, 28, EPS),
681 PARAMS_B16(2, 128, 28, 28, EPS),
682 PARAMS_B16(2, 32, 28, 28, EPS),
683 PARAMS_B16(2, 96, 28, 28, EPS),
684 PARAMS_B16(2, 96, 14, 14, EPS),
685 PARAMS_B16(2, 16, 14, 14, EPS),
686 PARAMS_B16(2, 192, 14, 14, EPS),
687 PARAMS_B16(2, 208, 14, 14, EPS),
688 PARAMS_B16(2, 48, 14, 14, EPS),
689 PARAMS_B16(2, 64, 14, 14, EPS),
690 PARAMS_B16(2, 112, 14, 14, EPS),
691 //PARAMS_B16(2, 24, 14, 14, EPS),
692 PARAMS_B16(2, 160, 14, 14, EPS),
693 PARAMS_B16(2, 224, 14, 14, EPS),
694 PARAMS_B16(2, 128, 4, 4, EPS),
695 PARAMS_B16(2, 128, 14, 14, EPS),
696 PARAMS_B16(2, 512, 14, 14, EPS),
697 PARAMS_B16(2, 256, 14, 14, EPS),
698 PARAMS_B16(2, 144, 14, 14, EPS),
699 PARAMS_B16(2, 32, 14, 14, EPS),
700 PARAMS_B16(2, 528, 14, 14, EPS),
701 PARAMS_B16(2, 320, 14, 14, EPS),
702 PARAMS_B16(2, 160, 7, 7, EPS),
703 PARAMS_B16(2, 32, 7, 7, EPS),
704 PARAMS_B16(2, 256, 7, 7, EPS),
705 PARAMS_B16(2, 320, 7, 7, EPS),
706 PARAMS_B16(2, 128, 7, 7, EPS),
707 PARAMS_B16(2, 192, 7, 7, EPS),
708 PARAMS_B16(2, 48, 7, 7, EPS),
709 PARAMS_B16(2, 384, 7, 7, EPS)