updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / ref_inner_product.cpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "c_types_map.hpp"
18 #include "type_helpers.hpp"
19 #include "mkldnn_thread.hpp"
20 #include "mkldnn_traits.hpp"
21 #include "math_utils.hpp"
22
23 #include "ref_inner_product.hpp"
24
25 namespace mkldnn {
26 namespace impl {
27 namespace cpu {
28
29 using math::saturate;
30 using math::get_bias;
31
32 template <data_type_t src_type, data_type_t wei_type, data_type_t dst_type,
33          data_type_t acc_type>
34 void ref_inner_product_fwd_t<src_type, wei_type, dst_type, acc_type>
35         ::execute_forward() const {
36     auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
37     auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
38     auto bias = reinterpret_cast<const char *>(this->input_memory(2));
39     auto dst = reinterpret_cast<dst_data_t *>(this->memory());
40
41     const memory_desc_wrapper src_d(pd()->src_pd());
42     const memory_desc_wrapper dst_d(pd()->dst_pd());
43     const memory_desc_wrapper weights_d(pd()->weights_pd(0));
44     const memory_desc_wrapper bias_d(pd()->weights_pd(1));
45
46     const int MB = pd()->MB();
47     const int OC = pd()->OC();
48     const int IC = pd()->IC();
49
50     const bool src_has_spatial = utils::one_of(src_d.ndims(), 3, 4, 5);
51     const int ndims = src_d.ndims() - 2;
52
53     const auto &post_ops = pd()->attr()->post_ops_;
54     const bool do_relu = post_ops.len_ == 1;
55     const float nslope = do_relu ? post_ops.entry_[0].eltwise.alpha : 0.f;
56
57     auto ker_has_spatial = [=](int mb, int oc) {
58         acc_data_t d = 0;
59         const int KD = pd()->KD();
60         const int KH = pd()->KH();
61         const int KW = pd()->KW();
62         for (int ic = 0; ic < IC; ++ic) {
63             for (int kd = 0; kd < KD; ++kd) {
64                 for (int kh = 0; kh < KH; ++kh) {
65                     for (int kw = 0; kw < KW; ++kw) {
66                         switch (ndims) {
67                         case 3:
68                             d += (acc_data_t)src[src_d.off(mb, ic, kd, kh, kw)]
69                                     * weights[weights_d.off(
70                                               oc, ic, kd, kh, kw)];
71                             break;
72                         case 2:
73                             d += (acc_data_t)src[src_d.off(mb, ic, kh, kw)]
74                                     * weights[weights_d.off(oc, ic, kh, kw)];
75                             break;
76                         case 1:
77                             d += (acc_data_t)src[src_d.off(mb, ic, kw)]
78                                     * weights[weights_d.off(oc, ic, kw)];
79                             break;
80                         default: assert(!"unsupported ndims size");
81                         }
82                     }
83                 }
84             }
85         }
86         return d;
87     };
88
89     auto ker_no_spatial = [=](int mb, int oc) {
90         acc_data_t d = 0;
91         for (int ic = 0; ic < IC; ++ic) {
92             d += (acc_data_t)src[src_d.off(mb, ic)]
93                 * weights[weights_d.off(oc, ic)];
94         }
95         return d;
96     };
97
98     parallel_nd(MB, OC, [&](int mb, int oc) {
99         float a = bias
100             ? get_bias(bias, bias_d.off(oc), pd()->desc()->bias_desc.data_type)
101             : 0;
102         if (src_has_spatial)
103             a += ker_has_spatial(mb, oc);
104         else
105             a += ker_no_spatial(mb, oc);
106         if (do_relu && a < (acc_data_t)0)
107             a *= nslope;
108         dst[dst_d.off(mb, oc)] = saturate<dst_data_t>(a);
109     });
110 }
111 using namespace data_type;
112 template struct ref_inner_product_fwd_t<f32>;
113 template struct ref_inner_product_fwd_t<s16, s16, s32, s32>;
114 template struct ref_inner_product_fwd_t<u8, s8, f32, s32>;
115 template struct ref_inner_product_fwd_t<u8, s8, s32, s32>;
116 template struct ref_inner_product_fwd_t<u8, s8, s8, s32>;
117 template struct ref_inner_product_fwd_t<u8, s8, u8, s32>;
118
119 template <data_type_t diff_src_type, data_type_t wei_type,
120          data_type_t diff_dst_type, data_type_t acc_type>
121 void ref_inner_product_bwd_data_t<diff_src_type, wei_type, diff_dst_type,
122      acc_type>::execute_backward_data() const {
123     auto diff_dst = reinterpret_cast<const diff_dst_data_t *>(
124             this->input_memory(0));
125     auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
126     auto diff_src = reinterpret_cast<diff_src_data_t*>(this->memory());
127
128     const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
129     const memory_desc_wrapper weights_d(pd()->weights_pd(0));
130     const memory_desc_wrapper diff_src_d(pd()->diff_src_pd());
131
132     const int MB = pd()->MB();
133     const int OC = pd()->OC();
134     const int IC = pd()->IC();
135
136     const bool diff_src_has_spatial
137             = utils::one_of(diff_src_d.ndims(), 3, 4, 5);
138     const int ndims = diff_src_d.ndims() - 2;
139
140     parallel_nd(MB, IC, [&](int mb, int ic) {
141         if (diff_src_has_spatial) {
142             const int KD = pd()->KD();
143             const int KH = pd()->KH();
144             const int KW = pd()->KW();
145             for (int kd = 0; kd < KD; ++kd)
146             for (int kh = 0; kh < KH; ++kh)
147             for (int kw = 0; kw < KW; ++kw) {
148                 acc_data_t ds = acc_data_t(0);
149                 for (int oc = 0; oc < OC; ++oc) {
150                     switch (ndims) {
151                     case 3:
152                         ds += (acc_data_t)(diff_dst[diff_dst_d.off(mb, oc)]
153                                 * weights[weights_d.off(oc, ic, kd, kh, kw)]);
154                         break;
155                     case 2:
156                         ds += (acc_data_t)(diff_dst[diff_dst_d.off(mb, oc)]
157                                 * weights[weights_d.off(oc, ic, kh, kw)]);
158                         break;
159                     case 1:
160                         ds += (acc_data_t)(diff_dst[diff_dst_d.off(mb, oc)]
161                                 * weights[weights_d.off(oc, ic, kw)]);
162                         break;
163                     default: assert(!"unsupported ndims size");
164                     }
165                 }
166                 switch (ndims) {
167                 case 3:
168                     diff_src[diff_src_d.off(mb, ic, kd, kh, kw)]
169                             = (diff_src_data_t)ds;
170                     break;
171                 case 2:
172                     diff_src[diff_src_d.off(mb, ic, kh, kw)]
173                             = (diff_src_data_t)ds;
174                     break;
175                 case 1:
176                     diff_src[diff_src_d.off(mb, ic, kw)] = (diff_src_data_t)ds;
177                     break;
178                 default: assert(!"unsupported ndims size");
179                 }
180             }
181         } else {
182             acc_data_t ds = acc_data_t(0);
183             for (int oc = 0; oc < OC; ++oc) {
184                 ds += (acc_data_t)(diff_dst[diff_dst_d.off(mb, oc)] *
185                     weights[weights_d.off(oc, ic)]);
186             }
187             diff_src[diff_src_d.off(mb, ic)] = (diff_src_data_t)ds;
188         }
189     });
190 }
191
192 template struct ref_inner_product_bwd_data_t<f32, f32, f32, f32>;
193 template struct ref_inner_product_bwd_data_t<s32, s16, s16, s32>;
194
195 template <impl::data_type_t data_type>
196 void ref_inner_product_bwd_weights_t<data_type>::execute_backward_weights() const {
197     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
198     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
199     auto diff_weights = reinterpret_cast<data_t*>(this->memory(0));
200     auto diff_bias = reinterpret_cast<data_t*>(this->memory(1));
201
202     const memory_desc_wrapper src_d(pd()->src_pd());
203     const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
204     const memory_desc_wrapper diff_weights_d(pd()->diff_weights_pd(0));
205     const memory_desc_wrapper diff_bias_d(pd()->diff_weights_pd(1));
206
207     const int MB = pd()->MB();
208     const int OC = pd()->OC();
209     const int IC = pd()->IC();
210
211     const bool src_has_spatial = utils::one_of(src_d.ndims(), 3, 4 ,5);
212     const int ndims = src_d.ndims() - 2;
213
214     parallel_nd(OC, IC, [&](int oc, int ic) {
215         if (src_has_spatial) {
216             const int KD = pd()->KD();
217             const int KH = pd()->KH();
218             const int KW = pd()->KW();
219             for (int kd = 0; kd < KD; ++kd) {
220                 for (int kh = 0; kh < KH; ++kh) {
221                     for (int kw = 0; kw < KW; ++kw) {
222                         data_t *dw(nullptr);
223                         switch (ndims) {
224                         case 3:
225                             dw = &diff_weights[diff_weights_d.off(
226                                     oc, ic, kd, kh, kw)];
227                             break;
228                         case 2:
229                             dw = &diff_weights[diff_weights_d.off(
230                                     oc, ic, kh, kw)];
231                             break;
232                         case 1:
233                             dw = &diff_weights[diff_weights_d.off(oc, ic, kw)];
234                             break;
235                         default: assert(!"unsupported ndims size");
236                         }
237                         *dw = data_t(0);
238                         for (int mb = 0; mb < MB; ++mb) {
239                             switch (ndims) {
240                             case 3:
241                                 *dw += diff_dst[diff_dst_d.off(mb, oc)]
242                                         * src[src_d.off(mb, ic, kd, kh, kw)];
243                                 break;
244                             case 2:
245                                 *dw += diff_dst[diff_dst_d.off(mb, oc)]
246                                         * src[src_d.off(mb, ic, kh, kw)];
247                                 break;
248                             case 1:
249                                 *dw += diff_dst[diff_dst_d.off(mb, oc)]
250                                         * src[src_d.off(mb, ic, kw)];
251                                 break;
252                             default: assert(!"unsupported ndims size");
253                             }
254                         }
255                     }
256                 }
257             }
258         } else {
259             data_t *dw = &diff_weights[diff_weights_d.off(oc, ic)];
260             *dw = data_t(0);
261             for (int mb = 0; mb < MB; ++mb) {
262                 *dw += diff_dst[diff_dst_d.off(mb, oc)] *
263                     src[src_d.off(mb, ic)];
264             }
265         }
266     });
267
268     if (diff_bias) {
269         diff_bias += diff_bias_d.blocking_desc().offset_padding;
270
271         parallel_nd(OC, [&](int oc) {
272             data_t *db = &diff_bias[oc];
273             *db = data_t(0);
274             for (int mb = 0; mb < MB; ++mb)
275                 *db += diff_dst[diff_dst_d.off(mb, oc)];
276         });
277     }
278 }
279
280 template struct ref_inner_product_bwd_weights_t<data_type::f32>;
281
282 }
283 }
284 }
285
286 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s