1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 #ifndef TEST_CONVOLUTION_BACKWARD_WEIGHTS_COMMON_H
17 #define TEST_CONVOLUTION_BACKWARD_WEIGHTS_COMMON_H
19 #include "mkldnn_test_common.hpp"
20 #include "gtest/gtest.h"
26 template <typename data_t_src, typename data_t_diff_dst,
27 typename data_t_diff_bias>
28 void compute_ref_conv_bwd_bias(const test_convolution_sizes_t &c,
29 const memory &diff_dst, const memory &diff_bias)
31 data_t_diff_bias *diff_bias_data
32 = (data_t_diff_bias *)diff_bias.get_data_handle();
33 data_t_diff_dst *diff_dst_data
34 = (data_t_diff_dst *)diff_dst.get_data_handle();
36 const memory::desc bias_d = diff_bias.get_primitive_desc().desc();
37 const memory::desc dst_d = diff_dst.get_primitive_desc().desc();
39 size_t padded_oc = dst_d.data.layout_desc.blocking.padding_dims[1];
41 mkldnn::impl::parallel_nd(c.ng, c.oc / c.ng, [&](int g, int oc) {
42 size_t bidx = g * padded_oc / c.ng + oc;
43 diff_bias_data[map_index(bias_d, bidx)] = 0.0;
44 for (int mb = 0; mb < c.mb; ++mb) {
45 for (int oh = 0; oh < c.oh; ++oh) {
46 for (int ow = 0; ow < c.ow; ++ow) {
47 size_t oidx = mb * padded_oc * c.oh * c.ow
48 + g * padded_oc / c.ng * c.oh * c.ow
49 + oc * c.oh * c.ow + oh * c.ow + ow;
50 diff_bias_data[map_index(bias_d, bidx)]
51 += diff_dst_data[map_index(dst_d, oidx)];
58 template <typename data_t_src, typename data_t_diff_dst,
59 typename data_t_diff_weights>
60 void compute_ref_conv_bwd_weights(const test_convolution_sizes_t &c,
61 const memory &src, const memory &diff_dst, const memory &diff_weights)
63 data_t_src *src_data = (data_t_src *)src.get_data_handle();
64 data_t_diff_weights *diff_weights_data
65 = (data_t_diff_weights *)diff_weights.get_data_handle();
66 data_t_diff_dst *diff_dst_data
67 = (data_t_diff_dst *)diff_dst.get_data_handle();
69 const memory::desc src_d = src.get_primitive_desc().desc();
70 const memory::desc weights_d = diff_weights.get_primitive_desc().desc();
71 const memory::desc dst_d = diff_dst.get_primitive_desc().desc();
73 size_t padded_ic = src_d.data.layout_desc.blocking.padding_dims[1];
74 size_t padded_oc = dst_d.data.layout_desc.blocking.padding_dims[1];
76 mkldnn::impl::parallel_nd(c.ng, c.oc / c.ng, c.ic / c.ng, c.kh, c.kw,
77 [&](int g, int oc, int ic, int kh, int kw) {
78 size_t widx = g * padded_oc / c.ng * padded_ic / c.ng * c.kh * c.kw
79 + oc * padded_ic / c.ng * c.kh * c.kw
80 + ic * c.kh * c.kw + kh * c.kw + kw;
81 diff_weights_data[map_index(weights_d, widx)] = 0.0;
82 for (int mb = 0; mb < c.mb; ++mb) {
83 for (int oh = 0; oh < c.oh; ++oh) {
84 for (int ow = 0; ow < c.ow; ++ow) {
85 if (ow*c.strw + kw * (1 + c.dilw) < c.padw ||
86 oh*c.strh + kh * (1 + c.dilh) < c.padh ||
87 ow*c.strw + kw * (1 + c.dilw) >= c.iw + c.padw ||
88 oh*c.strh + kh * (1 + c.dilh)>= c.ih + c.padh)
91 int ih = oh * c.strh - c.padh + kh
93 int iw = ow * c.strw - c.padw + kw
95 size_t sidx = mb * padded_ic * c.ih * c.iw
96 + g * padded_ic / c.ng * c.ih * c.iw
97 + ic * c.ih * c.iw + ih * c.iw + iw;
98 size_t didx = mb * padded_oc * c.oh * c.ow
99 + g * padded_oc / c.ng * c.oh * c.ow
100 + oc * c.oh * c.ow + oh * c.ow + ow;
102 diff_weights_data[map_index(weights_d, widx)]
103 += src_data[map_index(src_d, sidx)]
104 * diff_dst_data[map_index(dst_d, didx)];
111 template <typename data_t_src, typename data_t_diff_dst,
112 typename data_t_diff_weights, typename data_t_diff_bias>
113 class convolution_backward_weights_test
114 : public ::testing::TestWithParam<test_convolution_params_t> {
116 virtual void SetUp() {
117 auto p = ::testing::TestWithParam<test_convolution_params_t>::GetParam();
118 catch_expected_failures([=](){Test();}, p.expect_to_fail,
123 auto p = ::testing::TestWithParam<test_convolution_params_t>::GetParam();
125 ASSERT_TRUE(p.engine_kind == engine::kind::cpu);
126 ASSERT_EQ(p.aalgorithm, convolution_direct);
127 auto eng = engine(p.engine_kind, 0);
128 memory::data_type data_type_src = data_traits<data_t_src>::data_type;
129 memory::data_type data_type_diff_dst
130 = data_traits<data_t_diff_dst>::data_type;
131 memory::data_type data_type_diff_weights
132 = data_traits<data_t_diff_weights>::data_type;
133 memory::data_type data_type_diff_bias
134 = data_traits<data_t_diff_bias>::data_type;
136 test_convolution_sizes_t cd = p.sizes;
138 auto c_src_desc = create_md({ cd.mb, cd.ic, cd.ih, cd.iw },
139 data_type_src, p.formats.src_format);
140 auto c_diff_weights_desc = cd.ng > 1
141 ? create_md({ cd.ng, cd.oc / cd.ng, cd.ic / cd.ng, cd.kh, cd.kw },
142 data_type_diff_weights, p.formats.weights_format)
143 : create_md({ cd.oc, cd.ic, cd.kh, cd.kw }, data_type_diff_weights,
144 p.formats.weights_format);
145 auto c_diff_bias_desc = create_md({ cd.oc }, data_type_diff_bias,
146 p.formats.bias_format);
147 auto c_diff_dst_desc = create_md({ cd.mb, cd.oc, cd.oh, cd.ow },
148 data_type_diff_dst, p.formats.dst_format);
149 auto c_weights_desc_f = cd.ng > 1
150 ? create_md({ cd.ng, cd.oc / cd.ng, cd.ic / cd.ng, cd.kh, cd.kw },
151 data_type_diff_dst, p.formats.weights_format)
152 : create_md({ cd.oc, cd.ic, cd.kh, cd.kw }, data_type_diff_dst,
153 p.formats.weights_format);
154 auto c_dst_desc_f = create_md({ cd.mb, cd.oc, cd.oh, cd.ow },
155 data_type_diff_weights, p.formats.dst_format);
156 auto c_src = test_memory(c_src_desc, eng);
157 auto c_diff_weights = test_memory(c_diff_weights_desc, eng);
158 auto c_diff_bias = test_memory(c_diff_bias_desc, eng);
159 auto c_diff_dst = test_memory(c_diff_dst_desc, eng);
160 auto weights_primitive_desc_f = test_memory(c_weights_desc_f, eng);
161 auto dst_primitive_desc_f = test_memory(c_dst_desc_f, eng);
162 fill_data<data_t_diff_dst>(
163 c_diff_dst.get_size() / sizeof(data_t_diff_dst),
164 (data_t_diff_dst *)c_diff_dst.get().get_data_handle());
165 fill_data<data_t_src>(c_src.get_size() / sizeof(data_t_src),
166 (data_t_src *)c_src.get().get_data_handle());
167 fill_data<data_t_diff_weights>(
168 c_diff_weights.get_size() / sizeof(data_t_diff_weights),
169 (data_t_diff_weights *)c_diff_weights.get().get_data_handle());
171 check_zero_tail<data_t_diff_dst>(1, c_diff_dst.get());
172 check_zero_tail<data_t_src>(1, c_src.get());
173 check_zero_tail<data_t_diff_weights>(1, c_diff_weights.get());
175 std::vector<ptrdiff_t> padR = {
176 right_padding(cd.ih, cd.oh, cd.kh, cd.padh, cd.strh, cd.dilh),
177 right_padding(cd.iw, cd.ow, cd.kw, cd.padw, cd.strw, cd.dilw)
180 auto conv_desc = convolution_forward::desc(
181 prop_kind::forward_training, p.aalgorithm, c_src_desc,
182 c_weights_desc_f, c_diff_bias_desc, c_dst_desc_f,
183 { cd.strh, cd.strw }, { cd.dilh, cd.dilw },
184 { cd.padh, cd.padw }, padR, padding_kind::zero);
186 auto conv_bwd_weights_desc = convolution_backward_weights::desc(
187 p.aalgorithm, c_src_desc, c_diff_weights_desc,
188 c_diff_bias_desc, c_diff_dst_desc,
189 { cd.strh, cd.strw }, { cd.dilh, cd.dilw },
190 { cd.padh, cd.padw }, padR, padding_kind::zero);
192 auto conv_primitive_desc = convolution_forward::primitive_desc(
195 auto conv_bwd_weights_primitive_desc =
196 convolution_backward_weights::primitive_desc(
197 conv_bwd_weights_desc, eng, conv_primitive_desc);
199 auto conv_bwd_weights =
200 convolution_backward_weights(conv_bwd_weights_primitive_desc,
201 c_src.get(), c_diff_dst.get(), c_diff_weights.get(),
204 std::vector<primitive> pipeline;
205 pipeline.push_back(conv_bwd_weights);
206 stream(stream::kind::lazy).submit(pipeline).wait();
208 auto ref_diff_weights = memory({c_diff_weights_desc, eng});
209 auto ref_diff_bias = memory({c_diff_bias_desc, eng});
211 compute_ref_conv_bwd_weights<data_t_src, data_t_diff_dst,
212 data_t_diff_weights>(cd, c_src.get(), c_diff_dst.get(),
214 check_zero_tail<data_t_diff_weights>(1, ref_diff_weights);
215 compare_data<data_t_diff_weights>(ref_diff_weights,
216 c_diff_weights.get());
217 check_zero_tail<data_t_diff_weights>(1, c_diff_weights.get());
219 compute_ref_conv_bwd_bias<data_t_src, data_t_diff_dst,
220 data_t_diff_bias>(cd, c_diff_dst.get(), ref_diff_bias);
222 compare_data<data_t_diff_bias>(ref_diff_bias, c_diff_bias.get());