Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / gtests / test_convolution_depthwise_forward_f32.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "mkldnn_test_common.hpp"
18 #include "gtest/gtest.h"
19 #include "mkldnn.hpp"
20 #include "test_convolution_depthwise_forward_common.hpp"
21
22 namespace mkldnn {
23
24 using convolution_test = convolution_depthwise_test<float, float, float, float>;
25
26 TEST_P(convolution_test, TestConvolution)
27 {
28 }
29
30 #define EXPAND_FORMATS(src, weights, bias, dst) \
31     { mkldnn::memory::format::src, mkldnn::memory::format::weights, \
32     mkldnn::memory::format::bias, mkldnn::memory::format::dst }
33
34 #define FMT_WEIGHTS_BLOCKED8 OIhw8i8o
35 #define FMT_WEIGHTS_BLOCKED8_DW Goihw8g
36 #define FMT_WEIGHTS_BLOCKED16 OIhw16i16o
37 #define FMT_WEIGHTS_BLOCKED16_DW Goihw16g
38
39 #define ENGINE mkldnn::engine::kind::cpu
40 #define ALGORITHM mkldnn::convolution_direct
41
42 #define CONCAT_WITH_UNDERSCORE_(a,b) a ## _ ## b
43 #define CONCAT_WITH_UNDERSCORE(a,b) CONCAT_WITH_UNDERSCORE_(a,b)
44
45 #define INST_TEST_CASE_(str, ...) INSTANTIATE_TEST_CASE_P( \
46         str, convolution_test, ::testing::Values(__VA_ARGS__))
47
48 #define INST_TEST_CASE(str, ...) INST_TEST_CASE_( \
49         CONCAT_WITH_UNDERSCORE(CONCAT_WITH_UNDERSCORE(Convolution, \
50         str), depthwise),  __VA_ARGS__)
51
52 #define EXPAND_ARGS(args) args
53
54 #define PARAMS(...) \
55     EXPAND_ARGS(PARAMS_CONV(depthwise_scale_shift, __VA_ARGS__)), \
56     EXPAND_ARGS(PARAMS_CONV(depthwise_prelu, __VA_ARGS__))
57
58 #define PARAMS_CONV(alg, src, weights, bias, dst, ...) \
59     test_convolution_depthwise_params_t {alg,  ENGINE, ALGORITHM, \
60     EXPAND_FORMATS(src, weights, bias, dst), /* empty attributes */ {}, \
61     {__VA_ARGS__} }
62
63     INST_TEST_CASE(SimpleSmall,
64         PARAMS(nchw, oihw, x, nchw,
65                2, 1, 32, 13, 13, 48, 11, 11, 3, 3, 0, 0, 1, 1),
66         PARAMS(nchw, oihw, x, nchw,
67                2, 1, 16, 13, 13, 48, 13, 13, 1, 1, 0, 0, 1, 1),
68         PARAMS(nchw, goihw, x, nchw,
69                2, 64, 64, 16, 16, 64, 16, 16, 3, 3, 0, 0, 1, 1),
70         PARAMS(nchw, goihw, x, nchw,
71                2, 32, 32, 9, 9, 32, 9, 9, 1, 1, 0, 0, 1, 1)
72     );
73
74     INST_TEST_CASE(SimpleSmall_Blocked8,
75         PARAMS(nChw8c, FMT_WEIGHTS_BLOCKED8, x, nChw8c,
76                2, 1, 32, 13, 13, 48, 11, 11, 3, 3, 0, 0, 1, 1),
77         PARAMS(nChw8c, FMT_WEIGHTS_BLOCKED8, x, nChw8c,
78                2, 1, 16, 13, 13, 48, 13, 13, 1, 1, 0, 0, 1, 1),
79         PARAMS(nChw8c, FMT_WEIGHTS_BLOCKED8_DW, x, nChw8c,
80                2, 64, 64, 16, 16, 64, 16, 16, 3, 3, 0, 0, 1, 1),
81         PARAMS(nChw8c, FMT_WEIGHTS_BLOCKED8_DW, x, nChw8c,
82                2, 32, 32, 9, 9, 32, 9, 9, 1, 1, 0, 0, 1, 1)
83     );
84
85     INST_TEST_CASE(SimpleSmall_Blocked16,
86         PARAMS(nChw16c, FMT_WEIGHTS_BLOCKED16, x, nChw16c,
87                2, 1, 32, 13, 13, 48, 11, 11, 3, 3, 0, 0, 1, 1),
88         PARAMS(nChw16c, FMT_WEIGHTS_BLOCKED16, x, nChw16c,
89                2, 1, 16, 13, 13, 48, 13, 13, 1, 1, 0, 0, 1, 1),
90         PARAMS(nChw16c, FMT_WEIGHTS_BLOCKED16_DW, x, nChw16c,
91                2, 64, 64, 16, 16, 64, 16, 16, 3, 3, 0, 0, 1, 1),
92         PARAMS(nChw16c, FMT_WEIGHTS_BLOCKED16_DW, x, nChw16c,
93                2, 32, 32, 9, 9, 32, 9, 9, 1, 1, 0, 0, 1, 1)
94     );
95 }