1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "mkldnn_test_common.hpp"
18 #include "gtest/gtest.h"
19 #include "math_utils.hpp"
21 #include "test_convolution_eltwise_forward_common.hpp"
25 using convolution_test = convolution_eltwise_test<float, float, float, float>;
27 TEST_P(convolution_test, TestConvolutionEltwise)
31 #define EXPAND_FORMATS(src, weights, bias, dst) \
32 { mkldnn::memory::format::src, mkldnn::memory::format::weights, \
33 mkldnn::memory::format::bias, mkldnn::memory::format::dst }
35 #define CONCAT_WITH_UNDERSCORE_(a,b) a ## _ ## b
36 #define CONCAT_WITH_UNDERSCORE(a,b) CONCAT_WITH_UNDERSCORE_(a,b)
38 #define INST_TEST_CASE_(str, ...) INSTANTIATE_TEST_CASE_P( \
39 str, convolution_test, ::testing::Values(__VA_ARGS__))
41 #define INST_TEST_CASE(str, ...) INST_TEST_CASE_( \
42 CONCAT_WITH_UNDERSCORE(CONCAT_WITH_UNDERSCORE(Convolution, \
43 str), eltwise), __VA_ARGS__)
45 #define EXPAND_ARGS(args) args
48 EXPAND_ARGS(PARAMS_CONV(eltwise_relu, __VA_ARGS__)), \
49 EXPAND_ARGS(PARAMS_CONV(eltwise_elu, __VA_ARGS__)), \
50 EXPAND_ARGS(PARAMS_CONV(eltwise_tanh, __VA_ARGS__)), \
51 EXPAND_ARGS(PARAMS_CONV(eltwise_square, __VA_ARGS__)), \
52 EXPAND_ARGS(PARAMS_CONV(eltwise_abs, __VA_ARGS__)), \
53 EXPAND_ARGS(PARAMS_CONV(eltwise_sqrt, __VA_ARGS__)), \
54 EXPAND_ARGS(PARAMS_CONV(eltwise_linear, __VA_ARGS__)), \
55 EXPAND_ARGS(PARAMS_CONV(eltwise_bounded_relu, __VA_ARGS__)), \
56 EXPAND_ARGS(PARAMS_CONV(eltwise_soft_relu, __VA_ARGS__)), \
57 EXPAND_ARGS(PARAMS_CONV(eltwise_logistic, __VA_ARGS__))
59 #define ELTWISE_ALPHA 0.5f
60 #define ELTWISE_BETA 1.5f
62 #define PARAMS_CONV(alg, src, weights, bias, dst, ...) \
63 test_convolution_eltwise_params_t {alg, mkldnn::engine::kind::cpu, \
64 mkldnn::convolution_direct, ELTWISE_ALPHA, ELTWISE_BETA, \
65 EXPAND_FORMATS(src, weights, bias, dst), /* empty attributes */ {}, \
68 INST_TEST_CASE(SimpleSmall,
69 PARAMS(nchw, oihw, x, nchw, 2, 1, 32, 13, 13, 48, 11, 11, 3, 3, 0, 0, 1, 1),
70 PARAMS(nchw, oihw, x, nchw, 2, 1, 16, 13, 13, 48, 13, 13, 1, 1, 0, 0, 1, 1),
71 PARAMS(nchw, goihw, x, nchw, 2, 64, 64, 16, 16, 64, 16, 16, 3, 3, 0, 0, 1, 1),
72 PARAMS(nchw, goihw, x, nchw, 2, 32, 32, 9, 9, 32, 9, 9, 1, 1, 0, 0, 1, 1)
75 INST_TEST_CASE(SimpleSmall_Blocked,
76 PARAMS(nChw8c, Goihw8g, x, nChw8c, 1, 8, 8, 5, 5, 8, 5, 5, 3, 3, 1, 1, 1, 1),
77 PARAMS(nChw8c, OIhw8i8o, x, nChw8c, 1, 1, 48, 20, 20, 48, 20, 20, 1, 1, 0, 0, 1, 1),
78 PARAMS(nChw8c, OIhw8i8o, x, nChw8c, 1, 1, 48, 20, 20, 48, 20, 20, 3, 3, 0, 0, 1, 1)
81 INST_TEST_CASE(SimpleSmall_Blocked_Tail,
82 PARAMS(nChw8c, Goihw8g, x, nChw8c, 1, 47, 47, 20, 20, 47, 20, 20, 3, 3, 1, 1, 1, 1),
83 PARAMS(nChw8c, OIhw8i8o, x, nChw8c, 1, 1, 47, 20, 20, 47, 20, 20, 1, 1, 0, 0, 1, 1),
84 PARAMS(nChw8c, OIhw8i8o, x, nChw8c, 1, 1, 47, 20, 20, 47, 20, 20, 3, 3, 0, 0, 1, 1)
87 INST_TEST_CASE(SimpleSmall_Blocked16,
88 PARAMS(nChw16c, Goihw16g, x, nChw16c, 1, 48, 48, 20, 20, 48, 20, 20, 3, 3, 1, 1, 1, 1),
89 PARAMS(nChw16c, OIhw16i16o, x, nChw16c, 1, 1, 48, 20, 20, 48, 20, 20, 1, 1, 0, 0, 1, 1),
90 PARAMS(nChw16c, OIhw16i16o, x, nChw16c, 1, 1, 48, 20, 20, 48, 20, 20, 3, 3, 0, 0, 1, 1),
91 PARAMS(nChw16c, OIhw16i16o, x, nChw16c, 2, 1, 32, 32, 32, 32, 32, 32, 3, 3, 0, 0, 1, 1)
94 INST_TEST_CASE(SimpleSmall_Blocked16_Tail,
95 PARAMS(nChw16c, Goihw16g, x, nChw16c, 1, 47, 47, 20, 20, 47, 20, 20, 3, 3, 1, 1, 1, 1),
96 PARAMS(nChw16c, OIhw16i16o, x, nChw16c, 1, 1, 47, 20, 20, 47, 20, 20, 1, 1, 0, 0, 1, 1),
97 PARAMS(nChw16c, OIhw16i16o, x, nChw16c, 1, 1, 47, 20, 20, 47, 20, 20, 3, 3, 0, 0, 1, 1),
98 PARAMS(nChw16c, OIhw16i16o, x, nChw16c, 2, 1, 32, 32, 32, 32, 32, 32, 3, 3, 0, 0, 1, 1)