1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "mkldnn_test_common.hpp"
18 #include "gtest/gtest.h"
21 #include "test_convolution_dw_conv_common.hpp"
24 using convolution_test = convolution_dw_conv_test<uint8_t, int8_t, int32_t, uint8_t>;
26 TEST_P(convolution_test, TestConvolutionDwConv)
31 #define FMT_DATA_BLOCKED nhwc
33 #define EXPAND_FORMATS(src, conv1_weights, conv1_bias, conv2_weights, conv2_bias, dst) \
34 { mkldnn::memory::format::src, mkldnn::memory::format::conv1_weights, mkldnn::memory::format::conv1_bias, \
35 mkldnn::memory::format::conv2_weights, mkldnn::memory::format::conv2_bias, mkldnn::memory::format::dst }
37 #define FMT_WEIGHTS_BLOCKED OhIw8o4i
39 #define FMT_WEIGHTS_DW_BLOCKED Goihw8g
41 #define ENGINE mkldnn::engine::kind::cpu
42 #define ALGORITHM mkldnn::convolution_direct
44 #define CONCAT_WITH_UNDERSCORE_(a,b) a ## _ ## b
45 #define CONCAT_WITH_UNDERSCORE(a,b) CONCAT_WITH_UNDERSCORE_(a,b)
47 #define INST_TEST_CASE_(str, ...) INSTANTIATE_TEST_CASE_P( \
48 str, convolution_test, ::testing::Values(__VA_ARGS__))
50 #define INST_TEST_CASE(str, ...) INST_TEST_CASE_( \
51 CONCAT_WITH_UNDERSCORE(CONCAT_WITH_UNDERSCORE(TEST_CASE_NAME_PREFIX, \
52 str), dw_conv), __VA_ARGS__)
54 #define EXPAND_ARGS(args) args
56 #define PARAMS(src, conv1_weights, conv1_bias, conv2_weights, conv2_bias, dst, ...) \
57 test_convolution_dw_conv_params_t {ENGINE, ALGORITHM, \
58 EXPAND_FORMATS(src, conv1_weights, conv1_bias, conv2_weights, conv2_bias, dst), {__VA_ARGS__} }
60 INST_TEST_CASE(Mobilenet_Blocked,
61 PARAMS(FMT_DATA_BLOCKED, FMT_WEIGHTS_BLOCKED, FMT_BIAS, FMT_WEIGHTS_DW_BLOCKED, FMT_BIAS, FMT_DATA_BLOCKED,
62 2, 8, 19, 33, 56, 3, 3, 1, 1, 2, 2, 56, 3, 3, 1, 1, 1, 1), // 1_1
63 PARAMS(FMT_DATA_BLOCKED, FMT_WEIGHTS_BLOCKED, FMT_BIAS, FMT_WEIGHTS_DW_BLOCKED, FMT_BIAS, FMT_DATA_BLOCKED,
64 2, 32, 19, 33, 56, 1, 1, 0, 0, 1, 1, 56, 3, 3, 1, 1, 2, 2), // 2_1
65 PARAMS(FMT_DATA_BLOCKED, FMT_WEIGHTS_BLOCKED, FMT_BIAS, FMT_WEIGHTS_DW_BLOCKED, FMT_BIAS, FMT_DATA_BLOCKED,
66 2, 56, 9, 16, 112, 1, 1, 0, 0, 1, 1, 112, 3, 3, 1, 1, 1, 1), // 2_2
67 PARAMS(FMT_DATA_BLOCKED, FMT_WEIGHTS_BLOCKED, FMT_BIAS, FMT_WEIGHTS_DW_BLOCKED, FMT_BIAS, FMT_DATA_BLOCKED,
68 2, 112, 9, 16, 112, 1, 1, 0, 0, 1, 1, 112, 3, 3, 1, 1, 2, 2), // 3_1
69 PARAMS(FMT_DATA_BLOCKED, FMT_WEIGHTS_BLOCKED, FMT_BIAS, FMT_WEIGHTS_DW_BLOCKED, FMT_BIAS, FMT_DATA_BLOCKED,
70 2, 112, 4, 8, 208, 1, 1, 0, 0, 1, 1, 208, 3, 3, 1, 1, 1, 1), // 3_2
71 PARAMS(FMT_DATA_BLOCKED, FMT_WEIGHTS_BLOCKED, FMT_BIAS, FMT_WEIGHTS_DW_BLOCKED, FMT_BIAS, FMT_DATA_BLOCKED,
72 2, 208, 4, 8, 216, 1, 1, 0, 0, 1, 1, 216, 3, 3, 1, 1, 2, 2), // 4_1
73 PARAMS(FMT_DATA_BLOCKED, FMT_WEIGHTS_BLOCKED, FMT_BIAS, FMT_WEIGHTS_DW_BLOCKED, FMT_BIAS, FMT_DATA_BLOCKED,
74 2, 216, 2, 4, 328, 1, 1, 0, 0, 1, 1, 328, 3, 3, 1, 1, 1, 1), // 4_2
75 PARAMS(FMT_DATA_BLOCKED, FMT_WEIGHTS_BLOCKED, FMT_BIAS, FMT_WEIGHTS_DW_BLOCKED, FMT_BIAS, FMT_DATA_BLOCKED,
76 2, 328, 2, 4, 288, 1, 1, 0, 0, 1, 1, 288, 3, 3, 1, 1, 1, 1), // 5_1
77 PARAMS(FMT_DATA_BLOCKED, FMT_WEIGHTS_BLOCKED, FMT_BIAS, FMT_WEIGHTS_DW_BLOCKED, FMT_BIAS, FMT_DATA_BLOCKED,
78 2, 288, 2, 4, 288, 1, 1, 0, 0, 1, 1, 288, 3, 3, 1, 1, 1, 1), // 5_2
79 PARAMS(FMT_DATA_BLOCKED, FMT_WEIGHTS_BLOCKED, FMT_BIAS, FMT_WEIGHTS_DW_BLOCKED, FMT_BIAS, FMT_DATA_BLOCKED,
80 2, 288, 2, 4, 240, 1, 1, 0, 0, 1, 1, 240, 3, 3, 1, 1, 1, 1), // 5_3
81 PARAMS(FMT_DATA_BLOCKED, FMT_WEIGHTS_BLOCKED, FMT_BIAS, FMT_WEIGHTS_DW_BLOCKED, FMT_BIAS, FMT_DATA_BLOCKED,
82 2, 240, 2, 4, 264, 1, 1, 0, 0, 1, 1, 264, 3, 3, 1, 1, 1, 1), // 5_4
83 PARAMS(FMT_DATA_BLOCKED, FMT_WEIGHTS_BLOCKED, FMT_BIAS, FMT_WEIGHTS_DW_BLOCKED, FMT_BIAS, FMT_DATA_BLOCKED,
84 2, 48, 75, 75, 48, 1, 1, 0, 0, 1, 1, 48, 3, 3, 1, 1, 2, 2),
85 PARAMS(FMT_DATA_BLOCKED, FMT_WEIGHTS_BLOCKED, FMT_BIAS, FMT_WEIGHTS_DW_BLOCKED, FMT_BIAS, FMT_DATA_BLOCKED,
86 2, 48, 75, 75, 48, 3, 3, 1, 1, 1, 1, 48, 3, 3, 1, 1, 2, 2)