Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / gtests / test_convolution_forward_common.hpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef TEST_CONVOLUTION_FORWARD_COMMON_H
18 #define TEST_CONVOLUTION_FORWARD_COMMON_H
19
20 #include "mkldnn_test_common.hpp"
21 #include "gtest/gtest.h"
22
23 #include "mkldnn.hpp"
24 #include <stdint.h>
25
26 #include <math.h>
27
28 namespace mkldnn {
29
30 template <typename data_t_src, typename data_t_wei,
31           typename data_t_acc, typename data_t_dst>
32 void compute_ref_conv_fwd(const test_convolution_sizes_t &c,
33         const test_convolution_attr_t &attr,
34         const memory::desc &src_d,
35         const memory::desc &weights_d,
36         const memory::desc &bias_d,
37         const memory::desc &dst_d,
38         const memory &src,
39         const memory &weights,
40         const memory &bias,
41         const memory &dst)
42 {
43     const bool w_bias = bias_d.data.format != memory::format::format_undef;
44     data_t_src *src_data = (data_t_src *)src.get_data_handle();
45     data_t_wei *weights_data = (data_t_wei *)weights.get_data_handle();
46
47     data_t_dst *bias_data = w_bias ? (data_t_dst *)bias.get_data_handle() : nullptr;
48     data_t_dst *dst_data = (data_t_dst *)dst.get_data_handle();
49
50     size_t padded_ic = src_d.data.layout_desc.blocking.padding_dims[1];
51     size_t padded_oc = dst_d.data.layout_desc.blocking.padding_dims[1];
52
53     size_t padded_ic_w = weights_d.data.format == mkldnn_OhIw8o4i ? weights_d.data.layout_desc.blocking.padding_dims[1] :
54                                                                     src_d.data.layout_desc.blocking.padding_dims[1];
55     size_t padded_oc_w = weights_d.data.format == mkldnn_OhIw8o4i ? weights_d.data.layout_desc.blocking.padding_dims[0] :
56                                                                     dst_d.data.layout_desc.blocking.padding_dims[1];
57
58     mkldnn::impl::parallel_nd(c.mb, c.ng, c.oc / c.ng, c.oh, c.ow,
59         [&](int n, int g, int oc, int oh, int ow) {
60             data_t_acc a = 0;
61             for (int ic = 0; ic < c.ic / c.ng; ic++) {
62                 for (int kh = 0; kh < c.kh; kh++) {
63                     for (int kw = 0; kw < c.kw; kw++) {
64                         int iw = ow * c.strw
65                               - c.padw + kw * (1 + c.dilw);
66                         int ih = oh * c.strh
67                               - c.padh + kh * (1 + c.dilh);
68                         if (iw < 0 || iw >= c.iw) continue;
69                         if (ih < 0 || ih >= c.ih) continue;
70                         size_t iidx = n * padded_ic * c.ih * c.iw
71                             + g * padded_ic / c.ng * c.ih * c.iw
72                             + ic * c.ih * c.iw + ih * c.iw + iw;
73                         size_t widx = g * padded_oc_w / c.ng * padded_ic_w
74                             / c.ng * c.kh * c.kw
75                             + oc * padded_ic_w / c.ng * c.kh * c.kw
76                             + ic * c.kh * c.kw + kh * c.kw + kw;
77
78                         int iidx_ = map_index(src_d, iidx);
79                         int widx_ = map_index(weights_d, widx);
80
81                         a += ((data_t_acc)
82                             src_data[iidx_]
83                             *  weights_data[widx_]);
84                     }
85                 }
86             }
87
88             float a_fp = (float)a;
89
90             a_fp += (float)(bias_data
91                 ?  bias_data[map_index(bias_d, g * c.oc / c.ng + oc)] : 0);
92
93             if (attr.oscale.is_def()) {
94                 const auto &s = attr.oscale;
95                 using P = test_convolution_attr_t::scale_t;
96                 if (s.policy == P::policy_t::COMMON) {
97                     a_fp *= s.scale;
98                 }
99             }
100
101             using D = memory::data_type;
102             if (data_traits<data_t_dst>::data_type != D::f32){
103                 using R = mkldnn::round_mode;
104                 switch (attr.rmode) {
105                     case R::round_down: a_fp = floorf(a_fp); break;
106                     case R::round_nearest: a_fp = nearbyintf(a_fp); break;
107                 }
108             }
109
110             size_t oidx = n * padded_oc * c.oh * c.ow
111                      + g * padded_oc / c.ng * c.oh * c.ow
112                      + oc * c.oh * c.ow + oh * c.ow + ow;
113             dst_data[map_index(dst_d, oidx)] = (data_t_dst)a_fp;
114         }
115     );
116 }
117
118 template <typename data_t_src, typename data_t_wei,
119           typename data_t_acc, typename data_t_dst>
120 class convolution_forward_test
121         : public ::testing::TestWithParam<test_convolution_params_t> {
122 protected:
123     virtual void SetUp() {
124         auto p = ::testing::TestWithParam<test_convolution_params_t>::GetParam();
125         catch_expected_failures([=](){Test();}, p.expect_to_fail,
126                     p.expected_status);
127     }
128
129     void Test() {
130         auto p = ::testing::TestWithParam<test_convolution_params_t>::GetParam();
131         ASSERT_TRUE(p.engine_kind == engine::kind::cpu);
132         ASSERT_EQ(p.aalgorithm, algorithm::convolution_direct);
133         auto eng = engine(p.engine_kind, 0);
134
135         memory::data_type data_type_src = data_traits<data_t_src>::data_type;
136         memory::data_type data_type_dst = data_traits<data_t_dst>::data_type;
137         memory::data_type data_type_wei = data_traits<data_t_wei>::data_type;
138
139         test_convolution_sizes_t cd = p.sizes;
140
141         test_convolution_attr_t attr = p.attr;
142         attr.mkldnn_attr_recreate();
143
144         auto aprop_kind = prop_kind::forward;
145         bool with_bias = p.formats.bias_format != memory::format::format_undef;
146
147         auto c_src_desc = create_md({ cd.mb, cd.ic, cd.ih, cd.iw },
148             data_type_src, p.formats.src_format);
149         auto c_weights_desc = cd.ng > 1 ?
150                 create_md({ cd.ng, cd.oc / cd.ng, cd.ic / cd.ng, cd.kh, cd.kw },
151                         data_type_wei, p.formats.weights_format) :
152                 create_md({ cd.oc, cd.ic, cd.kh, cd.kw },
153                         data_type_wei,p.formats.weights_format);
154         auto c_dst_desc = create_md({ cd.mb, cd.oc, cd.oh, cd.ow },
155                 data_type_dst, p.formats.dst_format);
156         auto c_bias_desc = with_bias ?
157                 create_md({ cd.oc }, data_type_dst, p.formats.bias_format) :
158                 create_md({}, data_type_dst, p.formats.bias_format);
159
160         auto c_src = test_memory(c_src_desc, eng);
161         auto c_weights = test_memory(c_weights_desc, eng);
162         auto c_bias = test_memory(c_bias_desc, eng);
163         auto c_dst = test_memory(c_dst_desc, eng);
164
165         std::vector<data_t_dst> ref_dst_data(c_dst.get_size());
166
167         // Only true for dense format
168         fill_data<data_t_dst>(c_dst.get_size() / sizeof(data_t_dst),
169                 (data_t_dst *)c_dst.get().get_data_handle());
170         fill_data<data_t_src>(c_src.get_size() / sizeof(data_t_src),
171                 (data_t_src *)c_src.get().get_data_handle());
172         fill_data<data_t_wei>(c_weights.get_size() / sizeof(data_t_wei),
173                 (data_t_wei *)c_weights.get().get_data_handle());
174         if (with_bias) {
175             fill_data<data_t_dst>(c_bias.get_size() / sizeof(data_t_dst),
176                     (data_t_dst *)c_bias.get().get_data_handle());
177         }
178         check_zero_tail<data_t_src>(1, c_src.get());
179         check_zero_tail<data_t_wei>(1, c_weights.get());
180         check_zero_tail<data_t_dst>(1, c_dst.get());
181
182         std::vector<ptrdiff_t> padR = {
183             right_padding(cd.ih, cd.oh, cd.kh, cd.padh, cd.strh, cd.dilh),
184             right_padding(cd.iw, cd.ow, cd.kw, cd.padw, cd.strw, cd.dilw)
185         };
186
187         auto conv_desc = with_bias
188             ? convolution_forward::desc(aprop_kind, p.aalgorithm,
189                     c_src_desc, c_weights_desc, c_bias_desc, c_dst_desc,
190                     { cd.strh, cd.strw }, { cd.dilh, cd.dilw },
191                     { cd.padh, cd.padw }, padR, padding_kind::zero)
192             : convolution_forward::desc(aprop_kind, p.aalgorithm,
193                     c_src_desc, c_weights_desc, c_dst_desc,
194                     { cd.strh, cd.strw }, { cd.dilh, cd.dilw },
195                     { cd.padh, cd.padw }, padR, padding_kind::zero);
196
197         auto conv_primitive_desc = convolution_forward::primitive_desc(
198                 conv_desc, attr.mkl_attr, eng);
199
200         auto conv = with_bias ?
201             convolution_forward(conv_primitive_desc, c_src.get(),
202                     c_weights.get(), c_bias.get(), c_dst.get()) :
203             convolution_forward(conv_primitive_desc, c_src.get(),
204                     c_weights.get(), c_dst.get());
205
206         std::vector<primitive> pipeline;
207         pipeline.push_back(conv);
208         auto s = stream(stream::kind::lazy);
209         s.submit(pipeline).wait();
210
211         auto ref_memory = memory(memory::primitive_desc(c_dst_desc, eng),
212                 &ref_dst_data[0]);
213         compute_ref_conv_fwd<data_t_src,data_t_wei,data_t_acc,data_t_dst>(
214                 cd, attr, c_src_desc, c_weights_desc, c_bias_desc, c_dst_desc,
215                 c_src.get(), c_weights.get(), c_bias.get(), ref_memory);
216         check_zero_tail<data_t_dst>(1, ref_memory);
217
218         compare_data<data_t_dst>(ref_memory, c_dst.get());
219         check_zero_tail<data_t_dst>(0, c_dst.get());
220     }
221 };
222
223 }
224 #endif