1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
18 #include "conv/conv_common.hpp"
22 template <typename Telem, size_t Tdims>
23 struct array_offset_calculator {
24 template <typename... Targs>
25 array_offset_calculator(Telem *base, Targs... Fargs) : _dims{ Fargs... }
29 template <typename... Targs>
30 inline Telem &operator()(Targs... Fargs)
32 return *(_base_ptr + _offset(1, Fargs...));
36 template <typename... Targs>
37 inline size_t _offset(size_t const dimension, size_t element)
42 template <typename... Targs>
43 inline size_t _offset(size_t const dimension, size_t theta, size_t element)
45 return element + (_dims[dimension] * theta);
48 template <typename... Targs>
49 inline size_t _offset(size_t const dimension, size_t theta, size_t element,
52 size_t t_prime = element + (_dims[dimension] * theta);
53 return _offset(dimension + 1, t_prime, Fargs...);
57 const int _dims[Tdims];
60 void trans_I_4x4_3x3(float Iw[6][6], float I[6][6]) {
69 for (int i = 0; i < 6; i++) {
70 t0 = I[2][i] * -2.25f + I[4][i];
71 t1 = I[1][i] * -2.25f + I[3][i];
72 t2 = I[2][i] * -0.390625f + I[4][i];
73 t3 = I[1][i] * -0.390625f + I[3][i];
74 t4 = I[0][i] * 0.87890625f + I[4][i];
75 t5 = I[1][i] * 0.87890625f + I[5][i];
77 T[0][i] = I[2][i] * -2.640625f + t4;
78 T[1][i] = t1 * 0.625f + t0;
79 T[2][i] = t1 * -0.625f + t0;
80 T[3][i] = t3 * 1.5f + t2;
81 T[4][i] = t3 * -1.5f + t2;
82 T[5][i] = I[3][i] * -2.640625f + t5;
85 for (int i = 0; i < 6; i++) {
86 t0 = T[i][2] * -2.25f + T[i][4];
87 t1 = T[i][1] * -2.25f + T[i][3];
88 t2 = T[i][2] * -0.390625f + T[i][4];
89 t3 = T[i][1] * -0.390625f + T[i][3];
90 t4 = T[i][0] * 0.87890625f + T[i][4];
91 t5 = T[i][1] * 0.87890625f + T[i][5];
93 Iw[i][0] = T[i][2] * -2.640625f + t4;
94 Iw[i][1] = t1 * 0.625f + t0;
95 Iw[i][2] = t1 * -0.625f + t0;
96 Iw[i][3] = t3 * 1.5f + t2;
97 Iw[i][4] = t3 * -1.5f + t2;
98 Iw[i][5] = T[i][3] * -2.640625f + t5;
102 void trans_W_4x4_3x3(float Fw_[6][6], float F[3][3]) {
109 for (int i = 0; i < 3; i++) {
110 t0 = 0.26890756302521f * F[2][i];
111 t1 = -t0 - 0.688403361344538f * F[0][i];
112 t2 = t0 + 0.119514472455649f * F[0][i];
114 T[0][i] = 1.13777777777778f * F[0][i];
115 T[1][i] = t1 - 0.430252100840336f * F[1][i];
116 T[2][i] = t1 + 0.430252100840336f * F[1][i];
117 T[3][i] = t2 + 0.179271708683473f * F[1][i];
118 T[4][i] = t2 - 0.179271708683473f * F[1][i];
122 for (int i = 0; i < 6; i++) {
123 t0 = 0.26890756302521f * T[i][2];
124 t1 = -t0 - 0.688403361344538f * T[i][0];
125 t2 = t0 + 0.119514472455649f * T[i][0];
127 Fw[0] = 1.13777777777778f * T[i][0];
128 Fw[1] = t1 - 0.430252100840336f * T[i][1];
129 Fw[2] = t1 + 0.430252100840336f * T[i][1];
130 Fw[3] = t2 + 0.179271708683473f * T[i][1];
131 Fw[4] = t2 - 0.179271708683473f * T[i][1];
133 for (int l = 0; l < 6; l++) {
139 void trans_O_4x4_3x3(float Mw[6][6], float O[4][4]) {
146 for (int i = 0; i < 6; i++) {
147 t0 = Mw[1][i] + Mw[2][i];
148 t1 = Mw[3][i] + Mw[4][i];
149 t2 = Mw[1][i] - Mw[2][i];
150 t3 = Mw[3][i] - Mw[4][i];
152 T[0][i] = t0 + t1 + Mw[0][i];
153 T[1][i] = t2 * 0.625f + t3 * 1.5f;
154 T[2][i] = t0 * 0.390625f + t1 * 2.25f;
155 T[3][i] = t2 * 0.244140625f + t3 * 3.375f + Mw[5][i];
158 for (int i = 0; i < 4; i++) {
159 t0 = T[i][1] + T[i][2];
160 t1 = T[i][3] + T[i][4];
161 t2 = T[i][1] - T[i][2];
162 t3 = T[i][3] - T[i][4];
164 O[i][0] = t0 + t1 + T[i][0];
165 O[i][1] = t2 * 0.625f + t3 * 1.5f;
166 O[i][2] = t0 * 0.390625f + t1 * 2.25f;
167 O[i][3] = t2 * 0.244140625f + t3 * 3.375f + T[i][5];
171 void trans_W_3x3_4x4_wu(float Fw[6][6], float F[4][6]) {
179 for (int i = 0; i < 4; i++) {
180 t0 = F[2][i] * 0.26890756302521f;
181 t1 = F[0][i] * -0.688403361344538f - t0;
182 t2 = F[0][i] * 0.119514472455649f + t0;
183 t3 = F[1][i] * 0.430252100840336f + F[3][i] * 0.168067226890756f;
184 t4 = F[1][i] * 0.179271708683473f + F[3][i] * 0.403361344537815f;
186 T[0][i] = F[0][i] * 1.13777777777778f;
194 for (int i = 0; i < 6; i++) {
195 t0 = T[i][2] * 0.26890756302521f;
196 t1 = T[i][0] * -0.688403361344538f - t0;
197 t2 = T[i][0] * 0.119514472455649f + t0;
198 t3 = T[i][1] * 0.430252100840336f + T[i][3] * 0.168067226890756f;
199 t4 = T[i][1] * 0.179271708683473f + T[i][3] * 0.403361344537815f;
201 Fw[i][0] = T[i][0] * 1.13777777777778f;
210 void trans_O_3x3_4x4_wu(float Mw[6][6], float M[3][3]) {
217 for (int i = 0; i < 6; i++) {
218 t0 = Mw[1][i] + Mw[2][i];
219 t1 = Mw[3][i] + Mw[4][i];
220 t2 = t1 * 2.25f + Mw[5][i];
222 T[0][i] = Mw[0][i] + t0 + t1;
223 T[1][i] = 0.625f * (Mw[1][i] - Mw[2][i]) +
224 1.5f * (Mw[3][i] - Mw[4][i]);
225 T[2][i] = t0 * 0.390625f + t2;
227 for (int i = 0; i < 3; i++) {
228 t0 = T[i][1] + T[i][2];
229 t1 = T[i][3] + T[i][4];
230 t2 = t1 * 2.25f + T[i][5];
232 M_[0] = T[i][0] + t0 + t1;
233 M_[1] = 0.625f * (T[i][1] - T[i][2]) +
234 1.5f * (T[i][3] - T[i][4]);
235 M_[2] = t0 * 0.390625f + t2;
237 for (int k = 0; k < 3; k++) {
243 struct scratchpad_t {
252 const int out_dim = 4;
255 int init_scratchpad(const prb_t *p, scratchpad_t &sp) {
256 if (sp.out_dim != 4 || sp.alpha != 6)
259 sp.h_tiles = p->dir == FLAG_FWD ? div_up(p->oh, sp.out_dim) :
260 div_up(p->ih, sp.out_dim);
261 sp.w_tiles = p->dir == FLAG_FWD ? div_up(p->ow, sp.out_dim) :
262 div_up(p->iw, sp.out_dim);
264 sp._u_ptr = (float *)zmalloc(
265 sizeof(float) * sp.alpha * sp.alpha * p->oc * p->ic, 64);
266 sp._v_ptr = (float *)zmalloc(sizeof(float) * sp.alpha * sp.alpha * p->ic
267 * p->mb * sp.h_tiles * sp.w_tiles, 64);
268 sp._m_ptr = (float *)zmalloc(sizeof(float) * sp.alpha * sp.alpha * p->oc
269 * p->mb * sp.h_tiles * sp.w_tiles, 64);
271 if (sp._u_ptr == NULL || sp._v_ptr == NULL || sp._m_ptr == NULL)
272 return mkldnn_out_of_memory;
274 array_set((char *)sp._u_ptr,
275 sizeof(float) * sp.alpha * sp.alpha * p->oc * p->ic);
276 array_set((char *)sp._v_ptr, sizeof(float) * sp.alpha * sp.alpha * p->ic
277 * p->mb * sp.h_tiles * sp.w_tiles);
278 array_set((char *)sp._m_ptr, sizeof(float) * sp.alpha * sp.alpha * p->oc
279 * p->mb * sp.h_tiles * sp.w_tiles);
284 void free_scratchpad(scratchpad_t *sp) {
285 if(sp->_u_ptr != NULL) zfree(sp->_u_ptr);
286 if(sp->_v_ptr != NULL) zfree(sp->_v_ptr);
287 if(sp->_m_ptr != NULL) zfree(sp->_m_ptr);
290 void compute_wino_ref_fwd(const prb_t *p, dnn_mem_t &src_m, dnn_mem_t &wei_m,
291 dnn_mem_t &bia_m, dnn_mem_t &dst_m) {
293 SAFE_V(init_scratchpad(p, sp));
295 array_offset_calculator<float, 4> U(
296 sp._u_ptr, sp.alpha, sp.alpha, p->oc, p->ic);
297 array_offset_calculator<float, 6> V(sp._v_ptr, sp.alpha, sp.alpha, p->ic,
298 p->mb, sp.h_tiles, sp.w_tiles);
299 array_offset_calculator<float, 6> M(sp._m_ptr, sp.alpha, sp.alpha, p->oc,
300 p->mb, sp.h_tiles, sp.w_tiles);
302 SAFE_V(p->kh == 3 ? OK : FAIL);
303 SAFE_V(p->kw == 3 ? OK : FAIL);
305 bool with_bias = p->dir & FLAG_BIA;
306 const int t_pad = p->ph;
307 const int l_pad = p->pw;
308 const int wp_max = p->iw + l_pad;
309 const int hp_max = p->ih + t_pad;
310 const int p_dim = p->mb * sp.h_tiles * sp.w_tiles;
322 #pragma omp for collapse(4)
323 /* src_transform v <- B_t * d * B */
324 for (int img = 0; img < p->mb; img++) {
325 for (int c = 0; c < p->ic; c++) {
326 for (int hfm = 0; hfm < sp.h_tiles; hfm++) {
327 for (int wfm = 0; wfm < sp.w_tiles; wfm++) {
328 for (int j = 0; j < sp.alpha; j++) {
329 int ydim = hfm * sp.out_dim + j;
330 if ((t_pad <= ydim) && (ydim < hp_max)) {
331 for (int k = 0; k < sp.alpha; k++) {
332 int xdim = wfm * sp.out_dim + k;
333 if ((l_pad <= xdim) && (xdim < wp_max)) {
334 size_t src_off = src_off_f(p, img, 0, c, 0,
335 ydim - t_pad, xdim - l_pad);
336 I[j][k] = ((float *)src_m)[src_off];
342 for (int k = 0; k < sp.alpha; k++) {
348 trans_I_4x4_3x3(_v, I);
351 for (int j = 0; j < sp.alpha; j++) {
352 for (int k = 0; k < sp.alpha; k++) {
353 V(j, k, c, img, hfm, wfm) = _v[j][k];
361 #pragma omp for collapse(2)
362 /* wei_transform u <- G * g * G_t */
363 for (int oc = 0; oc < p->oc; ++oc) {
364 for (int ic = 0; ic < p->ic; ++ic) {
365 for (int j = 0; j < p->kh; j++) {
366 for (int i = 0; i < p->kw; i++) {
367 size_t wei_off = wei_off_f(p, 0, oc, ic, 0, j, i);
368 F[j][i] = ((float *)wei_m)[wei_off];
372 trans_W_4x4_3x3(_u, F);
375 for (int j = 0; j < sp.alpha; j++) {
376 for (int k = 0; k < sp.alpha; k++) {
377 U(j, k, oc, ic) = _u[j][k];
383 #pragma omp for collapse(2)
385 for (int j = 0; j < sp.alpha; ++j) {
386 for (int k = 0; k < sp.alpha; ++k) {
387 gemm("C", "N", "N", p->oc, p_dim, p->ic, 1.0,
388 (float *)&(U(j, k, 0, 0)), p->ic,
389 (float *)&(V(j, k, 0, 0, 0, 0)), p_dim, 1.0,
390 (float *)&(M(j, k, 0, 0, 0, 0)), p_dim);
394 #pragma omp for collapse(4)
396 for (int oc = 0; oc < p->oc; ++oc) {
397 for (int img = 0; img < p->mb; ++img) {
398 for (int hfm = 0; hfm < sp.h_tiles; ++hfm) {
399 for (int wfm = 0; wfm < sp.w_tiles; ++wfm) {
400 for (int j = 0; j < sp.alpha; j++) {
401 for (int k = 0; k < sp.alpha; k++) {
402 _m[j][k] = M(j, k, oc, img, hfm, wfm);
405 trans_O_4x4_3x3(_m, O);
407 for (int j = 0; j < sp.out_dim; j++) {
408 int ydim = hfm * sp.out_dim + j;
410 for (int k = 0; k < sp.out_dim; k++) {
412 float conv_res = O[j][k];
414 int xdim = wfm * sp.out_dim + k;
416 const size_t dst_off = dst_off_f(
417 p, img, 0, oc, 0, ydim, xdim);
418 float &dst = ((float *)dst_m)[dst_off];
420 const size_t bia_off = bia_off_f(p, 0, oc);
421 conv_res += with_bias ?
422 ((float *)bia_m)[bia_off] :
425 const auto &ops = p->attr.post_ops;
426 for (int idx = 0; idx < ops.len; ++idx) {
427 using pk = attr_t::post_ops_t::kind_t;
428 const auto &e = ops.entry[idx];
431 conv_res += e.sum.scale * dst;
434 conv_res = e.eltwise.scale
435 * (conv_res < 0 ? 0 :
440 "attr::post_ops::kind");
455 free_scratchpad(&sp);
458 void compute_wino_ref_bwd_d(const prb_t *p, dnn_mem_t &diff_src_m,
459 dnn_mem_t &wei_m, dnn_mem_t &bia_m, dnn_mem_t &diff_dst_m) {
461 SAFE_V(init_scratchpad(p, sp));
463 array_offset_calculator<float, 4> U(
464 sp._u_ptr, sp.alpha, sp.alpha, p->ic, p->oc);
465 array_offset_calculator<float, 6> V(sp._m_ptr, sp.alpha, sp.alpha, p->oc,
466 p->mb, sp.h_tiles, sp.w_tiles);
467 array_offset_calculator<float, 6> M(sp._v_ptr, sp.alpha, sp.alpha, p->ic,
468 p->mb, sp.h_tiles, sp.w_tiles);
470 SAFE_V(p->kh == 3 ? OK : FAIL);
471 SAFE_V(p->kw == 3 ? OK : FAIL);
473 const int r_pad = MAX2(0, p->ow - 1 + p->kw - p->iw - p->pw);
474 const int l_pad = p->iw + r_pad - p->ow;
475 const int t_pad = p->ih + p->ph - p->oh;
476 const int wp_max = p->ow + l_pad;
477 const int hp_max = p->oh + t_pad;
478 const int p_dim = p->mb * sp.h_tiles * sp.w_tiles;
480 bool with_bias = p->dir & FLAG_BIA;
492 #pragma omp for collapse(4)
493 /* diff_src transform v <- B_t * d * B */
494 for (int img = 0; img < p->mb; img++) {
495 for (int c = 0; c < p->oc; c++) {
496 for (int hfm = 0; hfm < sp.h_tiles; hfm++) {
497 for (int wfm = 0; wfm < sp.w_tiles; wfm++) {
499 for (int j = 0; j < sp.alpha; j++) {
500 int ydim = hfm * sp.out_dim + j;
501 if ((t_pad <= ydim) && (ydim < hp_max)) {
502 for (int k = 0; k < sp.alpha; k++) {
503 int xdim = wfm * sp.out_dim + k;
504 if ((l_pad <= xdim) && (xdim < wp_max)) {
505 size_t dst_off = dst_off_f(p, img, 0, c, 0,
506 ydim - t_pad, xdim - l_pad);
507 I[j][k] = ((float *)diff_dst_m)[dst_off];
513 for (int k = 0; k < sp.alpha; k++) {
519 trans_I_4x4_3x3(_v, I);
522 for (int j = 0; j < sp.alpha; j++) {
523 for (int k = 0; k < sp.alpha; k++) {
524 V(j, k, c, img, hfm, wfm) = _v[j][k];
532 #pragma omp for collapse(2)
533 /* wei_transform u <- G * g * G_t */
534 for (int ic = 0; ic < p->ic; ++ic) {
535 for (int oc = 0; oc < p->oc; ++oc) {
536 for (int j = 0; j < p->kh; j++) {
537 for (int i = 0; i < p->kw; i++) {
538 size_t wei_off = wei_off_f(
539 p, 0, oc, ic, 0, p->kh - j - 1, p->kw - i - 1);
540 F[j][i] = ((float *)wei_m)[wei_off];
543 trans_W_4x4_3x3(_u, F);
546 for (int j = 0; j < sp.alpha; j++) {
547 for (int k = 0; k < sp.alpha; k++) {
548 U(j, k, ic, oc) = _u[j][k];
554 #pragma omp for collapse(2)
556 for (int j = 0; j < sp.alpha; ++j) {
557 for (int k = 0; k < sp.alpha; ++k) {
558 gemm("C", "N", "N", p->ic, p_dim, p->oc, 1.0,
559 (float *)&(U(j, k, 0, 0)), p->oc,
560 (float *)&(V(j, k, 0, 0, 0, 0)), p_dim, 1.0,
561 (float *)&(M(j, k, 0, 0, 0, 0)), p_dim);
565 #pragma omp for collapse(4)
566 /* diff_dst: Y = A_t *m * A */
567 for (int c = 0; c < p->ic; ++c) {
568 for (int img = 0; img < p->mb; ++img) {
569 for (int hfm = 0; hfm < sp.h_tiles; ++hfm) {
570 for (int wfm = 0; wfm < sp.w_tiles; ++wfm) {
571 for (int j = 0; j < sp.alpha; j++) {
572 for (int k = 0; k < sp.alpha; k++) {
573 _m[j][k] = M(j, k, c, img, hfm, wfm);
576 trans_O_4x4_3x3(_m, O);
578 float bia = with_bias ? ((float *)bia_m)[c] : 0.f;
580 for (int j = 0; j < sp.out_dim; j++) {
581 int ydim = hfm * sp.out_dim + j;
583 for (int k = 0; k < sp.out_dim; k++) {
584 int xdim = wfm * sp.out_dim + k;
586 size_t src_off = src_off_f(
587 p, img, 0, c, 0, ydim, xdim);
588 ((float *)diff_src_m)[src_off] = O[j][k]
600 free_scratchpad(&sp);
603 void compute_wino_ref_bwd_w(const prb_t *p, dnn_mem_t &src_m,
604 dnn_mem_t &diff_wei_m, dnn_mem_t &diff_bia_m, dnn_mem_t &diff_dst_m) {
606 SAFE_V(init_scratchpad(p, sp));
608 array_offset_calculator<float, 4> U(
609 sp._u_ptr, sp.alpha, sp.alpha, p->oc, p->ic);
610 array_offset_calculator<float, 6> V(sp._v_ptr, sp.alpha, sp.alpha, p->mb,
611 sp.h_tiles, sp.w_tiles, p->ic);
612 array_offset_calculator<float, 6> M(sp._m_ptr, sp.alpha, sp.alpha, p->oc,
613 p->mb, sp.h_tiles, sp.w_tiles);
615 SAFE_V(p->kh == 3 ? OK : FAIL);
616 SAFE_V(p->kw == 3 ? OK : FAIL);
618 const int t_pad = p->ph;
619 const int l_pad = p->pw;
620 const int wp_max = p->iw + l_pad;
621 const int hp_max = p->ih + t_pad;
622 const int p_dim = p->mb * sp.h_tiles * sp.w_tiles;
634 #pragma omp for collapse(4)
635 /* src transform v <- B_t * d * B */
636 for (int img = 0; img < p->mb; img++) {
637 for (int hfm = 0; hfm < sp.h_tiles; hfm++) {
638 for (int wfm = 0; wfm < sp.w_tiles; wfm++) {
639 for (int ic = 0; ic < p->ic; ic++) {
640 for (int j = 0; j < sp.alpha; j++) {
641 int ydim = hfm * sp.out_dim + j;
642 if ((t_pad <= ydim) && (ydim < hp_max)) {
643 for (int k = 0; k < sp.alpha; k++) {
644 int xdim = wfm * sp.out_dim + k;
645 if ((l_pad <= xdim) && (xdim < wp_max)) {
646 size_t src_off = src_off_f(p, img, 0, ic, 0,
647 ydim - t_pad, xdim - l_pad);
648 I[j][k] = ((float *)src_m)[src_off];
654 for (int k = 0; k < sp.alpha; k++) {
660 trans_I_4x4_3x3(_v, I);
663 for (int j = 0; j < sp.alpha; j++) {
664 for (int k = 0; k < sp.alpha; k++) {
665 V(j, k, img, hfm, wfm, ic) = _v[j][k];
673 #pragma omp for collapse(4)
674 /* diff_dst transform */
675 for (int oc = 0; oc < p->oc; oc++) {
676 for (int img = 0; img < p->mb; img++) {
677 for (int hfm = 0; hfm < sp.h_tiles; hfm++) {
678 for (int wfm = 0; wfm < sp.w_tiles; wfm++) {
679 for (int j = 0; j < sp.alpha; j++) {
680 int ydim = hfm * sp.out_dim + j;
682 for (int k = 0; k < sp.alpha; k++) {
683 int xdim = wfm * sp.out_dim + k;
685 size_t dst_off = dst_off_f(
686 p, img, 0, oc, 0, ydim, xdim);
687 O[j][k] = ((float *)diff_dst_m)[dst_off];
693 for (int k = 0; k < sp.alpha; k++) {
698 trans_W_3x3_4x4_wu(_m, O);
701 for (int j = 0; j < sp.alpha; j++) {
702 for (int k = 0; k < sp.alpha; k++) {
703 M(j, k, oc, img, hfm, wfm) = _m[j][k];
711 #pragma omp for collapse(2)
713 for (int j = 0; j < sp.alpha; ++j) {
714 for (int k = 0; k < sp.alpha; ++k) {
715 gemm("C", "N", "N", p->oc, p->ic, p_dim, 1.0,
716 (float *)&(M(j, k, 0, 0, 0, 0)), p_dim,
717 (float *)&(V(j, k, 0, 0, 0, 0)), p->ic, 1.0,
718 (float *)&(U(j, k, 0, 0)), p->ic);
722 #pragma omp for collapse(2)
723 for (int oc = 0; oc < p->oc; ++oc) {
724 for (int ic = 0; ic < p->ic; ++ic) {
725 for (int j = 0; j < sp.alpha; j++) {
726 for (int k = 0; k < sp.alpha; k++) {
727 F[j][k] = U(j, k, oc, ic);
731 trans_O_3x3_4x4_wu(F, _u);
734 for (int kh = 0; kh < p->kh; kh++) {
735 for (int kw = 0; kw < p->kw; kw++) {
736 size_t wei_off = wei_off_f(p, 0, oc, ic, 0, kh, kw);
737 ((float *)diff_wei_m)[wei_off] = _u[kh][kw];
744 free_scratchpad(&sp);
746 if (p->dir & FLAG_BIA)
747 compute_ref_bwd_bias(p, diff_bia_m, diff_dst_m);