updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / gtests / test_convolution_format_any.cpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "mkldnn_test_common.hpp"
18 #include "gtest/gtest.h"
19 #include "cpu_isa_traits.hpp"
20
21 #include "mkldnn.hpp"
22 namespace mkldnn {
23
24 using fmt = memory::format;
25
26 #define EXP_VALS_NUM 3
27 struct fmt_compare {
28     fmt in;
29     fmt exp[EXP_VALS_NUM];
30 };
31 struct conv_any_fmt_test_params {
32     prop_kind aprop_kind;
33     const engine::kind engine_kind;
34     algorithm aalgorithm;
35     fmt_compare src_fmt;
36     fmt_compare weights_fmt;
37     fmt_compare bias_fmt;
38     fmt_compare dst_fmt;
39     test_convolution_sizes_t test_cd;
40 };
41
42 template <typename data_t>
43 class convolution_any_fmt_test
44         : public ::testing::TestWithParam<conv_any_fmt_test_params> {
45 protected:
46     virtual bool FmtIsExp(const mkldnn_memory_format_t in, fmt *exp ) {
47         for (int i = 0; i < EXP_VALS_NUM; i++)
48             if (in == exp[i])
49                 return true;
50         return false;
51     }
52     virtual void SetUp()
53     {
54         // Skip this test if the library cannot select blocked format a priori.
55         // Currently blocking is supported only for sse42 and later CPUs.
56         bool implementation_supports_blocking
57             = impl::cpu::mayiuse(impl::cpu::sse42);
58         if (!implementation_supports_blocking) return;
59
60         conv_any_fmt_test_params p = ::testing::
61                 TestWithParam<conv_any_fmt_test_params>::GetParam();
62
63         ASSERT_TRUE(p.engine_kind == engine::kind::cpu);
64         ASSERT_EQ(p.aprop_kind, prop_kind::forward);
65         ASSERT_EQ(p.aalgorithm, algorithm::convolution_direct);
66         auto eng = engine(p.engine_kind, 0);
67         memory::data_type data_type = data_traits<data_t>::data_type;
68         ASSERT_EQ(data_type, mkldnn::memory::data_type::f32);
69
70         // Some format chekers
71         ASSERT_NE(p.src_fmt.exp[0], fmt::any);
72         ASSERT_NE(p.weights_fmt.exp[0], fmt::any);
73         ASSERT_NE(p.bias_fmt.exp[0], fmt::any);
74         ASSERT_NE(p.dst_fmt.exp[0], fmt::any);
75         ASSERT_TRUE(
76                 p.src_fmt.in == fmt::any || p.src_fmt.in == p.src_fmt.exp[0]);
77         ASSERT_TRUE(p.weights_fmt.in == fmt::any
78                 || p.weights_fmt.in == p.weights_fmt.exp[0]);
79         ASSERT_TRUE(p.bias_fmt.in == fmt::any
80                 || p.bias_fmt.in == p.bias_fmt.exp[0]);
81         ASSERT_TRUE(
82                 p.dst_fmt.in == fmt::any || p.dst_fmt.in == p.dst_fmt.exp[0]);
83
84         test_convolution_sizes_t cd = p.test_cd;
85
86         auto c_src_desc = create_md(
87                 { cd.mb, cd.ic, cd.ih, cd.iw }, data_type, p.src_fmt.in);
88         auto c_weights_desc = cd.ng > 1 ?
89                 create_md({ cd.ng, cd.oc / cd.ng, cd.ic / cd.ng, cd.kh, cd.kw },
90                         data_type, p.weights_fmt.in) :
91                 create_md({ cd.oc, cd.ic, cd.kh, cd.kw }, data_type,
92                         p.weights_fmt.in);
93         auto c_dst_desc = create_md(
94                 { cd.mb, cd.oc, cd.oh, cd.ow }, data_type, p.dst_fmt.in);
95
96         bool with_bias = p.bias_fmt.in != fmt::format_undef;
97         auto c_bias_desc = with_bias ?
98                 create_md({ cd.oc }, data_type, p.bias_fmt.in) :
99                 create_md({}, data_type, p.bias_fmt.in);
100
101         auto conv_desc = with_bias ?
102                 convolution_forward::desc(p.aprop_kind, p.aalgorithm, c_src_desc,
103                         c_weights_desc, c_bias_desc, c_dst_desc,
104                         { cd.strh, cd.strw }, { cd.padh, cd.padw }, { cd.padh, cd.padw },
105                         padding_kind::zero) :
106                 convolution_forward::desc(p.aprop_kind, p.aalgorithm, c_src_desc,
107                         c_weights_desc, c_dst_desc, { cd.strh, cd.strw }, { cd.strh, cd.strw },
108                         { cd.padh, cd.padw }, padding_kind::zero);
109
110         auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, eng);
111
112         ASSERT_TRUE(
113                 FmtIsExp(conv_prim_desc.src_primitive_desc().desc().data.format,
114                         p.src_fmt.exp));
115         ASSERT_TRUE(FmtIsExp(
116                 conv_prim_desc.weights_primitive_desc().desc().data.format,
117                 p.weights_fmt.exp));
118         if (with_bias) {
119             ASSERT_TRUE(FmtIsExp(
120                     conv_prim_desc.bias_primitive_desc().desc().data.format,
121                     p.bias_fmt.exp));
122         }
123         ASSERT_TRUE(
124                 FmtIsExp(conv_prim_desc.dst_primitive_desc().desc().data.format,
125                         p.dst_fmt.exp));
126     }
127 };
128
129 using conv_any_fmt_test_float = convolution_any_fmt_test<float>;
130 using conv_any_fmt_test_params_float = conv_any_fmt_test_params;
131
132 TEST_P(conv_any_fmt_test_float, TestsConvolutionAnyFmt)
133 {
134 }
135 #define ENGINE engine::kind::cpu
136 #define ALG algorithm::convolution_direct
137 #define PROP_KIND prop_kind::forward
138
139 #define ANY_X { fmt::any,    \
140               { fmt::x, fmt::format_undef, fmt::format_undef } }
141 #define ANY_NCHW { fmt::any, \
142               { fmt::nchw, fmt::format_undef, fmt::format_undef } }
143 #define ANY_OIHW { fmt::any, \
144                  { fmt::oihw, fmt::format_undef, fmt::format_undef } }
145
146 #define ANY_OHWIxO { fmt::any,   \
147                    { fmt::Ohwi8o, fmt::Ohwi16o, fmt::Oihw16o } }
148 #define ANY_NCHWxC { fmt::any,   \
149                    { fmt::nChw8c, fmt::nChw16c, fmt::format_undef } }
150 #define ANY_OIHWxIxO { fmt::any, \
151                      { fmt::OIhw8i8o, fmt::OIhw16i16o, fmt::format_undef } }
152 #define ANY_GOIHWxIxO { fmt::any,\
153                       { fmt::gOIhw8i8o, fmt::gOIhw16i16o, fmt::format_undef } }
154
155 //INSTANTIATE_TEST_CASE_P(TestConvolutionAnyFmtForward, conv_any_fmt_test_float,
156 //    ::testing::Values(conv_any_fmt_test_params_float{ PROP_KIND, ENGINE, ALG,
157 //    ANY_NCHW, ANY_OIHW, ANY_X, ANY_NCHW,
158 //    { 2, 1, 4, 4, 4, 6, 4, 4, 3, 3, 1, 1, 1, 1 } }));
159
160 INSTANTIATE_TEST_CASE_P(
161         TestConvolutionAlexnetAnyFmtForwardxlocked, conv_any_fmt_test_float,
162         ::testing::Values(
163                 conv_any_fmt_test_params_float{ PROP_KIND, ENGINE, ALG,
164                         ANY_NCHW, ANY_OHWIxO, ANY_X, ANY_NCHWxC,
165                         { 2, 1, 3, 227, 227, 96, 55, 55, 11, 11, 0, 0, 4, 4 } },
166                 conv_any_fmt_test_params_float{ PROP_KIND, ENGINE, ALG,
167                         ANY_NCHWxC, ANY_GOIHWxIxO, ANY_X, ANY_NCHWxC,
168                         { 2, 2, 96, 27, 27, 256, 27, 27, 5, 5, 2, 2, 1, 1 } },
169                 conv_any_fmt_test_params_float{ PROP_KIND, ENGINE, ALG,
170                         ANY_NCHWxC, ANY_OIHWxIxO, ANY_X, ANY_NCHWxC,
171                         { 2, 1, 256, 13, 13, 384, 13, 13, 3, 3, 1, 1, 1, 1 } },
172                 conv_any_fmt_test_params_float{ PROP_KIND, ENGINE, ALG,
173                         ANY_NCHWxC, ANY_GOIHWxIxO, ANY_X, ANY_NCHWxC,
174                         { 2, 2, 384, 13, 13, 384, 13, 13, 3, 3, 1, 1, 1, 1 } },
175                 conv_any_fmt_test_params_float{ PROP_KIND, ENGINE, ALG,
176                     ANY_NCHWxC, ANY_GOIHWxIxO, ANY_X, ANY_NCHWxC,
177                     { 2, 2, 384, 13, 13, 256, 13, 13, 3, 3, 1, 1, 1, 1 } }));
178 }