Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / ref_convolution.cpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "c_types_map.hpp"
18 #include "type_helpers.hpp"
19 #include "mkldnn_thread.hpp"
20 #include "mkldnn_traits.hpp"
21 #include "math_utils.hpp"
22
23 #include "ref_convolution.hpp"
24
25 namespace mkldnn {
26 namespace impl {
27 namespace cpu {
28
29 using math::saturate;
30 using math::get_bias;
31
32 template <data_type_t src_type, data_type_t wei_type,
33          data_type_t dst_type, data_type_t acc_type>
34 void ref_convolution_fwd_t<src_type, wei_type, dst_type, acc_type>
35         ::execute_forward() const {
36     auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
37     auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
38     auto bias = reinterpret_cast<const char *>(this->input_memory(2));
39     auto dst = reinterpret_cast<dst_data_t *>(this->memory());
40
41     const memory_desc_wrapper src_d(pd()->src_pd());
42     const memory_desc_wrapper dst_d(pd()->dst_pd());
43     const memory_desc_wrapper weights_d(pd()->weights_pd(0));
44     const memory_desc_wrapper bias_d(pd()->weights_pd(1));
45
46     const bool with_groups = pd()->with_groups();
47
48     const int G = pd()->G();
49     const int MB = pd()->MB();
50     const int OD = pd()->OD();
51     const int OH = pd()->OH();
52     const int OW = pd()->OW();
53     const int ID = pd()->ID();
54     const int IH = pd()->IH();
55     const int IW = pd()->IW();
56
57     const int OC = pd()->OC() / G;
58     const int IC = pd()->IC() / G;
59     const int KD = pd()->KD();
60     const int KH = pd()->KH();
61     const int KW = pd()->KW();
62
63     const int KSD = pd()->KSD();
64     const int KSH = pd()->KSH();
65     const int KSW = pd()->KSW();
66
67     const int KDD = pd()->KDD();
68     const int KDH = pd()->KDH();
69     const int KDW = pd()->KDW();
70
71     const int padFront = pd()->padFront();
72     const int padT = pd()->padT();
73     const int padL = pd()->padL();
74
75     const bool with_relu = 0; // TODO: change if support post_ops
76     const float nslope = 0.f;
77
78     const int ndims = pd()->desc()->src_desc.ndims;
79
80     auto ker = [=](int g, int mb, int oc, int od, int oh,
81             int ow) {
82         acc_data_t d = 0;
83         for (int ic = 0; ic < IC; ++ic)
84         for (int kd = 0; kd < KD; ++kd)
85         for (int kh = 0; kh < KH; ++kh)
86         for (int kw = 0; kw < KW; ++kw) {
87             const int id = od * KSD - padFront + kd * (1 + KDD);
88             const int ih = oh * KSH - padT + kh * (1 + KDH);
89             const int iw = ow * KSW - padL + kw * (1 + KDW);
90
91             if (id < 0 || id >= ID) continue;
92             if (ih < 0 || ih >= IH) continue;
93             if (iw < 0 || iw >= IW) continue;
94
95             if (ndims == 5)
96                 d += (acc_data_t)src[src_d.off(mb, g*IC + ic, id, ih, iw)]
97                     * (with_groups
98                     ? weights[weights_d.off(g, oc, ic, kd, kh, kw)]
99                     : weights[weights_d.off(oc, ic, kd, kh, kw)]);
100             else if (ndims == 4)
101                 d += (acc_data_t)src[src_d.off(mb, g*IC + ic, ih, iw)]
102                     * (with_groups
103                     ? weights[weights_d.off(g, oc, ic, kh, kw)]
104                     : weights[weights_d.off(oc, ic, kh, kw)]);
105             else if (ndims == 3)
106                 d += (acc_data_t)src[src_d.off(mb, g*IC + ic, iw)]
107                     * (with_groups
108                     ? weights[weights_d.off(g, oc, ic, kw)]
109                     : weights[weights_d.off(oc, ic, kw)]);
110            else
111                assert(false);
112
113         }
114         return d;
115     };
116
117     parallel_nd(G, MB, OC, OD, OH, OW,
118         [&](int g, int mb, int oc, int od, int oh, int ow) {
119         float a_fp = ker(g, mb, oc, od, oh, ow);
120
121         if (bias)
122             a_fp += get_bias(bias, bias_d.off(g * OC + oc),
123                              pd()->desc()->bias_desc.data_type);
124
125         if (with_relu && a_fp < 0)
126             a_fp *= nslope;
127
128         if (data_traits<dst_data_t>::data_type != data_type::f32) {
129             switch (pd()->attr()->round_mode_) {
130                 case round_mode::down:    a_fp = floorf(a_fp); break;
131                 case round_mode::nearest: a_fp = nearbyintf(a_fp); break;
132             }
133         }
134
135         if (ndims == 5)
136             dst[dst_d.off(mb, g*OC + oc, od, oh, ow)] = saturate<dst_data_t>(a_fp);
137         else if (ndims == 4)
138             dst[dst_d.off(mb, g*OC + oc, oh, ow)] = saturate<dst_data_t>(a_fp);
139         else if (ndims == 3)
140             dst[dst_d.off(mb, g*OC + oc, ow)] = saturate<dst_data_t>(a_fp);
141         else
142             assert(false);
143    });
144 }
145
146 template <data_type_t diff_src_type, data_type_t wei_type,
147          data_type_t diff_dst_type, data_type_t acc_type>
148 void ref_convolution_bwd_data_t<diff_src_type, wei_type, diff_dst_type,
149      acc_type>::execute_backward_data() const {
150     auto diff_dst = reinterpret_cast<const diff_dst_data_t*>(
151             this->input_memory(0));
152     auto weights = reinterpret_cast<const wei_data_t*>(this->input_memory(1));
153     auto bias = reinterpret_cast<const char *>(this->input_memory(2));
154     auto diff_src = reinterpret_cast<diff_src_data_t*>(this->memory());
155
156     const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
157     const memory_desc_wrapper diff_src_d(pd()->diff_src_pd());
158     const memory_desc_wrapper weights_d(pd()->weights_pd(0));
159     const memory_desc_wrapper bias_d(pd()->weights_pd(1));
160
161     const bool with_groups = pd()->with_groups();
162
163     const int G = pd()->G();
164     const int MB = pd()->MB();
165     const int OD = pd()->OD();
166     const int OH = pd()->OH();
167     const int OW = pd()->OW();
168     const int ID = pd()->ID();
169     const int IH = pd()->IH();
170     const int IW = pd()->IW();
171
172     const int OC = pd()->OC() / G;
173     const int IC = pd()->IC() / G;
174     const int KD = pd()->KD();
175     const int KH = pd()->KH();
176     const int KW = pd()->KW();
177
178     const int KSD = pd()->KSD();
179     const int KSH = pd()->KSH();
180     const int KSW = pd()->KSW();
181
182     const int KDD = pd()->KDD();
183     const int KDH = pd()->KDH();
184     const int KDW = pd()->KDW();
185
186     const int padFront = pd()->padFront();
187     const int padT = pd()->padT();
188     const int padL = pd()->padL();
189
190     const int ndims = pd()->desc()->diff_src_desc.ndims;
191
192     auto ker = [=](int g, int mb, int ic, int id, int ih,
193             int iw) {
194         acc_data_t d = 0;
195         for (int oc = 0; oc < OC; ++oc)
196         for (int kd = 0; kd < KD; ++kd)
197         for (int kh = 0; kh < KH; ++kh)
198         for (int kw = 0; kw < KW; ++kw) {
199             if (iw + padL < kw * (1 + KDW)
200                 || ih + padT < kh * (1 + KDH)
201                 || id + padFront < kd * (1 + KDD))
202                 continue;
203             int ow = iw - kw * (1 + KDW) + padL;
204             int oh = ih - kh * (1 + KDH) + padT;
205             int od = id - kd * (1 + KDD) + padFront;
206             if (ow % KSW != 0 || oh % KSH != 0 || od % KSD != 0)
207                 continue;
208
209             ow /= KSW;
210             oh /= KSH;
211             od /= KSD;
212
213             if (od < OD && oh < OH && ow < OW) {
214                 if (ndims == 5)
215                     d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC
216                         + oc, od, oh, ow)] * (with_groups
217                         ? weights[weights_d.off(g, oc, ic, kd, kh, kw)]
218                         : weights[weights_d.off(oc, ic, kd, kh, kw)]);
219                 else if (ndims == 4)
220                     d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC
221                         + oc, oh, ow)] * (with_groups
222                         ? weights[weights_d.off(g, oc, ic, kh, kw)]
223                         : weights[weights_d.off(oc, ic, kh, kw)]);
224                 else if (ndims == 3)
225                     d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC
226                         + oc, ow)] * (with_groups
227                         ? weights[weights_d.off(g, oc, ic, kw)]
228                         : weights[weights_d.off(oc, ic, kw)]);
229                 else
230                     assert(false);
231             }
232         }
233         return d;
234     };
235
236     parallel_nd(G, MB, IC, ID, IH, IW,
237         [&](int g, int mb, int ic, int id, int ih, int iw) {
238         auto ds_idx = (ndims == 5)
239             ? diff_src_d.off(mb, g*IC + ic, id, ih, iw)
240             : (ndims == 4)
241             ? diff_src_d.off(mb, g*IC + ic, ih, iw)
242             : diff_src_d.off(mb, g*IC + ic, iw);
243         float a = bias
244             ? get_bias(bias, bias_d.off(g * IC + ic),
245                     pd()->desc()->bias_desc.data_type)
246             : 0;
247         a += ker(g, mb, ic, id, ih, iw);
248         diff_src[ds_idx] = saturate<diff_src_data_t>(a);
249     });
250 }
251
252 template <data_type_t src_type, data_type_t diff_wei_type,
253          data_type_t diff_dst_type, data_type_t acc_type>
254 void ref_convolution_bwd_weights_t<src_type, diff_wei_type, diff_dst_type,
255      acc_type>::execute_backward_weights() const {
256     auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
257     auto diff_dst = reinterpret_cast<const diff_dst_data_t *>(
258             this->input_memory(1));
259     auto diff_weights = reinterpret_cast<diff_wei_data_t*>(this->memory(0));
260     auto diff_bias = reinterpret_cast<diff_wei_data_t *>(this->memory(1));
261
262     const memory_desc_wrapper src_d(pd()->src_pd());
263     const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
264     const memory_desc_wrapper diff_weights_d(pd()->diff_weights_pd(0));
265     const memory_desc_wrapper diff_bias_d(pd()->diff_weights_pd(1));
266
267     const bool with_groups = pd()->with_groups();
268
269     const int G = pd()->G();
270     const int MB = pd()->MB();
271     const int OD = pd()->OD();
272     const int OH = pd()->OH();
273     const int OW = pd()->OW();
274     const int ID = pd()->ID();
275     const int IH = pd()->IH();
276     const int IW = pd()->IW();
277
278     const int OC = pd()->OC() / G;
279     const int IC = pd()->IC() / G;
280     const int KD = pd()->KD();
281     const int KH = pd()->KH();
282     const int KW = pd()->KW();
283
284     const int KSD = pd()->KSD();
285     const int KSH = pd()->KSH();
286     const int KSW = pd()->KSW();
287
288     const int KDD = pd()->KDD();
289     const int KDH = pd()->KDH();
290     const int KDW = pd()->KDW();
291
292     const int padFront = pd()->padFront();
293     const int padT = pd()->padT();
294     const int padL = pd()->padL();
295
296     const int ndims = pd()->desc()->src_desc.ndims;
297
298 auto ker = [=](acc_data_t &d, int g, int oc, int ic, int kd, int kh, int kw) {
299         for (int mb = 0; mb < MB; ++mb)
300         for (int od = 0; od < OD; ++od)
301         for (int oh = 0; oh < OH; ++oh)
302         for (int ow = 0; ow < OW; ++ow) {
303             if (ow*KSW + kw * (1 + KDW) < padL
304                 || oh*KSH + kh * (1 + KDH) < padT
305                 || od*KSD + kd * (1 + KDD) < padFront
306                 || ow*KSW + kw * (1 + KDW) >= IW + padL
307                 || oh*KSH + kh * (1 + KDH) >= IH + padT
308                 || od*KSD + kd * (1 + KDD) >= ID + padFront)
309                 continue;
310
311             int id = od*KSD - padFront + kd * (1 + KDD);
312             int ih = oh*KSH - padT + kh * (1 + KDH);
313             int iw = ow*KSW - padL + kw * (1 + KDW);
314             if (ndims == 5)
315                 d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, od,
316                     oh, ow)] * src[src_d.off(mb, g*IC + ic, id, ih, iw)];
317             else if (ndims == 4)
318                 d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, oh, ow)]
319                     * src[src_d.off(mb, g*IC + ic, ih, iw)];
320             else if (ndims == 3)
321                 d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, ow)]
322                     * src[src_d.off(mb, g*IC + ic, iw)];
323             else
324                 assert(false);
325         }
326     };
327
328     auto ker_bias = [=](acc_data_t &d, int g, int oc) {
329         for (int mb = 0; mb < MB; ++mb)
330         for (int od = 0; od < OD; ++od)
331         for (int oh = 0; oh < OH; ++oh)
332         for (int ow = 0; ow < OW; ++ow) {
333             if (ndims == 5)
334                 d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, od, oh,
335                      ow)];
336             else if (ndims == 4)
337                 d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, oh,
338                      ow)];
339             else if (ndims == 3)
340                 d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, ow)];
341             else
342                 assert(false);
343         }
344     };
345
346     parallel_nd(G, OC, [&](int g, int oc) {
347         if (diff_bias) {
348             // XXX: loss of precision when bias is a float...
349             acc_data_t db = 0;
350             ker_bias(db, g, oc);
351             diff_bias[diff_bias_d.off(g*OC+oc)]
352                 = saturate<diff_wei_data_t>(db);
353         }
354
355         for (int ic = 0; ic < IC; ++ic)
356         for (int kd = 0; kd < KD; ++kd)
357         for (int kh = 0; kh < KH; ++kh)
358         for (int kw = 0; kw < KW; ++kw) {
359             acc_data_t dw = 0;
360             ker(dw, g, oc, ic, kd, kh, kw);
361
362             if (ndims == 5) {
363                 auto idx = with_groups
364                     ? diff_weights_d.off(g, oc, ic, kd, kh, kw)
365                     : diff_weights_d.off(oc, ic, kd, kh, kw);
366                     diff_weights[idx] = saturate<diff_wei_data_t>(dw);
367             } else if (ndims == 4) {
368                 auto idx = with_groups
369                     ? diff_weights_d.off(g, oc, ic, kh, kw)
370                     : diff_weights_d.off(oc, ic, kh, kw);
371                     diff_weights[idx] = saturate<diff_wei_data_t>(dw);
372             } else if (ndims == 3) {
373                 auto idx = with_groups
374                     ? diff_weights_d.off(g, oc, ic, kw)
375                     : diff_weights_d.off(oc, ic, kw);
376                     diff_weights[idx] = saturate<diff_wei_data_t>(dw);
377             } else {
378                  assert(false);
379             }
380         }
381     });
382 }
383
384 using namespace data_type;
385
386 template struct ref_convolution_fwd_t<f32>;
387 template struct ref_convolution_fwd_t<s16, s16, s32, s32>;
388
389 template struct ref_convolution_fwd_t<u8, s8, f32, s32>;
390 template struct ref_convolution_fwd_t<u8, s8, s32, s32>;
391 template struct ref_convolution_fwd_t<u8, s8, s8, s32>;
392 template struct ref_convolution_fwd_t<u8, s8, u8, s32>;
393
394 template struct ref_convolution_bwd_data_t<f32, f32, f32, f32>;
395 template struct ref_convolution_bwd_data_t<s32, s16, s16, s32>;
396
397 template struct ref_convolution_bwd_data_t<f32, s8, u8, s32>;
398 template struct ref_convolution_bwd_data_t<s32, s8, u8, s32>;
399 template struct ref_convolution_bwd_data_t<s8, s8, u8, s32>;
400 template struct ref_convolution_bwd_data_t<u8, s8, u8, s32>;
401
402 template struct ref_convolution_bwd_weights_t<f32, f32, f32, f32>;
403 template struct ref_convolution_bwd_weights_t<s16, s32, s16, s32>;
404
405 }
406 }
407 }
408
409 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s