1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef TEST_CONVOLUTION_BACKWARD_DATA_COMMON_H
18 #define TEST_CONVOLUTION_BACKWARD_DATA_COMMON_H
20 #include "mkldnn_test_common.hpp"
21 #include "gtest/gtest.h"
27 template <typename data_t_diff_dst, typename data_t_wei,
28 typename data_t_acc, typename data_t_diff_src>
29 void compute_ref_conv_bwd_data(const test_convolution_sizes_t &c,
30 const memory &diff_src, const memory &weights, const memory &diff_dst)
32 data_t_diff_dst *diff_dst_data = (data_t_diff_dst *)diff_dst.get_data_handle();
33 data_t_wei *weights_data = (data_t_wei *)weights.get_data_handle();
34 data_t_diff_src *diff_src_data = (data_t_diff_src *)diff_src.get_data_handle();
36 const memory::desc diff_src_d = diff_src.get_primitive_desc().desc();
37 const memory::desc weights_d = weights.get_primitive_desc().desc();
38 const memory::desc diff_dst_d = diff_dst.get_primitive_desc().desc();
40 size_t padded_ic = diff_src_d.data.layout_desc.blocking.padding_dims[1];
41 size_t padded_oc = diff_dst_d.data.layout_desc.blocking.padding_dims[1];
43 mkldnn::impl::parallel_nd(c.mb, c.ng, c.ic / c.ng, c.ih, c.iw,
44 [&](int mb, int g, int ic, int ih, int iw) {
45 size_t sidx = mb * padded_ic * c.ih * c.iw
46 + g * padded_ic / c.ng * c.ih * c.iw
47 + ic * c.ih * c.iw + ih * c.iw + iw;
48 data_t_acc a = data_t_acc(0);
49 for (int oc = 0; oc < c.oc / c.ng; oc++) {
50 for (int kh = 0; kh < c.kh; kh++) {
51 for (int kw = 0; kw < c.kw; kw++) {
52 if (iw + c.padw < kw * (1 + c.dilw)
53 || ih + c.padh < kh * (1 + c.dilh))
55 int ow = iw - kw * (1 + c.dilw) + c.padw;
56 int oh = ih - kh * (1 + c.dilh) + c.padh;
57 if (ow % c.strw != 0 || oh % c.strh != 0)
61 if (oh < c.oh && ow < c.ow) {
62 size_t didx = mb * padded_oc * c.oh * c.ow
63 + g * padded_oc / c.ng * c.oh * c.ow
64 + oc * c.oh * c.ow + oh * c.ow + ow;
66 g * padded_oc / c.ng * padded_ic
68 + oc * padded_ic / c.ng * c.kh * c.kw
69 + ic * c.kh * c.kw + kh * c.kw + kw;
72 diff_dst_data[map_index(diff_dst_d, didx)]
73 * weights_data[map_index(weights_d, widx)]);
78 diff_src_data[map_index(diff_src_d, sidx)] = (data_t_diff_src)a;
82 template <typename data_t_diff_dst, typename data_t_wei,
83 typename data_t_acc, typename data_t_diff_src>
84 class convolution_backward_data_test
85 : public ::testing::TestWithParam<test_convolution_params_t> {
87 virtual void SetUp() {
88 auto p = ::testing::TestWithParam<test_convolution_params_t>::GetParam();
89 catch_expected_failures([=](){Test();}, p.expect_to_fail,
94 auto p = ::testing::TestWithParam<test_convolution_params_t>::GetParam();
95 ASSERT_TRUE(p.engine_kind == engine::kind::cpu);
96 ASSERT_EQ(p.aalgorithm, convolution_direct);
97 auto eng = engine(p.engine_kind, 0);
98 auto data_type_diff_src = data_traits<data_t_diff_src>::data_type;
99 auto data_type_diff_dst = data_traits<data_t_diff_dst>::data_type;
100 auto data_type_wei = data_traits<data_t_wei>::data_type;
102 test_convolution_sizes_t cd = p.sizes;
104 auto c_src_desc = create_md({ cd.mb, cd.ic, cd.ih, cd.iw },
105 data_type_diff_src, p.formats.src_format);
106 auto c_weights_desc = cd.ng > 1
107 ? create_md({ cd.ng, cd.oc / cd.ng, cd.ic / cd.ng, cd.kh, cd.kw },
108 data_type_wei, p.formats.weights_format)
109 : create_md({ cd.oc, cd.ic, cd.kh, cd.kw },
110 data_type_wei, p.formats.weights_format);
111 auto c_dst_desc = create_md({ cd.mb, cd.oc, cd.oh, cd.ow },
112 data_type_diff_dst, p.formats.dst_format);
113 auto c_src_desc_f = create_md({ cd.mb, cd.ic, cd.ih, cd.iw },
114 data_type_diff_dst, p.formats.src_format);
115 auto c_dst_desc_f = create_md({ cd.mb, cd.oc, cd.oh, cd.ow },
116 data_type_diff_src, p.formats.dst_format);
118 auto c_diff_src = test_memory(c_src_desc, eng);
119 auto c_weights = test_memory(c_weights_desc, eng);
120 auto c_diff_dst = test_memory(c_dst_desc, eng);
122 std::vector<ptrdiff_t> padR = {
123 right_padding(cd.ih, cd.oh, cd.kh, cd.padh, cd.strh, cd.dilh),
124 right_padding(cd.iw, cd.ow, cd.kw, cd.padw, cd.strw, cd.dilw)
127 // Only true for dense format
128 fill_data<data_t_wei>(c_weights.get_size() / sizeof(data_t_wei),
129 (data_t_wei *)c_weights.get().get_data_handle());
130 fill_data<data_t_diff_dst>(
131 c_diff_dst.get_size() / sizeof(data_t_diff_dst),
132 (data_t_diff_dst *)c_diff_dst.get().get_data_handle());
133 fill_data<data_t_diff_src>(
134 c_diff_src.get_size() / sizeof(data_t_diff_src),
135 (data_t_diff_src *)c_diff_src.get().get_data_handle());
136 check_zero_tail<data_t_diff_dst>(1, c_diff_dst.get());
137 check_zero_tail<data_t_wei>(1, c_weights.get());
138 check_zero_tail<data_t_diff_src>(1, c_diff_src.get());
140 auto conv_desc = convolution_forward::desc(
141 prop_kind::forward_training, p.aalgorithm, c_src_desc_f,
142 c_weights_desc, c_dst_desc_f,
143 { cd.strh, cd.strw }, { cd.dilh, cd.dilw },
144 { cd.padh, cd.padw }, padR, padding_kind::zero);
145 auto conv_primitive_desc = convolution_forward::primitive_desc(
148 auto conv_bwd_data_desc = convolution_backward_data::desc(
149 p.aalgorithm, c_src_desc, c_weights_desc, c_dst_desc,
150 { cd.strh, cd.strw }, { cd.dilh, cd.dilw },
151 { cd.padh, cd.padw }, padR, padding_kind::zero);
152 auto conv_bwd_data_primitive_desc
153 = convolution_backward_data::primitive_desc(
154 conv_bwd_data_desc, eng, conv_primitive_desc);
155 auto conv_bwd_data = convolution_backward_data(
156 conv_bwd_data_primitive_desc,
157 c_diff_dst.get(), c_weights.get(), c_diff_src.get());
159 std::vector<primitive> pipeline;
160 pipeline.push_back(conv_bwd_data);
161 stream(stream::kind::lazy).submit(pipeline).wait();
163 auto ref_memory = memory(memory::primitive_desc(c_src_desc, eng));
164 compute_ref_conv_bwd_data
165 <data_t_diff_dst, data_t_wei, data_t_acc, data_t_diff_src>(
166 cd, ref_memory, c_weights.get(), c_diff_dst.get());
167 check_zero_tail<data_t_diff_src>(1, ref_memory);
169 compare_data<data_t_diff_src>(ref_memory, c_diff_src.get());
170 check_zero_tail<data_t_diff_src>(0, c_diff_src.get());