1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "mkldnn_test_common.hpp"
18 #include "gtest/gtest.h"
19 #include "cpu_isa_traits.hpp"
24 using fmt = memory::format;
26 #define EXP_VALS_NUM 3
29 fmt exp[EXP_VALS_NUM];
31 struct conv_any_fmt_test_params {
33 const engine::kind engine_kind;
36 fmt_compare weights_fmt;
39 test_convolution_sizes_t test_cd;
42 template <typename data_t>
43 class convolution_any_fmt_test
44 : public ::testing::TestWithParam<conv_any_fmt_test_params> {
46 virtual bool FmtIsExp(const mkldnn_memory_format_t in, fmt *exp ) {
47 for (int i = 0; i < EXP_VALS_NUM; i++)
54 // Skip this test if the library cannot select blocked format a priori.
55 // Currently blocking is supported only for sse42 and later CPUs.
56 bool implementation_supports_blocking
57 = impl::cpu::mayiuse(impl::cpu::sse42);
58 if (!implementation_supports_blocking) return;
60 conv_any_fmt_test_params p = ::testing::
61 TestWithParam<conv_any_fmt_test_params>::GetParam();
63 ASSERT_TRUE(p.engine_kind == engine::kind::cpu);
64 ASSERT_EQ(p.aprop_kind, prop_kind::forward);
65 ASSERT_EQ(p.aalgorithm, algorithm::convolution_direct);
66 auto eng = engine(p.engine_kind, 0);
67 memory::data_type data_type = data_traits<data_t>::data_type;
68 ASSERT_EQ(data_type, mkldnn::memory::data_type::f32);
70 // Some format chekers
71 ASSERT_NE(p.src_fmt.exp[0], fmt::any);
72 ASSERT_NE(p.weights_fmt.exp[0], fmt::any);
73 ASSERT_NE(p.bias_fmt.exp[0], fmt::any);
74 ASSERT_NE(p.dst_fmt.exp[0], fmt::any);
76 p.src_fmt.in == fmt::any || p.src_fmt.in == p.src_fmt.exp[0]);
77 ASSERT_TRUE(p.weights_fmt.in == fmt::any
78 || p.weights_fmt.in == p.weights_fmt.exp[0]);
79 ASSERT_TRUE(p.bias_fmt.in == fmt::any
80 || p.bias_fmt.in == p.bias_fmt.exp[0]);
82 p.dst_fmt.in == fmt::any || p.dst_fmt.in == p.dst_fmt.exp[0]);
84 test_convolution_sizes_t cd = p.test_cd;
86 auto c_src_desc = create_md(
87 { cd.mb, cd.ic, cd.ih, cd.iw }, data_type, p.src_fmt.in);
88 auto c_weights_desc = cd.ng > 1 ?
89 create_md({ cd.ng, cd.oc / cd.ng, cd.ic / cd.ng, cd.kh, cd.kw },
90 data_type, p.weights_fmt.in) :
91 create_md({ cd.oc, cd.ic, cd.kh, cd.kw }, data_type,
93 auto c_dst_desc = create_md(
94 { cd.mb, cd.oc, cd.oh, cd.ow }, data_type, p.dst_fmt.in);
96 bool with_bias = p.bias_fmt.in != fmt::format_undef;
97 auto c_bias_desc = with_bias ?
98 create_md({ cd.oc }, data_type, p.bias_fmt.in) :
99 create_md({}, data_type, p.bias_fmt.in);
101 auto conv_desc = with_bias ?
102 convolution_forward::desc(p.aprop_kind, p.aalgorithm, c_src_desc,
103 c_weights_desc, c_bias_desc, c_dst_desc,
104 { cd.strh, cd.strw }, { cd.padh, cd.padw }, { cd.padh, cd.padw },
105 padding_kind::zero) :
106 convolution_forward::desc(p.aprop_kind, p.aalgorithm, c_src_desc,
107 c_weights_desc, c_dst_desc, { cd.strh, cd.strw }, { cd.strh, cd.strw },
108 { cd.padh, cd.padw }, padding_kind::zero);
110 auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, eng);
113 FmtIsExp(conv_prim_desc.src_primitive_desc().desc().data.format,
115 ASSERT_TRUE(FmtIsExp(
116 conv_prim_desc.weights_primitive_desc().desc().data.format,
119 ASSERT_TRUE(FmtIsExp(
120 conv_prim_desc.bias_primitive_desc().desc().data.format,
124 FmtIsExp(conv_prim_desc.dst_primitive_desc().desc().data.format,
129 using conv_any_fmt_test_float = convolution_any_fmt_test<float>;
130 using conv_any_fmt_test_params_float = conv_any_fmt_test_params;
132 TEST_P(conv_any_fmt_test_float, TestsConvolutionAnyFmt)
135 #define ENGINE engine::kind::cpu
136 #define ALG algorithm::convolution_direct
137 #define PROP_KIND prop_kind::forward
139 #define ANY_X { fmt::any, \
140 { fmt::x, fmt::format_undef, fmt::format_undef } }
141 #define ANY_NCHW { fmt::any, \
142 { fmt::nchw, fmt::format_undef, fmt::format_undef } }
143 #define ANY_OIHW { fmt::any, \
144 { fmt::oihw, fmt::format_undef, fmt::format_undef } }
146 #define ANY_OHWIxO { fmt::any, \
147 { fmt::Ohwi8o, fmt::Ohwi16o, fmt::Oihw16o } }
148 #define ANY_NCHWxC { fmt::any, \
149 { fmt::nChw8c, fmt::nChw16c, fmt::format_undef } }
150 #define ANY_OIHWxIxO { fmt::any, \
151 { fmt::OIhw8i8o, fmt::OIhw16i16o, fmt::format_undef } }
152 #define ANY_GOIHWxIxO { fmt::any,\
153 { fmt::gOIhw8i8o, fmt::gOIhw16i16o, fmt::format_undef } }
155 //INSTANTIATE_TEST_CASE_P(TestConvolutionAnyFmtForward, conv_any_fmt_test_float,
156 // ::testing::Values(conv_any_fmt_test_params_float{ PROP_KIND, ENGINE, ALG,
157 // ANY_NCHW, ANY_OIHW, ANY_X, ANY_NCHW,
158 // { 2, 1, 4, 4, 4, 6, 4, 4, 3, 3, 1, 1, 1, 1 } }));
160 INSTANTIATE_TEST_CASE_P(
161 TestConvolutionAlexnetAnyFmtForwardxlocked, conv_any_fmt_test_float,
163 conv_any_fmt_test_params_float{ PROP_KIND, ENGINE, ALG,
164 ANY_NCHW, ANY_OHWIxO, ANY_X, ANY_NCHWxC,
165 { 2, 1, 3, 227, 227, 96, 55, 55, 11, 11, 0, 0, 4, 4 } },
166 conv_any_fmt_test_params_float{ PROP_KIND, ENGINE, ALG,
167 ANY_NCHWxC, ANY_GOIHWxIxO, ANY_X, ANY_NCHWxC,
168 { 2, 2, 96, 27, 27, 256, 27, 27, 5, 5, 2, 2, 1, 1 } },
169 conv_any_fmt_test_params_float{ PROP_KIND, ENGINE, ALG,
170 ANY_NCHWxC, ANY_OIHWxIxO, ANY_X, ANY_NCHWxC,
171 { 2, 1, 256, 13, 13, 384, 13, 13, 3, 3, 1, 1, 1, 1 } },
172 conv_any_fmt_test_params_float{ PROP_KIND, ENGINE, ALG,
173 ANY_NCHWxC, ANY_GOIHWxIxO, ANY_X, ANY_NCHWxC,
174 { 2, 2, 384, 13, 13, 384, 13, 13, 3, 3, 1, 1, 1, 1 } },
175 conv_any_fmt_test_params_float{ PROP_KIND, ENGINE, ALG,
176 ANY_NCHWxC, ANY_GOIHWxIxO, ANY_X, ANY_NCHWxC,
177 { 2, 2, 384, 13, 13, 256, 13, 13, 3, 3, 1, 1, 1, 1 } }));