Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / gtests / test_convolution_backward_weights_common.hpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 #ifndef TEST_CONVOLUTION_BACKWARD_WEIGHTS_COMMON_H
17 #define TEST_CONVOLUTION_BACKWARD_WEIGHTS_COMMON_H
18
19 #include "mkldnn_test_common.hpp"
20 #include "gtest/gtest.h"
21
22 #include "mkldnn.hpp"
23
24 namespace mkldnn {
25
26 template <typename data_t_src, typename data_t_diff_dst,
27           typename data_t_diff_bias>
28 void compute_ref_conv_bwd_bias(const test_convolution_sizes_t &c,
29         const memory &diff_dst, const memory &diff_bias)
30 {
31     data_t_diff_bias *diff_bias_data
32         = (data_t_diff_bias *)diff_bias.get_data_handle();
33     data_t_diff_dst *diff_dst_data
34         = (data_t_diff_dst *)diff_dst.get_data_handle();
35
36     const memory::desc bias_d = diff_bias.get_primitive_desc().desc();
37     const memory::desc dst_d = diff_dst.get_primitive_desc().desc();
38
39     size_t padded_oc = dst_d.data.layout_desc.blocking.padding_dims[1];
40
41     mkldnn::impl::parallel_nd(c.ng, c.oc / c.ng, [&](int g, int oc) {
42         size_t bidx = g * padded_oc / c.ng + oc;
43         diff_bias_data[map_index(bias_d, bidx)] = 0.0;
44         for (int mb = 0; mb < c.mb; ++mb) {
45             for (int oh = 0; oh < c.oh; ++oh) {
46                 for (int ow = 0; ow < c.ow; ++ow) {
47                     size_t oidx = mb * padded_oc * c.oh * c.ow
48                             + g * padded_oc / c.ng * c.oh * c.ow
49                             + oc * c.oh * c.ow + oh * c.ow + ow;
50                     diff_bias_data[map_index(bias_d, bidx)]
51                         += diff_dst_data[map_index(dst_d, oidx)];
52                 }
53             }
54         }
55     });
56 }
57
58 template <typename data_t_src, typename data_t_diff_dst,
59           typename data_t_diff_weights>
60 void compute_ref_conv_bwd_weights(const test_convolution_sizes_t &c,
61         const memory &src, const memory &diff_dst, const memory &diff_weights)
62 {
63     data_t_src *src_data = (data_t_src *)src.get_data_handle();
64     data_t_diff_weights *diff_weights_data
65         = (data_t_diff_weights *)diff_weights.get_data_handle();
66     data_t_diff_dst *diff_dst_data
67         = (data_t_diff_dst *)diff_dst.get_data_handle();
68
69     const memory::desc src_d = src.get_primitive_desc().desc();
70     const memory::desc weights_d = diff_weights.get_primitive_desc().desc();
71     const memory::desc dst_d = diff_dst.get_primitive_desc().desc();
72
73     size_t padded_ic = src_d.data.layout_desc.blocking.padding_dims[1];
74     size_t padded_oc = dst_d.data.layout_desc.blocking.padding_dims[1];
75
76     mkldnn::impl::parallel_nd(c.ng, c.oc / c.ng, c.ic / c.ng, c.kh, c.kw,
77         [&](int g, int oc, int ic, int kh, int kw) {
78         size_t widx = g * padded_oc / c.ng * padded_ic / c.ng * c.kh * c.kw
79                 + oc * padded_ic / c.ng * c.kh * c.kw
80                 + ic * c.kh * c.kw + kh * c.kw + kw;
81         diff_weights_data[map_index(weights_d, widx)] = 0.0;
82         for (int mb = 0; mb < c.mb; ++mb) {
83             for (int oh = 0; oh < c.oh; ++oh) {
84                 for (int ow = 0; ow < c.ow; ++ow) {
85                     if (ow*c.strw + kw * (1 + c.dilw) < c.padw ||
86                         oh*c.strh + kh * (1 + c.dilh) < c.padh ||
87                         ow*c.strw + kw * (1 + c.dilw) >= c.iw + c.padw ||
88                         oh*c.strh + kh * (1 + c.dilh)>= c.ih + c.padh)
89                         continue;
90
91                     int ih = oh * c.strh - c.padh + kh
92                             * (1 + c.dilh);
93                     int iw = ow * c.strw - c.padw + kw
94                             * (1 + c.dilw);
95                     size_t sidx = mb * padded_ic * c.ih * c.iw
96                         + g * padded_ic / c.ng * c.ih * c.iw
97                         + ic * c.ih * c.iw + ih * c.iw + iw;
98                     size_t didx = mb * padded_oc * c.oh * c.ow
99                         + g * padded_oc / c.ng * c.oh * c.ow
100                         + oc * c.oh * c.ow + oh * c.ow + ow;
101
102                     diff_weights_data[map_index(weights_d, widx)]
103                         += src_data[map_index(src_d, sidx)]
104                         * diff_dst_data[map_index(dst_d, didx)];
105                 }
106             }
107         }
108     });
109 }
110
111 template <typename data_t_src, typename data_t_diff_dst,
112           typename data_t_diff_weights, typename data_t_diff_bias>
113 class convolution_backward_weights_test
114             : public ::testing::TestWithParam<test_convolution_params_t> {
115 protected:
116     virtual void SetUp() {
117         auto p = ::testing::TestWithParam<test_convolution_params_t>::GetParam();
118         catch_expected_failures([=](){Test();}, p.expect_to_fail,
119                     p.expected_status);
120     }
121
122     void Test() {
123         auto p = ::testing::TestWithParam<test_convolution_params_t>::GetParam();
124
125         ASSERT_TRUE(p.engine_kind == engine::kind::cpu);
126         ASSERT_EQ(p.aalgorithm, convolution_direct);
127         auto eng = engine(p.engine_kind, 0);
128         memory::data_type data_type_src = data_traits<data_t_src>::data_type;
129         memory::data_type data_type_diff_dst
130             = data_traits<data_t_diff_dst>::data_type;
131         memory::data_type data_type_diff_weights
132             = data_traits<data_t_diff_weights>::data_type;
133         memory::data_type data_type_diff_bias
134             = data_traits<data_t_diff_bias>::data_type;
135
136         test_convolution_sizes_t cd = p.sizes;
137
138         auto c_src_desc = create_md({ cd.mb, cd.ic, cd.ih, cd.iw },
139             data_type_src, p.formats.src_format);
140         auto c_diff_weights_desc = cd.ng > 1
141             ? create_md({ cd.ng, cd.oc / cd.ng, cd.ic / cd.ng, cd.kh, cd.kw },
142               data_type_diff_weights, p.formats.weights_format)
143             : create_md({ cd.oc, cd.ic, cd.kh, cd.kw }, data_type_diff_weights,
144               p.formats.weights_format);
145         auto c_diff_bias_desc = create_md({ cd.oc }, data_type_diff_bias,
146             p.formats.bias_format);
147         auto c_diff_dst_desc = create_md({ cd.mb, cd.oc, cd.oh, cd.ow },
148             data_type_diff_dst, p.formats.dst_format);
149         auto c_weights_desc_f = cd.ng > 1
150             ? create_md({ cd.ng, cd.oc / cd.ng, cd.ic / cd.ng, cd.kh, cd.kw },
151               data_type_diff_dst, p.formats.weights_format)
152             : create_md({ cd.oc, cd.ic, cd.kh, cd.kw }, data_type_diff_dst,
153               p.formats.weights_format);
154         auto c_dst_desc_f = create_md({ cd.mb, cd.oc, cd.oh, cd.ow },
155             data_type_diff_weights, p.formats.dst_format);
156         auto c_src = test_memory(c_src_desc, eng);
157         auto c_diff_weights = test_memory(c_diff_weights_desc, eng);
158         auto c_diff_bias = test_memory(c_diff_bias_desc, eng);
159         auto c_diff_dst = test_memory(c_diff_dst_desc, eng);
160         auto weights_primitive_desc_f = test_memory(c_weights_desc_f, eng);
161         auto dst_primitive_desc_f = test_memory(c_dst_desc_f, eng);
162         fill_data<data_t_diff_dst>(
163             c_diff_dst.get_size() / sizeof(data_t_diff_dst),
164             (data_t_diff_dst *)c_diff_dst.get().get_data_handle());
165         fill_data<data_t_src>(c_src.get_size() / sizeof(data_t_src),
166             (data_t_src *)c_src.get().get_data_handle());
167         fill_data<data_t_diff_weights>(
168             c_diff_weights.get_size() / sizeof(data_t_diff_weights),
169             (data_t_diff_weights *)c_diff_weights.get().get_data_handle());
170
171         check_zero_tail<data_t_diff_dst>(1, c_diff_dst.get());
172         check_zero_tail<data_t_src>(1, c_src.get());
173         check_zero_tail<data_t_diff_weights>(1, c_diff_weights.get());
174
175         std::vector<ptrdiff_t> padR = {
176             right_padding(cd.ih, cd.oh, cd.kh, cd.padh, cd.strh, cd.dilh),
177             right_padding(cd.iw, cd.ow, cd.kw, cd.padw, cd.strw, cd.dilw)
178         };
179
180         auto conv_desc = convolution_forward::desc(
181                 prop_kind::forward_training, p.aalgorithm, c_src_desc,
182                 c_weights_desc_f, c_diff_bias_desc, c_dst_desc_f,
183                 { cd.strh, cd.strw }, { cd.dilh, cd.dilw },
184                 { cd.padh, cd.padw }, padR, padding_kind::zero);
185
186         auto conv_bwd_weights_desc = convolution_backward_weights::desc(
187                 p.aalgorithm, c_src_desc, c_diff_weights_desc,
188                 c_diff_bias_desc, c_diff_dst_desc,
189                 { cd.strh, cd.strw }, { cd.dilh, cd.dilw },
190                 { cd.padh, cd.padw }, padR, padding_kind::zero);
191
192         auto conv_primitive_desc = convolution_forward::primitive_desc(
193                 conv_desc, eng);
194
195         auto conv_bwd_weights_primitive_desc =
196             convolution_backward_weights::primitive_desc(
197                     conv_bwd_weights_desc, eng, conv_primitive_desc);
198
199         auto conv_bwd_weights =
200             convolution_backward_weights(conv_bwd_weights_primitive_desc,
201                     c_src.get(), c_diff_dst.get(), c_diff_weights.get(),
202                     c_diff_bias.get());
203
204         std::vector<primitive> pipeline;
205         pipeline.push_back(conv_bwd_weights);
206         stream(stream::kind::lazy).submit(pipeline).wait();
207
208         auto ref_diff_weights = memory({c_diff_weights_desc, eng});
209         auto ref_diff_bias = memory({c_diff_bias_desc, eng});
210
211         compute_ref_conv_bwd_weights<data_t_src, data_t_diff_dst,
212             data_t_diff_weights>(cd, c_src.get(), c_diff_dst.get(),
213                     ref_diff_weights);
214         check_zero_tail<data_t_diff_weights>(1, ref_diff_weights);
215         compare_data<data_t_diff_weights>(ref_diff_weights,
216                 c_diff_weights.get());
217         check_zero_tail<data_t_diff_weights>(1, c_diff_weights.get());
218
219         compute_ref_conv_bwd_bias<data_t_src, data_t_diff_dst,
220             data_t_diff_bias>(cd, c_diff_dst.get(), ref_diff_bias);
221
222         compare_data<data_t_diff_bias>(ref_diff_bias, c_diff_bias.get());
223     }
224 };
225
226 }
227 #endif