1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "c_types_map.hpp"
18 #include "type_helpers.hpp"
19 #include "mkldnn_thread.hpp"
20 #include "mkldnn_traits.hpp"
21 #include "math_utils.hpp"
23 #include "ref_convolution.hpp"
32 template <data_type_t src_type, data_type_t wei_type,
33 data_type_t dst_type, data_type_t acc_type>
34 void ref_convolution_fwd_t<src_type, wei_type, dst_type, acc_type>
35 ::execute_forward() const {
36 auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
37 auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
38 auto bias = reinterpret_cast<const char *>(this->input_memory(2));
39 auto dst = reinterpret_cast<dst_data_t *>(this->memory());
41 const memory_desc_wrapper src_d(pd()->src_pd());
42 const memory_desc_wrapper dst_d(pd()->dst_pd());
43 const memory_desc_wrapper weights_d(pd()->weights_pd(0));
44 const memory_desc_wrapper bias_d(pd()->weights_pd(1));
46 const bool with_groups = pd()->with_groups();
48 const int G = pd()->G();
49 const int MB = pd()->MB();
50 const int OD = pd()->OD();
51 const int OH = pd()->OH();
52 const int OW = pd()->OW();
53 const int ID = pd()->ID();
54 const int IH = pd()->IH();
55 const int IW = pd()->IW();
57 const int OC = pd()->OC() / G;
58 const int IC = pd()->IC() / G;
59 const int KD = pd()->KD();
60 const int KH = pd()->KH();
61 const int KW = pd()->KW();
63 const int KSD = pd()->KSD();
64 const int KSH = pd()->KSH();
65 const int KSW = pd()->KSW();
67 const int KDD = pd()->KDD();
68 const int KDH = pd()->KDH();
69 const int KDW = pd()->KDW();
71 const int padFront = pd()->padFront();
72 const int padT = pd()->padT();
73 const int padL = pd()->padL();
75 const bool with_relu = 0; // TODO: change if support post_ops
76 const float nslope = 0.f;
78 const int ndims = pd()->desc()->src_desc.ndims;
80 auto ker = [=](int g, int mb, int oc, int od, int oh,
83 for (int ic = 0; ic < IC; ++ic)
84 for (int kd = 0; kd < KD; ++kd)
85 for (int kh = 0; kh < KH; ++kh)
86 for (int kw = 0; kw < KW; ++kw) {
87 const int id = od * KSD - padFront + kd * (1 + KDD);
88 const int ih = oh * KSH - padT + kh * (1 + KDH);
89 const int iw = ow * KSW - padL + kw * (1 + KDW);
91 if (id < 0 || id >= ID) continue;
92 if (ih < 0 || ih >= IH) continue;
93 if (iw < 0 || iw >= IW) continue;
96 d += (acc_data_t)src[src_d.off(mb, g*IC + ic, id, ih, iw)]
98 ? weights[weights_d.off(g, oc, ic, kd, kh, kw)]
99 : weights[weights_d.off(oc, ic, kd, kh, kw)]);
101 d += (acc_data_t)src[src_d.off(mb, g*IC + ic, ih, iw)]
103 ? weights[weights_d.off(g, oc, ic, kh, kw)]
104 : weights[weights_d.off(oc, ic, kh, kw)]);
106 d += (acc_data_t)src[src_d.off(mb, g*IC + ic, iw)]
108 ? weights[weights_d.off(g, oc, ic, kw)]
109 : weights[weights_d.off(oc, ic, kw)]);
117 parallel_nd(G, MB, OC, OD, OH, OW,
118 [&](int g, int mb, int oc, int od, int oh, int ow) {
119 float a_fp = ker(g, mb, oc, od, oh, ow);
122 a_fp += get_bias(bias, bias_d.off(g * OC + oc),
123 pd()->desc()->bias_desc.data_type);
125 if (with_relu && a_fp < 0)
128 if (data_traits<dst_data_t>::data_type != data_type::f32) {
129 switch (pd()->attr()->round_mode_) {
130 case round_mode::down: a_fp = floorf(a_fp); break;
131 case round_mode::nearest: a_fp = nearbyintf(a_fp); break;
136 dst[dst_d.off(mb, g*OC + oc, od, oh, ow)] = saturate<dst_data_t>(a_fp);
138 dst[dst_d.off(mb, g*OC + oc, oh, ow)] = saturate<dst_data_t>(a_fp);
140 dst[dst_d.off(mb, g*OC + oc, ow)] = saturate<dst_data_t>(a_fp);
146 template <data_type_t diff_src_type, data_type_t wei_type,
147 data_type_t diff_dst_type, data_type_t acc_type>
148 void ref_convolution_bwd_data_t<diff_src_type, wei_type, diff_dst_type,
149 acc_type>::execute_backward_data() const {
150 auto diff_dst = reinterpret_cast<const diff_dst_data_t*>(
151 this->input_memory(0));
152 auto weights = reinterpret_cast<const wei_data_t*>(this->input_memory(1));
153 auto bias = reinterpret_cast<const char *>(this->input_memory(2));
154 auto diff_src = reinterpret_cast<diff_src_data_t*>(this->memory());
156 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
157 const memory_desc_wrapper diff_src_d(pd()->diff_src_pd());
158 const memory_desc_wrapper weights_d(pd()->weights_pd(0));
159 const memory_desc_wrapper bias_d(pd()->weights_pd(1));
161 const bool with_groups = pd()->with_groups();
163 const int G = pd()->G();
164 const int MB = pd()->MB();
165 const int OD = pd()->OD();
166 const int OH = pd()->OH();
167 const int OW = pd()->OW();
168 const int ID = pd()->ID();
169 const int IH = pd()->IH();
170 const int IW = pd()->IW();
172 const int OC = pd()->OC() / G;
173 const int IC = pd()->IC() / G;
174 const int KD = pd()->KD();
175 const int KH = pd()->KH();
176 const int KW = pd()->KW();
178 const int KSD = pd()->KSD();
179 const int KSH = pd()->KSH();
180 const int KSW = pd()->KSW();
182 const int KDD = pd()->KDD();
183 const int KDH = pd()->KDH();
184 const int KDW = pd()->KDW();
186 const int padFront = pd()->padFront();
187 const int padT = pd()->padT();
188 const int padL = pd()->padL();
190 const int ndims = pd()->desc()->diff_src_desc.ndims;
192 auto ker = [=](int g, int mb, int ic, int id, int ih,
195 for (int oc = 0; oc < OC; ++oc)
196 for (int kd = 0; kd < KD; ++kd)
197 for (int kh = 0; kh < KH; ++kh)
198 for (int kw = 0; kw < KW; ++kw) {
199 if (iw + padL < kw * (1 + KDW)
200 || ih + padT < kh * (1 + KDH)
201 || id + padFront < kd * (1 + KDD))
203 int ow = iw - kw * (1 + KDW) + padL;
204 int oh = ih - kh * (1 + KDH) + padT;
205 int od = id - kd * (1 + KDD) + padFront;
206 if (ow % KSW != 0 || oh % KSH != 0 || od % KSD != 0)
213 if (od < OD && oh < OH && ow < OW) {
215 d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC
216 + oc, od, oh, ow)] * (with_groups
217 ? weights[weights_d.off(g, oc, ic, kd, kh, kw)]
218 : weights[weights_d.off(oc, ic, kd, kh, kw)]);
220 d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC
221 + oc, oh, ow)] * (with_groups
222 ? weights[weights_d.off(g, oc, ic, kh, kw)]
223 : weights[weights_d.off(oc, ic, kh, kw)]);
225 d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC
226 + oc, ow)] * (with_groups
227 ? weights[weights_d.off(g, oc, ic, kw)]
228 : weights[weights_d.off(oc, ic, kw)]);
236 parallel_nd(G, MB, IC, ID, IH, IW,
237 [&](int g, int mb, int ic, int id, int ih, int iw) {
238 auto ds_idx = (ndims == 5)
239 ? diff_src_d.off(mb, g*IC + ic, id, ih, iw)
241 ? diff_src_d.off(mb, g*IC + ic, ih, iw)
242 : diff_src_d.off(mb, g*IC + ic, iw);
244 ? get_bias(bias, bias_d.off(g * IC + ic),
245 pd()->desc()->bias_desc.data_type)
247 a += ker(g, mb, ic, id, ih, iw);
248 diff_src[ds_idx] = saturate<diff_src_data_t>(a);
252 template <data_type_t src_type, data_type_t diff_wei_type,
253 data_type_t diff_dst_type, data_type_t acc_type>
254 void ref_convolution_bwd_weights_t<src_type, diff_wei_type, diff_dst_type,
255 acc_type>::execute_backward_weights() const {
256 auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
257 auto diff_dst = reinterpret_cast<const diff_dst_data_t *>(
258 this->input_memory(1));
259 auto diff_weights = reinterpret_cast<diff_wei_data_t*>(this->memory(0));
260 auto diff_bias = reinterpret_cast<diff_wei_data_t *>(this->memory(1));
262 const memory_desc_wrapper src_d(pd()->src_pd());
263 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
264 const memory_desc_wrapper diff_weights_d(pd()->diff_weights_pd(0));
265 const memory_desc_wrapper diff_bias_d(pd()->diff_weights_pd(1));
267 const bool with_groups = pd()->with_groups();
269 const int G = pd()->G();
270 const int MB = pd()->MB();
271 const int OD = pd()->OD();
272 const int OH = pd()->OH();
273 const int OW = pd()->OW();
274 const int ID = pd()->ID();
275 const int IH = pd()->IH();
276 const int IW = pd()->IW();
278 const int OC = pd()->OC() / G;
279 const int IC = pd()->IC() / G;
280 const int KD = pd()->KD();
281 const int KH = pd()->KH();
282 const int KW = pd()->KW();
284 const int KSD = pd()->KSD();
285 const int KSH = pd()->KSH();
286 const int KSW = pd()->KSW();
288 const int KDD = pd()->KDD();
289 const int KDH = pd()->KDH();
290 const int KDW = pd()->KDW();
292 const int padFront = pd()->padFront();
293 const int padT = pd()->padT();
294 const int padL = pd()->padL();
296 const int ndims = pd()->desc()->src_desc.ndims;
298 auto ker = [=](acc_data_t &d, int g, int oc, int ic, int kd, int kh, int kw) {
299 for (int mb = 0; mb < MB; ++mb)
300 for (int od = 0; od < OD; ++od)
301 for (int oh = 0; oh < OH; ++oh)
302 for (int ow = 0; ow < OW; ++ow) {
303 if (ow*KSW + kw * (1 + KDW) < padL
304 || oh*KSH + kh * (1 + KDH) < padT
305 || od*KSD + kd * (1 + KDD) < padFront
306 || ow*KSW + kw * (1 + KDW) >= IW + padL
307 || oh*KSH + kh * (1 + KDH) >= IH + padT
308 || od*KSD + kd * (1 + KDD) >= ID + padFront)
311 int id = od*KSD - padFront + kd * (1 + KDD);
312 int ih = oh*KSH - padT + kh * (1 + KDH);
313 int iw = ow*KSW - padL + kw * (1 + KDW);
315 d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, od,
316 oh, ow)] * src[src_d.off(mb, g*IC + ic, id, ih, iw)];
318 d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, oh, ow)]
319 * src[src_d.off(mb, g*IC + ic, ih, iw)];
321 d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, ow)]
322 * src[src_d.off(mb, g*IC + ic, iw)];
328 auto ker_bias = [=](acc_data_t &d, int g, int oc) {
329 for (int mb = 0; mb < MB; ++mb)
330 for (int od = 0; od < OD; ++od)
331 for (int oh = 0; oh < OH; ++oh)
332 for (int ow = 0; ow < OW; ++ow) {
334 d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, od, oh,
337 d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, oh,
340 d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, ow)];
346 parallel_nd(G, OC, [&](int g, int oc) {
348 // XXX: loss of precision when bias is a float...
351 diff_bias[diff_bias_d.off(g*OC+oc)]
352 = saturate<diff_wei_data_t>(db);
355 for (int ic = 0; ic < IC; ++ic)
356 for (int kd = 0; kd < KD; ++kd)
357 for (int kh = 0; kh < KH; ++kh)
358 for (int kw = 0; kw < KW; ++kw) {
360 ker(dw, g, oc, ic, kd, kh, kw);
363 auto idx = with_groups
364 ? diff_weights_d.off(g, oc, ic, kd, kh, kw)
365 : diff_weights_d.off(oc, ic, kd, kh, kw);
366 diff_weights[idx] = saturate<diff_wei_data_t>(dw);
367 } else if (ndims == 4) {
368 auto idx = with_groups
369 ? diff_weights_d.off(g, oc, ic, kh, kw)
370 : diff_weights_d.off(oc, ic, kh, kw);
371 diff_weights[idx] = saturate<diff_wei_data_t>(dw);
372 } else if (ndims == 3) {
373 auto idx = with_groups
374 ? diff_weights_d.off(g, oc, ic, kw)
375 : diff_weights_d.off(oc, ic, kw);
376 diff_weights[idx] = saturate<diff_wei_data_t>(dw);
384 using namespace data_type;
386 template struct ref_convolution_fwd_t<f32>;
387 template struct ref_convolution_fwd_t<s16, s16, s32, s32>;
389 template struct ref_convolution_fwd_t<u8, s8, f32, s32>;
390 template struct ref_convolution_fwd_t<u8, s8, s32, s32>;
391 template struct ref_convolution_fwd_t<u8, s8, s8, s32>;
392 template struct ref_convolution_fwd_t<u8, s8, u8, s32>;
394 template struct ref_convolution_bwd_data_t<f32, f32, f32, f32>;
395 template struct ref_convolution_bwd_data_t<s32, s16, s16, s32>;
397 template struct ref_convolution_bwd_data_t<f32, s8, u8, s32>;
398 template struct ref_convolution_bwd_data_t<s32, s8, u8, s32>;
399 template struct ref_convolution_bwd_data_t<s8, s8, u8, s32>;
400 template struct ref_convolution_bwd_data_t<u8, s8, u8, s32>;
402 template struct ref_convolution_bwd_weights_t<f32, f32, f32, f32>;
403 template struct ref_convolution_bwd_weights_t<s16, s32, s16, s32>;
409 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s