1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "mkldnn_test_common.hpp"
18 #include "gtest/gtest.h"
23 using fmt = memory::format;
25 #define EXP_VALS_NUM 3
28 fmt exp[EXP_VALS_NUM];
30 struct conv_any_fmt_test_params {
32 const engine::kind engine_kind;
35 fmt_compare weights_fmt;
38 test_convolution_sizes_t test_cd;
41 template <typename data_t>
42 class convolution_any_fmt_test
43 : public ::testing::TestWithParam<conv_any_fmt_test_params> {
45 virtual bool FmtIsExp(const mkldnn_memory_format_t in, fmt *exp ) {
46 for (int i = 0; i < EXP_VALS_NUM; i++)
53 conv_any_fmt_test_params p = ::testing::
54 TestWithParam<conv_any_fmt_test_params>::GetParam();
56 ASSERT_TRUE(p.engine_kind == engine::kind::cpu);
57 ASSERT_EQ(p.aprop_kind, prop_kind::forward);
58 ASSERT_EQ(p.aalgorithm, algorithm::convolution_direct);
59 auto eng = engine(p.engine_kind, 0);
60 memory::data_type data_type = data_traits<data_t>::data_type;
61 ASSERT_EQ(data_type, mkldnn::memory::data_type::f32);
63 // Some format chekers
64 ASSERT_NE(p.src_fmt.exp[0], fmt::any);
65 ASSERT_NE(p.weights_fmt.exp[0], fmt::any);
66 ASSERT_NE(p.bias_fmt.exp[0], fmt::any);
67 ASSERT_NE(p.dst_fmt.exp[0], fmt::any);
69 p.src_fmt.in == fmt::any || p.src_fmt.in == p.src_fmt.exp[0]);
70 ASSERT_TRUE(p.weights_fmt.in == fmt::any
71 || p.weights_fmt.in == p.weights_fmt.exp[0]);
72 ASSERT_TRUE(p.bias_fmt.in == fmt::any
73 || p.bias_fmt.in == p.bias_fmt.exp[0]);
75 p.dst_fmt.in == fmt::any || p.dst_fmt.in == p.dst_fmt.exp[0]);
77 test_convolution_sizes_t cd = p.test_cd;
79 auto c_src_desc = create_md(
80 { cd.mb, cd.ic, cd.ih, cd.iw }, data_type, p.src_fmt.in);
81 auto c_weights_desc = cd.ng > 1 ?
82 create_md({ cd.ng, cd.oc / cd.ng, cd.ic / cd.ng, cd.kh, cd.kw },
83 data_type, p.weights_fmt.in) :
84 create_md({ cd.oc, cd.ic, cd.kh, cd.kw }, data_type,
86 auto c_dst_desc = create_md(
87 { cd.mb, cd.oc, cd.oh, cd.ow }, data_type, p.dst_fmt.in);
89 bool with_bias = p.bias_fmt.in != fmt::format_undef;
90 auto c_bias_desc = with_bias ?
91 create_md({ cd.oc }, data_type, p.bias_fmt.in) :
92 create_md({}, data_type, p.bias_fmt.in);
94 auto conv_desc = with_bias ?
95 convolution_forward::desc(p.aprop_kind, p.aalgorithm, c_src_desc,
96 c_weights_desc, c_bias_desc, c_dst_desc,
97 { cd.strh, cd.strw }, { cd.padh, cd.padw }, { cd.padh, cd.padw },
99 convolution_forward::desc(p.aprop_kind, p.aalgorithm, c_src_desc,
100 c_weights_desc, c_dst_desc, { cd.strh, cd.strw }, { cd.strh, cd.strw },
101 { cd.padh, cd.padw }, padding_kind::zero);
103 auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, eng);
106 FmtIsExp(conv_prim_desc.src_primitive_desc().desc().data.format,
108 ASSERT_TRUE(FmtIsExp(
109 conv_prim_desc.weights_primitive_desc().desc().data.format,
112 ASSERT_TRUE(FmtIsExp(
113 conv_prim_desc.bias_primitive_desc().desc().data.format,
117 FmtIsExp(conv_prim_desc.dst_primitive_desc().desc().data.format,
122 using conv_any_fmt_test_float = convolution_any_fmt_test<float>;
123 using conv_any_fmt_test_params_float = conv_any_fmt_test_params;
125 TEST_P(conv_any_fmt_test_float, TestsConvolutionAnyFmt)
128 #define ENGINE engine::kind::cpu
129 #define ALG algorithm::convolution_direct
130 #define PROP_KIND prop_kind::forward
132 #define ANY_X { fmt::any, \
133 { fmt::x, fmt::format_undef, fmt::format_undef } }
134 #define ANY_NCHW { fmt::any, \
135 { fmt::nchw, fmt::format_undef, fmt::format_undef } }
136 #define ANY_OIHW { fmt::any, \
137 { fmt::oihw, fmt::format_undef, fmt::format_undef } }
139 #define ANY_OHWIxO { fmt::any, \
140 { fmt::Ohwi8o, fmt::Ohwi16o, fmt::Oihw16o } }
141 #define ANY_NCHWxC { fmt::any, \
142 { fmt::nChw8c, fmt::nChw16c, fmt::format_undef } }
143 #define ANY_OIHWxIxO { fmt::any, \
144 { fmt::OIhw8i8o, fmt::OIhw16i16o, fmt::format_undef } }
145 #define ANY_GOIHWxIxO { fmt::any,\
146 { fmt::gOIhw8i8o, fmt::gOIhw16i16o, fmt::format_undef } }
148 //INSTANTIATE_TEST_CASE_P(TestConvolutionAnyFmtForward, conv_any_fmt_test_float,
149 // ::testing::Values(conv_any_fmt_test_params_float{ PROP_KIND, ENGINE, ALG,
150 // ANY_NCHW, ANY_OIHW, ANY_X, ANY_NCHW,
151 // { 2, 1, 4, 4, 4, 6, 4, 4, 3, 3, 1, 1, 1, 1 } }));
153 INSTANTIATE_TEST_CASE_P(
154 TestConvolutionAlexnetAnyFmtForwardxlocked, conv_any_fmt_test_float,
156 conv_any_fmt_test_params_float{ PROP_KIND, ENGINE, ALG,
157 ANY_NCHW, ANY_OHWIxO, ANY_X, ANY_NCHWxC,
158 { 2, 1, 3, 227, 227, 96, 55, 55, 11, 11, 0, 0, 4, 4 } },
159 conv_any_fmt_test_params_float{ PROP_KIND, ENGINE, ALG,
160 ANY_NCHWxC, ANY_GOIHWxIxO, ANY_X, ANY_NCHWxC,
161 { 2, 2, 96, 27, 27, 256, 27, 27, 5, 5, 2, 2, 1, 1 } },
162 conv_any_fmt_test_params_float{ PROP_KIND, ENGINE, ALG,
163 ANY_NCHWxC, ANY_OIHWxIxO, ANY_X, ANY_NCHWxC,
164 { 2, 1, 256, 13, 13, 384, 13, 13, 3, 3, 1, 1, 1, 1 } },
165 conv_any_fmt_test_params_float{ PROP_KIND, ENGINE, ALG,
166 ANY_NCHWxC, ANY_GOIHWxIxO, ANY_X, ANY_NCHWxC,
167 { 2, 2, 384, 13, 13, 384, 13, 13, 3, 3, 1, 1, 1, 1 } },
168 conv_any_fmt_test_params_float{ PROP_KIND, ENGINE, ALG,
169 ANY_NCHWxC, ANY_GOIHWxIxO, ANY_X, ANY_NCHWxC,
170 { 2, 2, 384, 13, 13, 256, 13, 13, 3, 3, 1, 1, 1, 1 } }));