Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / gtests / test_convolution_backward_data_common.hpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef TEST_CONVOLUTION_BACKWARD_DATA_COMMON_H
18 #define TEST_CONVOLUTION_BACKWARD_DATA_COMMON_H
19
20 #include "mkldnn_test_common.hpp"
21 #include "gtest/gtest.h"
22
23 #include "mkldnn.hpp"
24
25 namespace mkldnn {
26
27 template <typename data_t_diff_dst, typename data_t_wei,
28           typename data_t_acc, typename data_t_diff_src>
29 void compute_ref_conv_bwd_data(const test_convolution_sizes_t &c,
30         const memory &diff_src, const memory &weights, const memory &diff_dst)
31 {
32     data_t_diff_dst *diff_dst_data = (data_t_diff_dst *)diff_dst.get_data_handle();
33     data_t_wei *weights_data = (data_t_wei *)weights.get_data_handle();
34     data_t_diff_src *diff_src_data = (data_t_diff_src *)diff_src.get_data_handle();
35
36     const memory::desc diff_src_d = diff_src.get_primitive_desc().desc();
37     const memory::desc weights_d = weights.get_primitive_desc().desc();
38     const memory::desc diff_dst_d = diff_dst.get_primitive_desc().desc();
39
40     size_t padded_ic = diff_src_d.data.layout_desc.blocking.padding_dims[1];
41     size_t padded_oc = diff_dst_d.data.layout_desc.blocking.padding_dims[1];
42
43     mkldnn::impl::parallel_nd(c.mb, c.ng, c.ic / c.ng, c.ih, c.iw,
44         [&](int mb, int g, int ic, int ih, int iw) {
45             size_t sidx = mb * padded_ic * c.ih * c.iw
46                     + g * padded_ic / c.ng * c.ih * c.iw
47                     + ic * c.ih * c.iw + ih * c.iw + iw;
48             data_t_acc a = data_t_acc(0);
49             for (int oc = 0; oc < c.oc / c.ng; oc++) {
50                 for (int kh = 0; kh < c.kh; kh++) {
51                     for (int kw = 0; kw < c.kw; kw++) {
52                         if (iw + c.padw < kw * (1 + c.dilw)
53                            || ih + c.padh < kh * (1 + c.dilh))
54                             continue;
55                         int ow = iw - kw * (1 + c.dilw) + c.padw;
56                         int oh = ih - kh * (1 + c.dilh) + c.padh;
57                         if (ow % c.strw != 0 || oh % c.strh != 0)
58                             continue;
59                         ow /= c.strw;
60                         oh /= c.strh;
61                         if (oh < c.oh && ow < c.ow) {
62                             size_t didx = mb * padded_oc * c.oh * c.ow
63                                 + g * padded_oc / c.ng * c.oh * c.ow
64                                 + oc * c.oh * c.ow + oh * c.ow + ow;
65                             size_t widx =
66                                 g * padded_oc / c.ng * padded_ic
67                                 / c.ng * c.kh * c.kw
68                                 + oc * padded_ic / c.ng * c.kh * c.kw
69                                 + ic * c.kh * c.kw + kh * c.kw + kw;
70
71                             a += (data_t_acc)(
72                                 diff_dst_data[map_index(diff_dst_d, didx)]
73                                 * weights_data[map_index(weights_d, widx)]);
74                         }
75                     }
76                 }
77             }
78             diff_src_data[map_index(diff_src_d, sidx)] = (data_t_diff_src)a;
79     });
80 }
81
82 template <typename data_t_diff_dst, typename data_t_wei,
83           typename data_t_acc, typename data_t_diff_src>
84 class convolution_backward_data_test
85             : public ::testing::TestWithParam<test_convolution_params_t> {
86 protected:
87     virtual void SetUp() {
88         auto p = ::testing::TestWithParam<test_convolution_params_t>::GetParam();
89         catch_expected_failures([=](){Test();}, p.expect_to_fail,
90                     p.expected_status);
91     }
92
93     void Test() {
94         auto p = ::testing::TestWithParam<test_convolution_params_t>::GetParam();
95         ASSERT_TRUE(p.engine_kind == engine::kind::cpu);
96         ASSERT_EQ(p.aalgorithm, convolution_direct);
97         auto eng =  engine(p.engine_kind, 0);
98         auto data_type_diff_src = data_traits<data_t_diff_src>::data_type;
99         auto data_type_diff_dst = data_traits<data_t_diff_dst>::data_type;
100         auto data_type_wei = data_traits<data_t_wei>::data_type;
101
102         test_convolution_sizes_t cd = p.sizes;
103
104         auto c_src_desc = create_md({ cd.mb, cd.ic, cd.ih, cd.iw },
105                 data_type_diff_src, p.formats.src_format);
106         auto c_weights_desc = cd.ng > 1
107             ? create_md({ cd.ng, cd.oc / cd.ng, cd.ic / cd.ng, cd.kh, cd.kw },
108                     data_type_wei, p.formats.weights_format)
109             : create_md({ cd.oc, cd.ic, cd.kh, cd.kw },
110                     data_type_wei, p.formats.weights_format);
111         auto c_dst_desc = create_md({ cd.mb, cd.oc, cd.oh, cd.ow },
112                 data_type_diff_dst, p.formats.dst_format);
113         auto c_src_desc_f = create_md({ cd.mb, cd.ic, cd.ih, cd.iw },
114                 data_type_diff_dst, p.formats.src_format);
115         auto c_dst_desc_f = create_md({ cd.mb, cd.oc, cd.oh, cd.ow },
116                 data_type_diff_src, p.formats.dst_format);
117
118         auto c_diff_src = test_memory(c_src_desc, eng);
119         auto c_weights = test_memory(c_weights_desc, eng);
120         auto c_diff_dst = test_memory(c_dst_desc, eng);
121
122         std::vector<ptrdiff_t> padR = {
123             right_padding(cd.ih, cd.oh, cd.kh, cd.padh, cd.strh, cd.dilh),
124             right_padding(cd.iw, cd.ow, cd.kw, cd.padw, cd.strw, cd.dilw)
125         };
126
127         // Only true for dense format
128         fill_data<data_t_wei>(c_weights.get_size() / sizeof(data_t_wei),
129                 (data_t_wei *)c_weights.get().get_data_handle());
130         fill_data<data_t_diff_dst>(
131                 c_diff_dst.get_size() / sizeof(data_t_diff_dst),
132                 (data_t_diff_dst *)c_diff_dst.get().get_data_handle());
133         fill_data<data_t_diff_src>(
134                 c_diff_src.get_size() / sizeof(data_t_diff_src),
135                 (data_t_diff_src *)c_diff_src.get().get_data_handle());
136         check_zero_tail<data_t_diff_dst>(1, c_diff_dst.get());
137         check_zero_tail<data_t_wei>(1, c_weights.get());
138         check_zero_tail<data_t_diff_src>(1, c_diff_src.get());
139
140         auto conv_desc = convolution_forward::desc(
141                 prop_kind::forward_training, p.aalgorithm, c_src_desc_f,
142                 c_weights_desc, c_dst_desc_f,
143                 { cd.strh, cd.strw }, { cd.dilh, cd.dilw },
144                 { cd.padh, cd.padw }, padR, padding_kind::zero);
145         auto conv_primitive_desc = convolution_forward::primitive_desc(
146                 conv_desc, eng);
147
148         auto conv_bwd_data_desc = convolution_backward_data::desc(
149                 p.aalgorithm, c_src_desc, c_weights_desc, c_dst_desc,
150                 { cd.strh, cd.strw }, { cd.dilh, cd.dilw },
151                 { cd.padh, cd.padw }, padR, padding_kind::zero);
152         auto conv_bwd_data_primitive_desc
153             = convolution_backward_data::primitive_desc(
154                     conv_bwd_data_desc, eng, conv_primitive_desc);
155         auto conv_bwd_data = convolution_backward_data(
156                 conv_bwd_data_primitive_desc,
157                 c_diff_dst.get(), c_weights.get(), c_diff_src.get());
158
159         std::vector<primitive> pipeline;
160         pipeline.push_back(conv_bwd_data);
161         stream(stream::kind::lazy).submit(pipeline).wait();
162
163         auto ref_memory = memory(memory::primitive_desc(c_src_desc, eng));
164         compute_ref_conv_bwd_data
165             <data_t_diff_dst, data_t_wei, data_t_acc, data_t_diff_src>(
166                     cd, ref_memory, c_weights.get(), c_diff_dst.get());
167         check_zero_tail<data_t_diff_src>(1, ref_memory);
168
169         compare_data<data_t_diff_src>(ref_memory, c_diff_src.get());
170         check_zero_tail<data_t_diff_src>(0, c_diff_src.get());
171     }
172 };
173
174 }
175 #endif