Publishing R3
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / gtests / test_convolution_format_any.cpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "mkldnn_test_common.hpp"
18 #include "gtest/gtest.h"
19
20 #include "mkldnn.hpp"
21
22 namespace mkldnn {
23 using fmt = memory::format;
24
25 #define EXP_VALS_NUM 3
26 struct fmt_compare {
27     fmt in;
28     fmt exp[EXP_VALS_NUM];
29 };
30 struct conv_any_fmt_test_params {
31     prop_kind aprop_kind;
32     const engine::kind engine_kind;
33     algorithm aalgorithm;
34     fmt_compare src_fmt;
35     fmt_compare weights_fmt;
36     fmt_compare bias_fmt;
37     fmt_compare dst_fmt;
38     test_convolution_sizes_t test_cd;
39 };
40
41 template <typename data_t>
42 class convolution_any_fmt_test
43         : public ::testing::TestWithParam<conv_any_fmt_test_params> {
44 protected:
45     virtual bool FmtIsExp(const mkldnn_memory_format_t in, fmt *exp ) {
46         for (int i = 0; i < EXP_VALS_NUM; i++)
47             if (in == exp[i])
48                 return true;
49         return false;
50     }
51     virtual void SetUp()
52     {
53         conv_any_fmt_test_params p = ::testing::
54                 TestWithParam<conv_any_fmt_test_params>::GetParam();
55
56         ASSERT_TRUE(p.engine_kind == engine::kind::cpu);
57         ASSERT_EQ(p.aprop_kind, prop_kind::forward);
58         ASSERT_EQ(p.aalgorithm, algorithm::convolution_direct);
59         auto eng = engine(p.engine_kind, 0);
60         memory::data_type data_type = data_traits<data_t>::data_type;
61         ASSERT_EQ(data_type, mkldnn::memory::data_type::f32);
62
63         // Some format chekers
64         ASSERT_NE(p.src_fmt.exp[0], fmt::any);
65         ASSERT_NE(p.weights_fmt.exp[0], fmt::any);
66         ASSERT_NE(p.bias_fmt.exp[0], fmt::any);
67         ASSERT_NE(p.dst_fmt.exp[0], fmt::any);
68         ASSERT_TRUE(
69                 p.src_fmt.in == fmt::any || p.src_fmt.in == p.src_fmt.exp[0]);
70         ASSERT_TRUE(p.weights_fmt.in == fmt::any
71                 || p.weights_fmt.in == p.weights_fmt.exp[0]);
72         ASSERT_TRUE(p.bias_fmt.in == fmt::any
73                 || p.bias_fmt.in == p.bias_fmt.exp[0]);
74         ASSERT_TRUE(
75                 p.dst_fmt.in == fmt::any || p.dst_fmt.in == p.dst_fmt.exp[0]);
76
77         test_convolution_sizes_t cd = p.test_cd;
78
79         auto c_src_desc = create_md(
80                 { cd.mb, cd.ic, cd.ih, cd.iw }, data_type, p.src_fmt.in);
81         auto c_weights_desc = cd.ng > 1 ?
82                 create_md({ cd.ng, cd.oc / cd.ng, cd.ic / cd.ng, cd.kh, cd.kw },
83                         data_type, p.weights_fmt.in) :
84                 create_md({ cd.oc, cd.ic, cd.kh, cd.kw }, data_type,
85                         p.weights_fmt.in);
86         auto c_dst_desc = create_md(
87                 { cd.mb, cd.oc, cd.oh, cd.ow }, data_type, p.dst_fmt.in);
88
89         bool with_bias = p.bias_fmt.in != fmt::format_undef;
90         auto c_bias_desc = with_bias ?
91                 create_md({ cd.oc }, data_type, p.bias_fmt.in) :
92                 create_md({}, data_type, p.bias_fmt.in);
93
94         auto conv_desc = with_bias ?
95                 convolution_forward::desc(p.aprop_kind, p.aalgorithm, c_src_desc,
96                         c_weights_desc, c_bias_desc, c_dst_desc,
97                         { cd.strh, cd.strw }, { cd.padh, cd.padw }, { cd.padh, cd.padw },
98                         padding_kind::zero) :
99                 convolution_forward::desc(p.aprop_kind, p.aalgorithm, c_src_desc,
100                         c_weights_desc, c_dst_desc, { cd.strh, cd.strw }, { cd.strh, cd.strw },
101                         { cd.padh, cd.padw }, padding_kind::zero);
102
103         auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, eng);
104
105         ASSERT_TRUE(
106                 FmtIsExp(conv_prim_desc.src_primitive_desc().desc().data.format,
107                         p.src_fmt.exp));
108         ASSERT_TRUE(FmtIsExp(
109                 conv_prim_desc.weights_primitive_desc().desc().data.format,
110                 p.weights_fmt.exp));
111         if (with_bias) {
112             ASSERT_TRUE(FmtIsExp(
113                     conv_prim_desc.bias_primitive_desc().desc().data.format,
114                     p.bias_fmt.exp));
115         }
116         ASSERT_TRUE(
117                 FmtIsExp(conv_prim_desc.dst_primitive_desc().desc().data.format,
118                         p.dst_fmt.exp));
119     }
120 };
121
122 using conv_any_fmt_test_float = convolution_any_fmt_test<float>;
123 using conv_any_fmt_test_params_float = conv_any_fmt_test_params;
124
125 TEST_P(conv_any_fmt_test_float, TestsConvolutionAnyFmt)
126 {
127 }
128 #define ENGINE engine::kind::cpu
129 #define ALG algorithm::convolution_direct
130 #define PROP_KIND prop_kind::forward
131
132 #define ANY_X { fmt::any,    \
133               { fmt::x, fmt::format_undef, fmt::format_undef } }
134 #define ANY_NCHW { fmt::any, \
135               { fmt::nchw, fmt::format_undef, fmt::format_undef } }
136 #define ANY_OIHW { fmt::any, \
137                  { fmt::oihw, fmt::format_undef, fmt::format_undef } }
138
139 #define ANY_OHWIxO { fmt::any,   \
140                    { fmt::Ohwi8o, fmt::Ohwi16o, fmt::Oihw16o } }
141 #define ANY_NCHWxC { fmt::any,   \
142                    { fmt::nChw8c, fmt::nChw16c, fmt::format_undef } }
143 #define ANY_OIHWxIxO { fmt::any, \
144                      { fmt::OIhw8i8o, fmt::OIhw16i16o, fmt::format_undef } }
145 #define ANY_GOIHWxIxO { fmt::any,\
146                       { fmt::gOIhw8i8o, fmt::gOIhw16i16o, fmt::format_undef } }
147
148 //INSTANTIATE_TEST_CASE_P(TestConvolutionAnyFmtForward, conv_any_fmt_test_float,
149 //    ::testing::Values(conv_any_fmt_test_params_float{ PROP_KIND, ENGINE, ALG,
150 //    ANY_NCHW, ANY_OIHW, ANY_X, ANY_NCHW,
151 //    { 2, 1, 4, 4, 4, 6, 4, 4, 3, 3, 1, 1, 1, 1 } }));
152
153 INSTANTIATE_TEST_CASE_P(
154         TestConvolutionAlexnetAnyFmtForwardxlocked, conv_any_fmt_test_float,
155         ::testing::Values(
156                 conv_any_fmt_test_params_float{ PROP_KIND, ENGINE, ALG,
157                         ANY_NCHW, ANY_OHWIxO, ANY_X, ANY_NCHWxC,
158                         { 2, 1, 3, 227, 227, 96, 55, 55, 11, 11, 0, 0, 4, 4 } },
159                 conv_any_fmt_test_params_float{ PROP_KIND, ENGINE, ALG,
160                         ANY_NCHWxC, ANY_GOIHWxIxO, ANY_X, ANY_NCHWxC,
161                         { 2, 2, 96, 27, 27, 256, 27, 27, 5, 5, 2, 2, 1, 1 } },
162                 conv_any_fmt_test_params_float{ PROP_KIND, ENGINE, ALG,
163                         ANY_NCHWxC, ANY_OIHWxIxO, ANY_X, ANY_NCHWxC,
164                         { 2, 1, 256, 13, 13, 384, 13, 13, 3, 3, 1, 1, 1, 1 } },
165                 conv_any_fmt_test_params_float{ PROP_KIND, ENGINE, ALG,
166                         ANY_NCHWxC, ANY_GOIHWxIxO, ANY_X, ANY_NCHWxC,
167                         { 2, 2, 384, 13, 13, 384, 13, 13, 3, 3, 1, 1, 1, 1 } },
168                 conv_any_fmt_test_params_float{ PROP_KIND, ENGINE, ALG,
169                     ANY_NCHWxC, ANY_GOIHWxIxO, ANY_X, ANY_NCHWxC,
170                     { 2, 2, 384, 13, 13, 256, 13, 13, 3, 3, 1, 1, 1, 1 } }));
171 }