Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / ref_inner_product.cpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "c_types_map.hpp"
18 #include "type_helpers.hpp"
19 #include "mkldnn_thread.hpp"
20 #include "mkldnn_traits.hpp"
21 #include "math_utils.hpp"
22
23 #include "ref_inner_product.hpp"
24
25 namespace mkldnn {
26 namespace impl {
27 namespace cpu {
28
29 using math::saturate;
30 using math::get_bias;
31
32 template <data_type_t src_type, data_type_t wei_type, data_type_t dst_type,
33          data_type_t acc_type>
34 void ref_inner_product_fwd_t<src_type, wei_type, dst_type, acc_type>
35         ::execute_forward() const {
36     auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
37     auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
38     auto bias = reinterpret_cast<const char *>(this->input_memory(2));
39     auto dst = reinterpret_cast<dst_data_t *>(this->memory());
40
41     const memory_desc_wrapper src_d(pd()->src_pd());
42     const memory_desc_wrapper dst_d(pd()->dst_pd());
43     const memory_desc_wrapper weights_d(pd()->weights_pd(0));
44     const memory_desc_wrapper bias_d(pd()->weights_pd(1));
45
46     const int MB = pd()->MB();
47     const int OC = pd()->OC();
48     const int IC = pd()->IC();
49
50     const bool src_has_spatial = utils::one_of(src_d.ndims(), 4, 5);
51
52     const bool is_3d = src_d.ndims() == 5;
53
54     const auto &post_ops = pd()->attr()->post_ops_;
55     const bool do_relu = post_ops.len_ == 1;
56     const float nslope = do_relu ? post_ops.entry_[0].eltwise.alpha : 0.f;
57
58     auto ker_has_spatial = [=](int mb, int oc) {
59         acc_data_t d = 0;
60         const int KD = pd()->KD();
61         const int KH = pd()->KH();
62         const int KW = pd()->KW();
63         for (int ic = 0; ic < IC; ++ic) {
64             for (int kd = 0; kd < KD; ++kd) {
65                 for (int kh = 0; kh < KH; ++kh) {
66                     for (int kw = 0; kw < KW; ++kw) {
67                         if (is_3d)
68                             d += (acc_data_t)src[src_d.off(mb, ic, kd, kh, kw)]
69                                 * weights[weights_d.off(oc, ic, kd, kh, kw)];
70                         else
71                             d += (acc_data_t)src[src_d.off(mb, ic, kh, kw)]
72                                 * weights[weights_d.off(oc, ic, kh, kw)];
73                     }
74                 }
75             }
76         }
77         return d;
78     };
79
80     auto ker_no_spatial = [=](int mb, int oc) {
81         acc_data_t d = 0;
82         for (int ic = 0; ic < IC; ++ic) {
83             d += (acc_data_t)src[src_d.off(mb, ic)]
84                 * weights[weights_d.off(oc, ic)];
85         }
86         return d;
87     };
88
89     parallel_nd(MB, OC, [&](int mb, int oc) {
90         float a = bias
91             ? get_bias(bias, bias_d.off(oc), pd()->desc()->bias_desc.data_type)
92             : 0;
93         if (src_has_spatial)
94             a += ker_has_spatial(mb, oc);
95         else
96             a += ker_no_spatial(mb, oc);
97         if (do_relu && a < (acc_data_t)0)
98             a *= nslope;
99         dst[dst_d.off(mb, oc)] = saturate<dst_data_t>(a);
100     });
101 }
102 using namespace data_type;
103 template struct ref_inner_product_fwd_t<f32>;
104 template struct ref_inner_product_fwd_t<s16, s16, s32, s32>;
105 template struct ref_inner_product_fwd_t<u8, s8, f32, s32>;
106 template struct ref_inner_product_fwd_t<u8, s8, s32, s32>;
107 template struct ref_inner_product_fwd_t<u8, s8, s8, s32>;
108 template struct ref_inner_product_fwd_t<u8, s8, u8, s32>;
109
110 template <data_type_t diff_src_type, data_type_t wei_type,
111          data_type_t diff_dst_type, data_type_t acc_type>
112 void ref_inner_product_bwd_data_t<diff_src_type, wei_type, diff_dst_type,
113      acc_type>::execute_backward_data() const {
114     auto diff_dst = reinterpret_cast<const diff_dst_data_t *>(
115             this->input_memory(0));
116     auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
117     auto diff_src = reinterpret_cast<diff_src_data_t*>(this->memory());
118
119     const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
120     const memory_desc_wrapper weights_d(pd()->weights_pd(0));
121     const memory_desc_wrapper diff_src_d(pd()->diff_src_pd());
122
123     const int MB = pd()->MB();
124     const int OC = pd()->OC();
125     const int IC = pd()->IC();
126
127     const bool diff_src_has_spatial = utils::one_of(diff_src_d.ndims(), 4, 5);
128
129     const bool is_3d = diff_src_d.ndims() == 5;
130
131     parallel_nd(MB, IC, [&](int mb, int ic) {
132         if (diff_src_has_spatial) {
133             const int KD = pd()->KD();
134             const int KH = pd()->KH();
135             const int KW = pd()->KW();
136             for (int kd = 0; kd < KD; ++kd)
137             for (int kh = 0; kh < KH; ++kh)
138             for (int kw = 0; kw < KW; ++kw) {
139                 acc_data_t ds = acc_data_t(0);
140                 for (int oc = 0; oc < OC; ++oc) {
141                     if (is_3d)
142                         ds += (acc_data_t)(diff_dst[diff_dst_d.off(mb, oc)]
143                             * weights[weights_d.off(oc, ic, kd, kh, kw)]);
144                     else
145                         ds += (acc_data_t)(diff_dst[diff_dst_d.off(mb, oc)]
146                             * weights[weights_d.off(oc, ic, kh, kw)]);
147                 }
148                 if (is_3d) diff_src[diff_src_d.off(mb, ic, kd, kh, kw)] =
149                     (diff_src_data_t)ds;
150                 else diff_src[diff_src_d.off(mb, ic, kh, kw)] =
151                     (diff_src_data_t)ds;
152             }
153         } else {
154             acc_data_t ds = acc_data_t(0);
155             for (int oc = 0; oc < OC; ++oc) {
156                 ds += (acc_data_t)(diff_dst[diff_dst_d.off(mb, oc)] *
157                     weights[weights_d.off(oc, ic)]);
158             }
159             diff_src[diff_src_d.off(mb, ic)] = (diff_src_data_t)ds;
160         }
161     });
162 }
163
164 template struct ref_inner_product_bwd_data_t<f32, f32, f32, f32>;
165 template struct ref_inner_product_bwd_data_t<s32, s16, s16, s32>;
166
167 template <impl::data_type_t data_type>
168 void ref_inner_product_bwd_weights_t<data_type>::execute_backward_weights() const {
169     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
170     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
171     auto diff_weights = reinterpret_cast<data_t*>(this->memory(0));
172     auto diff_bias = reinterpret_cast<data_t*>(this->memory(1));
173
174     const memory_desc_wrapper src_d(pd()->src_pd());
175     const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
176     const memory_desc_wrapper diff_weights_d(pd()->diff_weights_pd(0));
177     const memory_desc_wrapper diff_bias_d(pd()->diff_weights_pd(1));
178
179     const int MB = pd()->MB();
180     const int OC = pd()->OC();
181     const int IC = pd()->IC();
182
183     const bool src_has_spatial = utils::one_of(src_d.ndims(), 4 ,5);
184
185     const bool is_3d = src_d.ndims() == 5;
186
187     parallel_nd(OC, IC, [&](int oc, int ic) {
188         if (src_has_spatial) {
189             const int KD = pd()->KD();
190             const int KH = pd()->KH();
191             const int KW = pd()->KW();
192             for (int kd = 0; kd < KD; ++kd) {
193                 for (int kh = 0; kh < KH; ++kh) {
194                     for (int kw = 0; kw < KW; ++kw) {
195                         data_t *dw = is_3d
196                             ? &diff_weights[
197                             diff_weights_d.off(oc, ic, kd, kh, kw)]
198                             : &diff_weights[
199                             diff_weights_d.off(oc, ic, kh, kw)];
200                         *dw = data_t(0);
201                         for (int mb = 0; mb < MB; ++mb) {
202                             if (is_3d)
203                                 *dw += diff_dst[diff_dst_d.off(mb, oc)] *
204                                     src[src_d.off(mb, ic, kd, kh, kw)];
205                             else
206                                 *dw += diff_dst[diff_dst_d.off(mb, oc)] *
207                                     src[src_d.off(mb, ic, kh, kw)];
208                         }
209                     }
210                 }
211             }
212         } else {
213             data_t *dw = &diff_weights[diff_weights_d.off(oc, ic)];
214             *dw = data_t(0);
215             for (int mb = 0; mb < MB; ++mb) {
216                 *dw += diff_dst[diff_dst_d.off(mb, oc)] *
217                     src[src_d.off(mb, ic)];
218             }
219         }
220     });
221
222     if (diff_bias) {
223         diff_bias += diff_bias_d.blocking_desc().offset_padding;
224
225         parallel_nd(OC, [&](int oc) {
226             data_t *db = &diff_bias[oc];
227             *db = data_t(0);
228             for (int mb = 0; mb < MB; ++mb)
229                 *db += diff_dst[diff_dst_d.off(mb, oc)];
230         });
231     }
232 }
233
234 template struct ref_inner_product_bwd_weights_t<data_type::f32>;
235
236 }
237 }
238 }
239
240 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s