Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / benchdnn / conv / ref_wino.cpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "common.hpp"
18 #include "conv/conv_common.hpp"
19
20 namespace conv {
21
22 template <typename Telem, size_t Tdims>
23 struct array_offset_calculator {
24     template <typename... Targs>
25     array_offset_calculator(Telem *base, Targs... Fargs) : _dims{ Fargs... }
26     {
27         _base_ptr = base;
28     }
29     template <typename... Targs>
30     inline Telem &operator()(Targs... Fargs)
31     {
32         return *(_base_ptr + _offset(1, Fargs...));
33     }
34
35 private:
36     template <typename... Targs>
37     inline size_t _offset(size_t const dimension, size_t element)
38     {
39         return element;
40     }
41
42     template <typename... Targs>
43     inline size_t _offset(size_t const dimension, size_t theta, size_t element)
44     {
45         return element + (_dims[dimension] * theta);
46     }
47
48     template <typename... Targs>
49     inline size_t _offset(size_t const dimension, size_t theta, size_t element,
50             Targs... Fargs)
51     {
52         size_t t_prime = element + (_dims[dimension] * theta);
53         return _offset(dimension + 1, t_prime, Fargs...);
54     }
55
56     Telem *_base_ptr;
57     const int _dims[Tdims];
58 };
59
60 void trans_I_4x4_3x3(float Iw[6][6], float I[6][6]) {
61     float T[6][6];
62     float t0;
63     float t1;
64     float t2;
65     float t3;
66     float t4;
67     float t5;
68
69     for (int i = 0; i < 6; i++) {
70         t0 = I[2][i] * -2.25f + I[4][i];
71         t1 = I[1][i] * -2.25f + I[3][i];
72         t2 = I[2][i] * -0.390625f + I[4][i];
73         t3 = I[1][i] * -0.390625f + I[3][i];
74         t4 = I[0][i] * 0.87890625f + I[4][i];
75         t5 = I[1][i] * 0.87890625f + I[5][i];
76
77         T[0][i] = I[2][i] * -2.640625f + t4;
78         T[1][i] = t1 * 0.625f + t0;
79         T[2][i] = t1 * -0.625f + t0;
80         T[3][i] = t3 * 1.5f + t2;
81         T[4][i] = t3 * -1.5f + t2;
82         T[5][i] = I[3][i] * -2.640625f + t5;
83     }
84
85     for (int i = 0; i < 6; i++) {
86         t0 = T[i][2] * -2.25f + T[i][4];
87         t1 = T[i][1] * -2.25f + T[i][3];
88         t2 = T[i][2] * -0.390625f + T[i][4];
89         t3 = T[i][1] * -0.390625f + T[i][3];
90         t4 = T[i][0] * 0.87890625f + T[i][4];
91         t5 = T[i][1] * 0.87890625f + T[i][5];
92
93         Iw[i][0] = T[i][2] * -2.640625f + t4;
94         Iw[i][1] = t1 * 0.625f + t0;
95         Iw[i][2] = t1 * -0.625f + t0;
96         Iw[i][3] = t3 * 1.5f + t2;
97         Iw[i][4] = t3 * -1.5f + t2;
98         Iw[i][5] = T[i][3] * -2.640625f + t5;
99     }
100 }
101
102 void trans_W_4x4_3x3(float Fw_[6][6], float F[3][3]) {
103     float Fw[6];
104     float T[6][3];
105     float t0;
106     float t1;
107     float t2;
108
109     for (int i = 0; i < 3; i++) {
110         t0 = 0.26890756302521f * F[2][i];
111         t1 = -t0 - 0.688403361344538f * F[0][i];
112         t2 = t0 + 0.119514472455649f * F[0][i];
113
114         T[0][i] = 1.13777777777778f * F[0][i];
115         T[1][i] = t1 - 0.430252100840336f * F[1][i];
116         T[2][i] = t1 + 0.430252100840336f * F[1][i];
117         T[3][i] = t2 + 0.179271708683473f * F[1][i];
118         T[4][i] = t2 - 0.179271708683473f * F[1][i];
119         T[5][i] = F[2][i];
120     }
121
122     for (int i = 0; i < 6; i++) {
123         t0 = 0.26890756302521f * T[i][2];
124         t1 = -t0 - 0.688403361344538f * T[i][0];
125         t2 = t0 + 0.119514472455649f * T[i][0];
126
127         Fw[0] = 1.13777777777778f * T[i][0];
128         Fw[1] = t1 - 0.430252100840336f * T[i][1];
129         Fw[2] = t1 + 0.430252100840336f * T[i][1];
130         Fw[3] = t2 + 0.179271708683473f * T[i][1];
131         Fw[4] = t2 - 0.179271708683473f * T[i][1];
132         Fw[5] = T[i][2];
133         for (int l = 0; l < 6; l++) {
134             Fw_[i][l] = Fw[l];
135         }
136     }
137 }
138
139 void trans_O_4x4_3x3(float Mw[6][6], float O[4][4]) {
140     float T[4][6];
141     float t0;
142     float t1;
143     float t2;
144     float t3;
145
146     for (int i = 0; i < 6; i++) {
147         t0 = Mw[1][i] + Mw[2][i];
148         t1 = Mw[3][i] + Mw[4][i];
149         t2 = Mw[1][i] - Mw[2][i];
150         t3 = Mw[3][i] - Mw[4][i];
151
152         T[0][i] = t0 + t1 + Mw[0][i];
153         T[1][i] = t2 * 0.625f + t3 * 1.5f;
154         T[2][i] = t0 * 0.390625f + t1 * 2.25f;
155         T[3][i] = t2 * 0.244140625f + t3 * 3.375f + Mw[5][i];
156     }
157
158     for (int i = 0; i < 4; i++) {
159         t0 = T[i][1] + T[i][2];
160         t1 = T[i][3] + T[i][4];
161         t2 = T[i][1] - T[i][2];
162         t3 = T[i][3] - T[i][4];
163
164         O[i][0] = t0 + t1 + T[i][0];
165         O[i][1] = t2 * 0.625f + t3 * 1.5f;
166         O[i][2] = t0 * 0.390625f + t1 * 2.25f;
167         O[i][3] = t2 * 0.244140625f + t3 * 3.375f + T[i][5];
168     }
169 }
170
171 void trans_W_3x3_4x4_wu(float Fw[6][6], float F[4][6]) {
172     float T[6][4];
173     float t0;
174     float t1;
175     float t2;
176     float t3;
177     float t4;
178
179     for (int i = 0; i < 4; i++) {
180         t0 = F[2][i] * 0.26890756302521f;
181         t1 = F[0][i] * -0.688403361344538f - t0;
182         t2 = F[0][i] * 0.119514472455649f + t0;
183         t3 = F[1][i] * 0.430252100840336f + F[3][i] * 0.168067226890756f;
184         t4 = F[1][i] * 0.179271708683473f + F[3][i] * 0.403361344537815f;
185
186         T[0][i] = F[0][i] * 1.13777777777778f;
187         T[1][i] = t1 - t3;
188         T[2][i] = t1 + t3;
189         T[3][i] = t2 + t4;
190         T[4][i] = t2 - t4;
191         T[5][i] = F[3][i];
192     }
193
194     for (int i = 0; i < 6; i++) {
195         t0 = T[i][2] * 0.26890756302521f;
196         t1 = T[i][0] * -0.688403361344538f - t0;
197         t2 = T[i][0] * 0.119514472455649f + t0;
198         t3 = T[i][1] * 0.430252100840336f + T[i][3] * 0.168067226890756f;
199         t4 = T[i][1] * 0.179271708683473f + T[i][3] * 0.403361344537815f;
200
201         Fw[i][0] = T[i][0] * 1.13777777777778f;
202         Fw[i][1] = t1 - t3;
203         Fw[i][2] = t1 + t3;
204         Fw[i][3] = t2 + t4;
205         Fw[i][4] = t2 - t4;
206         Fw[i][5] = T[i][3];
207     }
208 }
209
210 void trans_O_3x3_4x4_wu(float Mw[6][6], float M[3][3]) {
211     float T[3][6];
212     float t0;
213     float t1;
214     float t2;
215     float M_[3];
216
217     for (int i = 0; i < 6; i++) {
218         t0 = Mw[1][i] + Mw[2][i];
219         t1 = Mw[3][i] + Mw[4][i];
220         t2 = t1 * 2.25f + Mw[5][i];
221
222         T[0][i] = Mw[0][i] + t0 + t1;
223         T[1][i] = 0.625f * (Mw[1][i] - Mw[2][i]) +
224             1.5f * (Mw[3][i] - Mw[4][i]);
225         T[2][i] = t0 * 0.390625f + t2;
226     }
227     for (int i = 0; i < 3; i++) {
228         t0 = T[i][1] + T[i][2];
229         t1 = T[i][3] + T[i][4];
230         t2 = t1 * 2.25f + T[i][5];
231
232         M_[0] = T[i][0] + t0 + t1;
233         M_[1] = 0.625f * (T[i][1] - T[i][2]) +
234             1.5f * (T[i][3] - T[i][4]);
235         M_[2] = t0 * 0.390625f + t2;
236
237         for (int k = 0; k < 3; k++) {
238             M[i][k] = M_[k];
239         }
240     }
241 }
242
243 struct scratchpad_t {
244     float *_u_ptr;
245     float *_m_ptr;
246     float *_v_ptr;
247
248     int h_tiles;
249     int w_tiles;
250
251     const int alpha = 6;
252     const int out_dim = 4;
253 };
254
255 int init_scratchpad(const  prb_t *p, scratchpad_t &sp) {
256     if (sp.out_dim != 4 || sp.alpha != 6)
257         return FAIL;
258
259     sp.h_tiles = p->dir == FLAG_FWD ? div_up(p->oh, sp.out_dim) :
260                                              div_up(p->ih, sp.out_dim);
261     sp.w_tiles = p->dir == FLAG_FWD ? div_up(p->ow, sp.out_dim) :
262                                              div_up(p->iw, sp.out_dim);
263
264     sp._u_ptr = (float *)zmalloc(
265             sizeof(float) * sp.alpha * sp.alpha * p->oc * p->ic, 64);
266     sp._v_ptr = (float *)zmalloc(sizeof(float) * sp.alpha * sp.alpha * p->ic
267                     * p->mb * sp.h_tiles * sp.w_tiles, 64);
268     sp._m_ptr = (float *)zmalloc(sizeof(float) * sp.alpha * sp.alpha * p->oc
269                     * p->mb * sp.h_tiles * sp.w_tiles, 64);
270
271     if (sp._u_ptr == NULL || sp._v_ptr == NULL || sp._m_ptr == NULL)
272         return mkldnn_out_of_memory;
273
274     array_set((char *)sp._u_ptr,
275             sizeof(float) * sp.alpha * sp.alpha * p->oc * p->ic);
276     array_set((char *)sp._v_ptr, sizeof(float) * sp.alpha * sp.alpha * p->ic
277                     * p->mb * sp.h_tiles * sp.w_tiles);
278     array_set((char *)sp._m_ptr, sizeof(float) * sp.alpha * sp.alpha * p->oc
279                     * p->mb * sp.h_tiles * sp.w_tiles);
280
281     return OK;
282 }
283
284 void free_scratchpad(scratchpad_t *sp) {
285     if(sp->_u_ptr != NULL) zfree(sp->_u_ptr);
286     if(sp->_v_ptr != NULL) zfree(sp->_v_ptr);
287     if(sp->_m_ptr != NULL) zfree(sp->_m_ptr);
288 }
289
290 void compute_wino_ref_fwd(const prb_t *p, dnn_mem_t &src_m, dnn_mem_t &wei_m,
291         dnn_mem_t &bia_m, dnn_mem_t &dst_m) {
292     scratchpad_t sp{};
293     SAFE_V(init_scratchpad(p, sp));
294
295     array_offset_calculator<float, 4> U(
296             sp._u_ptr, sp.alpha, sp.alpha, p->oc, p->ic);
297     array_offset_calculator<float, 6> V(sp._v_ptr, sp.alpha, sp.alpha, p->ic,
298             p->mb, sp.h_tiles, sp.w_tiles);
299     array_offset_calculator<float, 6> M(sp._m_ptr, sp.alpha, sp.alpha, p->oc,
300             p->mb, sp.h_tiles, sp.w_tiles);
301
302     SAFE_V(p->kh == 3 ? OK : FAIL);
303     SAFE_V(p->kw == 3 ? OK : FAIL);
304
305     bool with_bias = p->dir & FLAG_BIA;
306     const int t_pad = p->ph;
307     const int l_pad = p->pw;
308     const int wp_max = p->iw + l_pad;
309     const int hp_max = p->ih + t_pad;
310     const int p_dim = p->mb * sp.h_tiles * sp.w_tiles;
311
312 #pragma omp parallel
313     {
314     float I[6][6];
315     float F[3][3];
316     float O[4][4];
317
318     float _v[6][6];
319     float _u[6][6];
320     float _m[6][6];
321
322 #pragma omp for collapse(4)
323     /* src_transform v <- B_t * d * B */
324     for (int img = 0; img < p->mb; img++) {
325         for (int c = 0; c < p->ic; c++) {
326             for (int hfm = 0; hfm < sp.h_tiles; hfm++) {
327                 for (int wfm = 0; wfm < sp.w_tiles; wfm++) {
328                     for (int j = 0; j < sp.alpha; j++) {
329                         int ydim = hfm * sp.out_dim + j;
330                         if ((t_pad <= ydim) && (ydim < hp_max)) {
331                             for (int k = 0; k < sp.alpha; k++) {
332                                 int xdim = wfm * sp.out_dim + k;
333                                 if ((l_pad <= xdim) && (xdim < wp_max)) {
334                                     size_t src_off = src_off_f(p, img, 0, c, 0,
335                                             ydim - t_pad, xdim - l_pad);
336                                     I[j][k] = ((float *)src_m)[src_off];
337                                 } else {
338                                     I[j][k] = 0.f;
339                                 }
340                             }
341                         } else {
342                             for (int k = 0; k < sp.alpha; k++) {
343                                 I[j][k] = 0.f;
344                             }
345                         }
346                     }
347
348                     trans_I_4x4_3x3(_v, I);
349
350                     /* scatter v:V */
351                     for (int j = 0; j < sp.alpha; j++) {
352                         for (int k = 0; k < sp.alpha; k++) {
353                             V(j, k, c, img, hfm, wfm) = _v[j][k];
354                         }
355                     }
356                 }
357             }
358         }
359     }
360
361 #pragma omp for collapse(2)
362     /* wei_transform u <- G * g * G_t */
363     for (int oc = 0; oc < p->oc; ++oc) {
364         for (int ic = 0; ic < p->ic; ++ic) {
365             for (int j = 0; j < p->kh; j++) {
366                 for (int i = 0; i < p->kw; i++) {
367                     size_t wei_off = wei_off_f(p, 0, oc, ic, 0, j, i);
368                     F[j][i] = ((float *)wei_m)[wei_off];
369                 }
370             }
371
372             trans_W_4x4_3x3(_u, F);
373
374             /* scatter u:U */
375             for (int j = 0; j < sp.alpha; j++) {
376                 for (int k = 0; k < sp.alpha; k++) {
377                     U(j, k, oc, ic) = _u[j][k];
378                 }
379             }
380         }
381     }
382
383 #pragma omp for collapse(2)
384     /* M = U * V */
385     for (int j = 0; j < sp.alpha; ++j) {
386         for (int k = 0; k < sp.alpha; ++k) {
387             gemm("C", "N", "N", p->oc, p_dim, p->ic, 1.0,
388                     (float *)&(U(j, k, 0, 0)), p->ic,
389                     (float *)&(V(j, k, 0, 0, 0, 0)), p_dim, 1.0,
390                     (float *)&(M(j, k, 0, 0, 0, 0)), p_dim);
391         }
392     }
393
394 #pragma omp for collapse(4)
395     /* Y = A_t *m * A */
396     for (int oc = 0; oc < p->oc; ++oc) {
397         for (int img = 0; img < p->mb; ++img) {
398             for (int hfm = 0; hfm < sp.h_tiles; ++hfm) {
399                 for (int wfm = 0; wfm < sp.w_tiles; ++wfm) {
400                     for (int j = 0; j < sp.alpha; j++) {
401                         for (int k = 0; k < sp.alpha; k++) {
402                             _m[j][k] = M(j, k, oc, img, hfm, wfm);
403                         }
404                     }
405                     trans_O_4x4_3x3(_m, O);
406
407                     for (int j = 0; j < sp.out_dim; j++) {
408                         int ydim = hfm * sp.out_dim + j;
409                         if (ydim < p->oh) {
410                             for (int k = 0; k < sp.out_dim; k++) {
411
412                                 float conv_res = O[j][k];
413
414                                 int xdim = wfm * sp.out_dim + k;
415                                 if (xdim < p->ow) {
416                                     const size_t dst_off = dst_off_f(
417                                             p, img, 0, oc, 0, ydim, xdim);
418                                     float &dst = ((float *)dst_m)[dst_off];
419
420                                     const size_t bia_off = bia_off_f(p, 0, oc);
421                                     conv_res += with_bias ?
422                                             ((float *)bia_m)[bia_off] :
423                                             0.f;
424
425                                     const auto &ops = p->attr.post_ops;
426                                     for (int idx = 0; idx < ops.len; ++idx) {
427                                         using pk = attr_t::post_ops_t::kind_t;
428                                         const auto &e = ops.entry[idx];
429                                         switch (e.kind) {
430                                         case pk::SUM:
431                                             conv_res += e.sum.scale * dst;
432                                             break;
433                                         case pk::RELU:
434                                             conv_res = e.eltwise.scale
435                                                     * (conv_res < 0 ? 0 :
436                                                                       conv_res);
437                                             break;
438                                         default:
439                                             assert(!"unknown "
440                                                     "attr::post_ops::kind");
441                                         }
442                                     }
443
444                                     dst = conv_res;
445                                 }
446                             }
447                         }
448                     }
449                 }
450             }
451         }
452     }
453     }
454
455     free_scratchpad(&sp);
456 }
457
458 void compute_wino_ref_bwd_d(const prb_t *p, dnn_mem_t &diff_src_m,
459         dnn_mem_t &wei_m, dnn_mem_t &bia_m, dnn_mem_t &diff_dst_m) {
460     scratchpad_t sp{};
461     SAFE_V(init_scratchpad(p, sp));
462
463     array_offset_calculator<float, 4> U(
464             sp._u_ptr, sp.alpha, sp.alpha, p->ic, p->oc);
465     array_offset_calculator<float, 6> V(sp._m_ptr, sp.alpha, sp.alpha, p->oc,
466             p->mb, sp.h_tiles, sp.w_tiles);
467     array_offset_calculator<float, 6> M(sp._v_ptr, sp.alpha, sp.alpha, p->ic,
468             p->mb, sp.h_tiles, sp.w_tiles);
469
470     SAFE_V(p->kh == 3 ? OK : FAIL);
471     SAFE_V(p->kw == 3 ? OK : FAIL);
472
473     const int r_pad = MAX2(0, p->ow - 1 + p->kw - p->iw - p->pw);
474     const int l_pad = p->iw + r_pad - p->ow;
475     const int t_pad = p->ih + p->ph - p->oh;
476     const int wp_max = p->ow + l_pad;
477     const int hp_max = p->oh + t_pad;
478     const int p_dim = p->mb * sp.h_tiles * sp.w_tiles;
479
480     bool with_bias = p->dir & FLAG_BIA;
481
482 #pragma omp parallel
483     {
484     float I[6][6];
485     float F[3][3];
486     float O[4][4];
487
488     float _v[6][6];
489     float _u[6][6];
490     float _m[6][6];
491
492 #pragma omp for collapse(4)
493     /* diff_src transform v <- B_t * d * B */
494     for (int img = 0; img < p->mb; img++) {
495         for (int c = 0; c < p->oc; c++) {
496             for (int hfm = 0; hfm < sp.h_tiles; hfm++) {
497                 for (int wfm = 0; wfm < sp.w_tiles; wfm++) {
498
499                     for (int j = 0; j < sp.alpha; j++) {
500                         int ydim = hfm * sp.out_dim + j;
501                         if ((t_pad <= ydim) && (ydim < hp_max)) {
502                             for (int k = 0; k < sp.alpha; k++) {
503                                 int xdim = wfm * sp.out_dim + k;
504                                 if ((l_pad <= xdim) && (xdim < wp_max)) {
505                                     size_t dst_off = dst_off_f(p, img, 0, c, 0,
506                                             ydim - t_pad, xdim - l_pad);
507                                     I[j][k] = ((float *)diff_dst_m)[dst_off];
508                                 } else {
509                                     I[j][k] = 0.f;
510                                 }
511                             }
512                         } else {
513                             for (int k = 0; k < sp.alpha; k++) {
514                                 I[j][k] = 0.f;
515                             }
516                         }
517                     }
518
519                     trans_I_4x4_3x3(_v, I);
520
521                     /* scatter v:V */
522                     for (int j = 0; j < sp.alpha; j++) {
523                         for (int k = 0; k < sp.alpha; k++) {
524                             V(j, k, c, img, hfm, wfm) = _v[j][k];
525                         }
526                     }
527                 }
528             }
529         }
530     }
531
532 #pragma omp for collapse(2)
533     /* wei_transform u <- G * g * G_t */
534     for (int ic = 0; ic < p->ic; ++ic) {
535         for (int oc = 0; oc < p->oc; ++oc) {
536             for (int j = 0; j < p->kh; j++) {
537                 for (int i = 0; i < p->kw; i++) {
538                     size_t wei_off = wei_off_f(
539                             p, 0, oc, ic, 0, p->kh - j - 1, p->kw - i - 1);
540                     F[j][i] = ((float *)wei_m)[wei_off];
541                 }
542             }
543             trans_W_4x4_3x3(_u, F);
544
545             /* scatter u:U */
546             for (int j = 0; j < sp.alpha; j++) {
547                 for (int k = 0; k < sp.alpha; k++) {
548                     U(j, k, ic, oc) = _u[j][k];
549                 }
550             }
551         }
552     }
553
554 #pragma omp for collapse(2)
555     /* M = U * V */
556     for (int j = 0; j < sp.alpha; ++j) {
557         for (int k = 0; k < sp.alpha; ++k) {
558             gemm("C", "N", "N", p->ic, p_dim, p->oc, 1.0,
559                     (float *)&(U(j, k, 0, 0)), p->oc,
560                     (float *)&(V(j, k, 0, 0, 0, 0)), p_dim, 1.0,
561                     (float *)&(M(j, k, 0, 0, 0, 0)), p_dim);
562         }
563     }
564
565 #pragma omp for collapse(4)
566     /* diff_dst: Y = A_t *m * A */
567     for (int c = 0; c < p->ic; ++c) {
568         for (int img = 0; img < p->mb; ++img) {
569             for (int hfm = 0; hfm < sp.h_tiles; ++hfm) {
570                 for (int wfm = 0; wfm < sp.w_tiles; ++wfm) {
571                     for (int j = 0; j < sp.alpha; j++) {
572                         for (int k = 0; k < sp.alpha; k++) {
573                             _m[j][k] = M(j, k, c, img, hfm, wfm);
574                         }
575                     }
576                     trans_O_4x4_3x3(_m, O);
577
578                     float bia = with_bias ? ((float *)bia_m)[c] : 0.f;
579
580                     for (int j = 0; j < sp.out_dim; j++) {
581                         int ydim = hfm * sp.out_dim + j;
582                         if (ydim < p->ih) {
583                             for (int k = 0; k < sp.out_dim; k++) {
584                                 int xdim = wfm * sp.out_dim + k;
585                                 if (xdim < p->iw) {
586                                     size_t src_off = src_off_f(
587                                             p, img, 0, c, 0, ydim, xdim);
588                                     ((float *)diff_src_m)[src_off] = O[j][k]
589                                             + bia;
590                                 }
591                             }
592                         }
593                     }
594                 }
595             }
596         }
597     }
598     }
599
600     free_scratchpad(&sp);
601 }
602
603 void compute_wino_ref_bwd_w(const prb_t *p, dnn_mem_t &src_m,
604         dnn_mem_t &diff_wei_m, dnn_mem_t &diff_bia_m, dnn_mem_t &diff_dst_m) {
605     scratchpad_t sp{};
606     SAFE_V(init_scratchpad(p, sp));
607
608     array_offset_calculator<float, 4> U(
609             sp._u_ptr, sp.alpha, sp.alpha, p->oc, p->ic);
610     array_offset_calculator<float, 6> V(sp._v_ptr, sp.alpha, sp.alpha, p->mb,
611             sp.h_tiles, sp.w_tiles, p->ic);
612     array_offset_calculator<float, 6> M(sp._m_ptr, sp.alpha, sp.alpha, p->oc,
613             p->mb, sp.h_tiles, sp.w_tiles);
614
615     SAFE_V(p->kh == 3 ? OK : FAIL);
616     SAFE_V(p->kw == 3 ? OK : FAIL);
617
618     const int t_pad = p->ph;
619     const int l_pad = p->pw;
620     const int wp_max = p->iw + l_pad;
621     const int hp_max = p->ih + t_pad;
622     const int p_dim = p->mb * sp.h_tiles * sp.w_tiles;
623
624 #pragma omp parallel
625     {
626     float I[6][6];
627     float F[6][6];
628     float O[6][6];
629
630     float _v[6][6];
631     float _u[3][3];
632     float _m[6][6];
633
634 #pragma omp for collapse(4)
635     /* src transform v <- B_t * d * B */
636     for (int img = 0; img < p->mb; img++) {
637         for (int hfm = 0; hfm < sp.h_tiles; hfm++) {
638             for (int wfm = 0; wfm < sp.w_tiles; wfm++) {
639                 for (int ic = 0; ic < p->ic; ic++) {
640                     for (int j = 0; j < sp.alpha; j++) {
641                         int ydim = hfm * sp.out_dim + j;
642                         if ((t_pad <= ydim) && (ydim < hp_max)) {
643                             for (int k = 0; k < sp.alpha; k++) {
644                                 int xdim = wfm * sp.out_dim + k;
645                                 if ((l_pad <= xdim) && (xdim < wp_max)) {
646                                     size_t src_off = src_off_f(p, img, 0, ic, 0,
647                                             ydim - t_pad, xdim - l_pad);
648                                     I[j][k] = ((float *)src_m)[src_off];
649                                 } else {
650                                     I[j][k] = 0.f;
651                                 }
652                             }
653                         } else {
654                             for (int k = 0; k < sp.alpha; k++) {
655                                 I[j][k] = 0.f;
656                             }
657                         }
658                     }
659
660                     trans_I_4x4_3x3(_v, I);
661
662                     /* scatter v:V */
663                     for (int j = 0; j < sp.alpha; j++) {
664                         for (int k = 0; k < sp.alpha; k++) {
665                             V(j, k, img, hfm, wfm, ic) = _v[j][k];
666                         }
667                     }
668                 }
669             }
670         }
671     }
672
673 #pragma omp for collapse(4)
674     /* diff_dst transform */
675     for (int oc = 0; oc < p->oc; oc++) {
676         for (int img = 0; img < p->mb; img++) {
677             for (int hfm = 0; hfm < sp.h_tiles; hfm++) {
678                 for (int wfm = 0; wfm < sp.w_tiles; wfm++) {
679                     for (int j = 0; j < sp.alpha; j++) {
680                         int ydim = hfm * sp.out_dim + j;
681                         if (ydim < p->oh) {
682                             for (int k = 0; k < sp.alpha; k++) {
683                                 int xdim = wfm * sp.out_dim + k;
684                                 if (xdim < p->ow) {
685                                     size_t dst_off = dst_off_f(
686                                             p, img, 0, oc, 0, ydim, xdim);
687                                     O[j][k] = ((float *)diff_dst_m)[dst_off];
688                                 } else {
689                                     O[j][k] = 0.f;
690                                 }
691                             }
692                         } else {
693                             for (int k = 0; k < sp.alpha; k++) {
694                                 O[j][k] = 0.f;
695                             }
696                         }
697                     }
698                     trans_W_3x3_4x4_wu(_m, O);
699
700                     /* scatter v:V */
701                     for (int j = 0; j < sp.alpha; j++) {
702                         for (int k = 0; k < sp.alpha; k++) {
703                             M(j, k, oc, img, hfm, wfm) = _m[j][k];
704                         }
705                     }
706                 }
707             }
708         }
709     }
710
711 #pragma omp for collapse(2)
712     /* GeMM U = M * V */
713     for (int j = 0; j < sp.alpha; ++j) {
714         for (int k = 0; k < sp.alpha; ++k) {
715             gemm("C", "N", "N", p->oc, p->ic, p_dim, 1.0,
716                     (float *)&(M(j, k, 0, 0, 0, 0)), p_dim,
717                     (float *)&(V(j, k, 0, 0, 0, 0)), p->ic, 1.0,
718                     (float *)&(U(j, k, 0, 0)), p->ic);
719         }
720     }
721
722 #pragma omp for collapse(2)
723     for (int oc = 0; oc < p->oc; ++oc) {
724         for (int ic = 0; ic < p->ic; ++ic) {
725             for (int j = 0; j < sp.alpha; j++) {
726                 for (int k = 0; k < sp.alpha; k++) {
727                     F[j][k] = U(j, k, oc, ic);
728                 }
729             }
730
731             trans_O_3x3_4x4_wu(F, _u);
732
733             /* scatter u:U */
734             for (int kh = 0; kh < p->kh; kh++) {
735                 for (int kw = 0; kw < p->kw; kw++) {
736                     size_t wei_off = wei_off_f(p, 0, oc, ic, 0, kh, kw);
737                     ((float *)diff_wei_m)[wei_off] = _u[kh][kw];
738                 }
739             }
740         }
741     }
742     }
743
744     free_scratchpad(&sp);
745
746     if (p->dir & FLAG_BIA)
747         compute_ref_bwd_bias(p, diff_bia_m, diff_dst_m);
748 }
749
750 }