Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / gtests / test_convolution_dw_conv_f32.cpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "mkldnn_test_common.hpp"
18 #include "gtest/gtest.h"
19
20 #include "mkldnn.hpp"
21 #include "test_convolution_dw_conv_common.hpp"
22 namespace mkldnn {
23
24 using convolution_test = convolution_dw_conv_test<float, float, float, float>;
25
26 TEST_P(convolution_test, TestConvolutionDwConv)
27 {
28 }
29
30 #define FMT_BIAS x
31 #define FMT_DATA_BLOCKED nChw8c
32 #define FMT_DATA_BLOCKED16 nChw16c
33
34 #define EXPAND_FORMATS(src, conv1_weights, conv1_bias, conv2_weights, conv2_bias, dst) \
35     { mkldnn::memory::format::src, mkldnn::memory::format::conv1_weights, mkldnn::memory::format::conv1_bias, \
36     mkldnn::memory::format::conv2_weights, mkldnn::memory::format::conv2_bias, mkldnn::memory::format::dst }
37
38 #define FMT_WEIGHTS_BLOCKED OIhw8i8o
39 #define FMT_WEIGHTS_BLOCKED16 OIhw16i16o
40
41 #define FMT_WEIGHTS_DW_BLOCKED Goihw8g
42 #define FMT_WEIGHTS_DW_BLOCKED16 Goihw16g
43
44 #define ENGINE mkldnn::engine::kind::cpu
45 #define ALGORITHM mkldnn::convolution_direct
46
47 #define CONCAT_WITH_UNDERSCORE_(a,b) a ## _ ## b
48 #define CONCAT_WITH_UNDERSCORE(a,b) CONCAT_WITH_UNDERSCORE_(a,b)
49
50 #define INST_TEST_CASE_(str, ...) INSTANTIATE_TEST_CASE_P( \
51         str, convolution_test, ::testing::Values(__VA_ARGS__))
52
53 #define INST_TEST_CASE(str, ...) INST_TEST_CASE_( \
54         CONCAT_WITH_UNDERSCORE(CONCAT_WITH_UNDERSCORE(TEST_CASE_NAME_PREFIX, \
55         str), dw_conv),  __VA_ARGS__)
56
57 #define EXPAND_ARGS(args) args
58
59 #define PARAMS(src, conv1_weights, conv1_bias, conv2_weights, conv2_bias, dst, ...) \
60     test_convolution_dw_conv_params_t {ENGINE, ALGORITHM, \
61     EXPAND_FORMATS(src, conv1_weights, conv1_bias, conv2_weights, conv2_bias, dst), {__VA_ARGS__} }
62
63 INST_TEST_CASE(Mobilenet_Blocked,
64     PARAMS(FMT_DATA_BLOCKED, FMT_WEIGHTS_BLOCKED, FMT_BIAS, FMT_WEIGHTS_DW_BLOCKED, FMT_BIAS, FMT_DATA_BLOCKED,
65            2, 8, 19, 33,  56, 3, 3, 1, 1, 2, 2,  56, 3, 3, 1, 1, 1, 1), // 1_1
66     PARAMS(FMT_DATA_BLOCKED, FMT_WEIGHTS_BLOCKED, FMT_BIAS, FMT_WEIGHTS_DW_BLOCKED, FMT_BIAS, FMT_DATA_BLOCKED,
67            2, 32, 19, 33,  56, 1, 1, 0, 0, 1, 1,  56, 3, 3, 1, 1, 2, 2), // 2_1
68     PARAMS(FMT_DATA_BLOCKED, FMT_WEIGHTS_BLOCKED, FMT_BIAS, FMT_WEIGHTS_DW_BLOCKED, FMT_BIAS, FMT_DATA_BLOCKED,
69            2, 56, 9, 16,  112, 1, 1, 0, 0, 1, 1,  112, 3, 3, 1, 1, 1, 1), // 2_2
70     PARAMS(FMT_DATA_BLOCKED, FMT_WEIGHTS_BLOCKED, FMT_BIAS, FMT_WEIGHTS_DW_BLOCKED, FMT_BIAS, FMT_DATA_BLOCKED,
71            2, 112, 9, 16,  112, 1, 1, 0, 0, 1, 1,  112, 3, 3, 1, 1, 2, 2), // 3_1
72     PARAMS(FMT_DATA_BLOCKED, FMT_WEIGHTS_BLOCKED, FMT_BIAS, FMT_WEIGHTS_DW_BLOCKED, FMT_BIAS, FMT_DATA_BLOCKED,
73            2, 112, 4, 8,  208, 1, 1, 0, 0, 1, 1,  208, 3, 3, 1, 1, 1, 1),  // 3_2
74     PARAMS(FMT_DATA_BLOCKED, FMT_WEIGHTS_BLOCKED, FMT_BIAS, FMT_WEIGHTS_DW_BLOCKED, FMT_BIAS, FMT_DATA_BLOCKED,
75            2, 208, 4, 8,  216, 1, 1, 0, 0, 1, 1,  216, 3, 3, 1, 1, 2, 2),  // 4_1
76     PARAMS(FMT_DATA_BLOCKED, FMT_WEIGHTS_BLOCKED, FMT_BIAS, FMT_WEIGHTS_DW_BLOCKED, FMT_BIAS, FMT_DATA_BLOCKED,
77            2, 216, 2, 4,  328, 1, 1, 0, 0, 1, 1,  328, 3, 3, 1, 1, 1, 1),  // 4_2
78     PARAMS(FMT_DATA_BLOCKED, FMT_WEIGHTS_BLOCKED, FMT_BIAS, FMT_WEIGHTS_DW_BLOCKED, FMT_BIAS, FMT_DATA_BLOCKED,
79            2, 328, 2, 4,  288, 1, 1, 0, 0, 1, 1,  288, 3, 3, 1, 1, 1, 1),  // 5_1
80     PARAMS(FMT_DATA_BLOCKED, FMT_WEIGHTS_BLOCKED, FMT_BIAS, FMT_WEIGHTS_DW_BLOCKED, FMT_BIAS, FMT_DATA_BLOCKED,
81            2, 288, 2, 4,  288, 1, 1, 0, 0, 1, 1,  288, 3, 3, 1, 1, 1, 1),  // 5_2
82     PARAMS(FMT_DATA_BLOCKED, FMT_WEIGHTS_BLOCKED, FMT_BIAS, FMT_WEIGHTS_DW_BLOCKED, FMT_BIAS, FMT_DATA_BLOCKED,
83            2, 288, 2, 4,  240, 1, 1, 0, 0, 1, 1,  240, 3, 3, 1, 1, 1, 1),  // 5_3
84     PARAMS(FMT_DATA_BLOCKED, FMT_WEIGHTS_BLOCKED, FMT_BIAS, FMT_WEIGHTS_DW_BLOCKED, FMT_BIAS, FMT_DATA_BLOCKED,
85            2, 240, 2, 4,  264, 1, 1, 0, 0, 1, 1,  264, 3, 3, 1, 1, 1, 1),  // 5_4
86     PARAMS(FMT_DATA_BLOCKED, FMT_WEIGHTS_BLOCKED, FMT_BIAS, FMT_WEIGHTS_DW_BLOCKED, FMT_BIAS, FMT_DATA_BLOCKED,
87            2, 48, 75, 75,  48, 1, 1, 0, 0, 1, 1,  48, 3, 3, 1, 1, 2, 2),
88     PARAMS(FMT_DATA_BLOCKED, FMT_WEIGHTS_BLOCKED, FMT_BIAS, FMT_WEIGHTS_DW_BLOCKED, FMT_BIAS, FMT_DATA_BLOCKED,
89            2, 48, 75, 75,  48, 3, 3, 1, 1, 1, 1,  48, 3, 3, 1, 1, 2, 2)
90
91 );
92
93 }