updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_common_lrn.cpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "mkldnn_types.h"
18
19 #include "c_types_map.hpp"
20 #include "jit_avx512_common_lrn.hpp"
21 #include "jit_generator.hpp"
22 #include "mkldnn_thread.hpp"
23 #include "type_helpers.hpp"
24 #include "utils.hpp"
25
26 #define FWD_RBC 4
27 #define BWD_RBC 3
28
29 #define XMM_SIZE (4*sizeof(float))
30 #define ZMM_SIZE (vlen)
31 #define BUFFER_BLOCK (XMM_SIZE + ZMM_SIZE + XMM_SIZE)
32 #define BUFFER_NEXT_OFFSET (XMM_SIZE + ZMM_SIZE)
33 #define SRC_PREV_OFFSET (vlen - XMM_SIZE)
34
35 #define IRB_LOOP(statement) for(int irb = 0; irb < loop_size; irb++) { \
36     statement;\
37 }
38
39 namespace mkldnn {
40 namespace impl {
41 namespace cpu {
42
43 using namespace mkldnn::impl::status;
44 using namespace mkldnn::impl::memory_format;
45 using namespace mkldnn::impl::utils;
46
47 using namespace Xbyak;
48
49 enum params { vsize = 16, vlen = 64};
50
51 typedef struct {
52     const float *src;
53     float *dst, *ws0, *ws1;
54 } jit_args_fwd_t;
55
56 typedef struct {
57     const float *src, *diff_dst, *ws0, *ws1;
58     float *diff_src;
59 } jit_args_bwd_t;
60
61 struct nChw16c_across {
62 /*  version:
63  *  -1: channels 0..15,
64  *   1: channels C-16 .. C-1,
65  *   0: other channels
66  *   3: channels only for this kernel(without prev and next)
67  */
68     int H, W, version;
69     nChw16c_across(int h, int w, int v) : H(h), W(w), version(v) {}
70 };
71
72 struct jit_avx512_common_lrn_fwd_t::jit_avx512_common_lrn_kernel_f32:
73        public jit_generator {
74     int HW, W;
75     bool is_first;
76     bool is_last;
77     bool is_single;
78
79     Reg64 src = rax;
80     Reg64 dst = r8;
81     Reg64 scratch0 = rdx;
82     Reg64 scratch1 = rsi;
83     Reg64 imm_addr64 = rbx;
84
85     Zmm zalpha = zmm0;
86     Xmm xalpha = xmm0;
87     Zmm zk = zmm1;
88     Xmm xk = xmm1;
89
90     Reg64 param = abi_param1;
91     Reg64 t = rsp;
92     Reg64 hw = r9;
93
94     int xsrc_prev = 2;
95     int zsrc = 7;
96     int xsrc_next = 3;
97     int zc = 7;
98
99     int za = 2;
100     int zb = 3;
101     int zd = 5;
102     int ze = 6;
103     int zsum = 4;
104     int zdst = 2;
105     int zbase = 3;
106     int zsum2 = 5;
107
108     prop_kind_t pk;
109     int use_h_parallelism;
110
111     float alpha, k;
112
113     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_lrn_kernel_f32)
114
115     void (*ker)(jit_args_fwd_t *);
116     void operator()(jit_args_fwd_t *arg) { ker(arg); }
117
118     enum {
119         prf0_offt = 1*FWD_RBC,
120         prf2_offt = 8*FWD_RBC
121     };
122
123     inline void compute_loop(int loop_size_param)
124     {
125         // loop_size - param for IRB_LOOP macro
126         int loop_size = FWD_RBC;
127
128         auto xreg = [=](int irb, int i) {
129             return Xmm(irb*3 + i);
130         };
131
132         auto zreg = [=](int irb, int i) {
133             return Zmm(irb*7 + i);
134         };
135
136         if (!is_first && !is_single) {
137             IRB_LOOP(mic_prefetcht0(ptr[src + (irb + prf0_offt - HW)*vlen]));
138             IRB_LOOP(mic_prefetcht2(ptr[src + (irb + prf2_offt - HW)*vlen]));
139         }
140         IRB_LOOP(mic_prefetcht0(EVEX_compress_addr(src, (irb + prf0_offt)*vlen)));
141         IRB_LOOP(mic_prefetcht2(EVEX_compress_addr(src, (irb + prf2_offt)*vlen)));
142         if (!is_last && !is_single) {
143             IRB_LOOP(mic_prefetcht0(ptr[src + (irb + prf0_offt + HW)*vlen]));
144             IRB_LOOP(mic_prefetcht2(ptr[src + (irb + prf2_offt + HW)*vlen]));
145         }
146         if (pk != prop_kind::forward_inference) {
147             IRB_LOOP(mic_prefetcht0(EVEX_compress_addr(scratch0,
148                        (irb + prf0_offt)*vlen)));
149             IRB_LOOP(mic_prefetcht2(EVEX_compress_addr(scratch0,
150                        (irb + prf2_offt)*vlen)));
151         }
152         IRB_LOOP(mic_prefetcht0(EVEX_compress_addr(dst, (irb + prf0_offt)*vlen)));
153         IRB_LOOP(mic_prefetcht2(EVEX_compress_addr(dst, (irb + prf2_offt)*vlen)));
154         if (pk != prop_kind::forward_inference) {
155             IRB_LOOP(mic_prefetcht0(EVEX_compress_addr(scratch1,
156                          (irb + prf0_offt) * vlen)));
157             IRB_LOOP(mic_prefetcht2(EVEX_compress_addr(scratch1,
158                          (irb + prf2_offt) * vlen)));
159         }
160
161         loop_size = loop_size_param;
162         if (loop_size == 0)
163             return;
164         if (!is_first && !is_single) {
165             IRB_LOOP(vmovups(xreg(irb, xsrc_prev),
166                         ptr[src + (irb - HW) * vlen + SRC_PREV_OFFSET]));
167         }
168         IRB_LOOP(vmovups(zreg(irb, zsrc), EVEX_compress_addr(src,irb*vlen)));
169         if (!is_last && !is_single) {
170             IRB_LOOP(vmovups(xreg(irb, xsrc_next),
171                         ptr[src + (irb + HW) * vlen]));
172         }
173
174         if (!is_first && !is_single) {
175             IRB_LOOP(vmovups(ptr[t + irb*BUFFER_BLOCK],
176                         xreg(irb, xsrc_prev)));
177         }
178         IRB_LOOP(vmovups(EVEX_compress_addr(t, irb*BUFFER_BLOCK + XMM_SIZE),
179                     zreg(irb, zsrc)));
180         if (!is_last && !is_single) {
181             IRB_LOOP(vmovups(ptr[t + irb*BUFFER_BLOCK + BUFFER_NEXT_OFFSET],
182                     xreg(irb, xsrc_next)));
183         }
184
185         IRB_LOOP(vmovups(zreg(irb, za), EVEX_compress_addr(t, irb*BUFFER_BLOCK
186                         + XMM_SIZE - 2*sizeof(float))));
187         IRB_LOOP(vmovups(zreg(irb, zb), EVEX_compress_addr(t, irb*BUFFER_BLOCK
188                         + XMM_SIZE - sizeof(float))));
189         IRB_LOOP(vmovups(zreg(irb, zd), EVEX_compress_addr(t, irb*BUFFER_BLOCK
190                         + XMM_SIZE + sizeof(float))));
191         IRB_LOOP(vmovups(zreg(irb, ze), EVEX_compress_addr(t, irb*BUFFER_BLOCK
192                         + XMM_SIZE + 2*sizeof(float))));
193
194         assert(zc == zsrc);
195         IRB_LOOP(vmulps(zreg(irb, zsum), zreg(irb, zc), zreg(irb, zc)));
196
197         IRB_LOOP(vfmadd231ps(zreg(irb, zsum), zreg(irb, za), zreg(irb, za)));
198         IRB_LOOP(vfmadd231ps(zreg(irb, zsum), zreg(irb, zb), zreg(irb, zb)));
199         IRB_LOOP(vfmadd231ps(zreg(irb, zsum), zreg(irb, zd), zreg(irb, zd)));
200         IRB_LOOP(vfmadd231ps(zreg(irb, zsum), zreg(irb, ze), zreg(irb, ze)));
201
202         IRB_LOOP(vfmadd132ps(zreg(irb, zsum), zk, zalpha));
203
204         IRB_LOOP(vmovaps(zreg(irb, zbase), zreg(irb, zsum)));
205
206         IRB_LOOP(vmulps(zreg(irb, zsum2), zreg(irb, zsum), zreg(irb, zsum)));
207         IRB_LOOP(vmulps(zreg(irb, zsum), zreg(irb, zsum), zreg(irb, zsum2)));
208
209         IRB_LOOP(vsqrtps(zreg(irb, zsum), zreg(irb, zsum)));
210         IRB_LOOP(vsqrtps(zreg(irb, zsum), zreg(irb, zsum)));
211
212         if (pk != prop_kind::forward_inference) {
213             IRB_LOOP(vmovups(EVEX_compress_addr(scratch0, irb*vlen),
214                         zreg(irb, zsum)));
215         }
216         IRB_LOOP(vdivps(zreg(irb, zdst), zreg(irb, zsrc), zreg(irb, zsum)));
217         IRB_LOOP(vmovups(EVEX_compress_addr(dst, irb*vlen), zreg(irb, zdst)));
218         if (pk != prop_kind::forward_inference) {
219             /* ws1 = zdst / zbase = zsrc / (zbase^1.75) */
220             IRB_LOOP(vdivps(zreg(irb, zsum), zreg(irb, zdst), zreg(irb, zbase)));
221             IRB_LOOP(vmovups(EVEX_compress_addr(scratch1, irb*vlen),
222                         zreg(irb, zsum)));
223         }
224     }
225
226     jit_avx512_common_lrn_kernel_f32(
227         const struct nChw16c_across &J,
228         prop_kind_t prop_kind,
229         int use_h_parallel,
230         float A,
231         float K,
232         void *code_ptr = nullptr,
233         size_t code_size = 2 * Xbyak::DEFAULT_MAX_CODE_SIZE)
234         : jit_generator(code_ptr, code_size)
235         , pk(prop_kind)
236         , use_h_parallelism(use_h_parallel)
237         , alpha(A)
238         , k(K)
239     {
240         this->preamble();
241
242         mov(src, ptr[param + 0]);
243         mov(dst, ptr[param + 8]);
244         if (pk != prop_kind::forward_inference)
245         {
246             mov(scratch0, ptr[param + 16]);
247             mov(scratch1, ptr[param + 24]);
248         }
249         is_first = J.version == -1 || J.version == -2;
250         is_last  = J.version == +1 || J.version == -2;
251         is_single = J.version == 3;
252
253         W = J.W;
254         HW = J.W*J.H;
255         int LSB = use_h_parallelism ? W : HW;
256
257         sub(t, FWD_RBC*BUFFER_BLOCK);
258         mov(imm_addr64, float2int(this->alpha));
259         movq(xalpha, imm_addr64);
260         vbroadcastss(zalpha, xalpha);
261
262         mov(imm_addr64, float2int(this->k));
263         movq(xk, imm_addr64);
264         vbroadcastss(zk, xk);
265
266         if (is_first || is_single) {
267             vxorps(xmm2, xmm2, xmm2);
268             for(int irb = 0; irb < FWD_RBC; irb++) {
269                 vmovups(ptr[t + irb*BUFFER_BLOCK], xmm2);
270             }
271         }
272         if (is_last || is_single) {
273             vxorps(xmm2, xmm2, xmm2);
274             for(int irb = 0; irb < FWD_RBC; irb++) {
275                 vmovups(ptr[t + irb*BUFFER_BLOCK + BUFFER_NEXT_OFFSET],
276                     xmm2);
277             }
278         }
279
280         int LSREST = LSB % FWD_RBC;
281         int LS = LSB - LSREST;
282
283         Label lrn_loop;
284
285         if (LS > 0) {
286             mov(hw, LS);
287
288             L(lrn_loop);
289             {
290                 compute_loop(FWD_RBC);
291
292                 add(src, FWD_RBC*vlen);
293                 add(dst, FWD_RBC*vlen);
294                 if (pk != prop_kind::forward_inference)
295                 {
296                     add(scratch0, FWD_RBC*vlen);
297                     add(scratch1, FWD_RBC*vlen);
298                 }
299
300                 for(int irb = 0; irb < FWD_RBC; irb++)
301                     dec(hw);
302                 cmp(hw, 0);
303                 jne(lrn_loop, T_NEAR);
304             }
305         }
306
307         compute_loop(LSREST);
308
309         add(t, FWD_RBC*BUFFER_BLOCK);
310         this->postamble();
311
312         ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>(
313                     this->getCode()));
314     }
315 };
316
317 status_t jit_avx512_common_lrn_fwd_t::pd_t::init() {
318     using namespace prop_kind;
319     using namespace alg_kind;
320
321     assert(engine()->kind() == engine_kind::cpu);
322
323     if (!mayiuse(avx512_common)) return unimplemented;
324
325     const memory_desc_wrapper data_d(data_pd_.desc());
326     bool ok = true
327         && one_of(desc()->prop_kind, forward_training, forward_inference)
328         && !has_zero_dim_memory()
329         && everyone_is(data_type::f32, desc()->data_desc.data_type)
330         && data_d.ndims() == 4
331         && data_d.dims()[1] % vsize == 0
332         && attr()->has_default_values();
333     if (!ok) return unimplemented;
334
335     if (desc()->prop_kind == forward_training) {
336         memory_desc_t ws_d;
337         dims_t ws_dims = { MB(), C(), H(), 2*W() };
338         mkldnn_memory_desc_init(&ws_d, 4, ws_dims, data_type::f32,
339             memory_format::nChw16c);
340         ws_pd_ = cpu_memory_t::pd_t(engine_, &ws_d);
341     }
342
343     bool args_ok_across = true
344         && desc()->alg_kind == lrn_across_channels
345         && desc()->local_size == 5
346         && desc()->lrn_beta == 0.75
347         && data_d.format() == nChw16c;
348
349     return args_ok_across ? success : unimplemented;
350 }
351
352 jit_avx512_common_lrn_fwd_t::jit_avx512_common_lrn_fwd_t(const pd_t *apd,
353         const input_vector &inputs, const output_vector &outputs)
354     : cpu_primitive_t(apd, inputs, outputs)
355     , use_h_parallelism(0), ker_(nullptr), ker_first_(nullptr)
356     , ker_last_(nullptr) {
357     using namespace alg_kind;
358     const int C = pd()->C();
359     const int H = pd()->H();
360     const int W = pd()->W();
361     const int ls = pd()->desc()->local_size;
362     const float alpha = pd()->desc()->lrn_alpha / ls;
363     const float k = pd()->desc()->lrn_k;
364
365     auto pk = pd()->desc()->prop_kind;
366
367     use_h_parallelism = H > 28 ? 1 : 0;
368
369     if (C / vsize == 1) {
370         ker_ = new jit_avx512_common_lrn_kernel_f32(nChw16c_across(H, W, 3), pk,
371             use_h_parallelism, alpha, k);
372     } else {
373         ker_ = new jit_avx512_common_lrn_kernel_f32(nChw16c_across(H, W, 0), pk,
374             use_h_parallelism, alpha, k);
375         ker_first_ = new jit_avx512_common_lrn_kernel_f32(
376             nChw16c_across(H, W, -1), pk, use_h_parallelism, alpha, k);
377         ker_last_ = new jit_avx512_common_lrn_kernel_f32(
378             nChw16c_across(H, W, +1), pk, use_h_parallelism, alpha, k);
379     }
380 }
381
382 jit_avx512_common_lrn_fwd_t::~jit_avx512_common_lrn_fwd_t()
383 { delete ker_; delete ker_first_; delete ker_last_; }
384
385 void jit_avx512_common_lrn_fwd_t::execute_forward() const {
386     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
387     auto dst = reinterpret_cast<data_t*>(this->memory(0));
388     auto ws = reinterpret_cast<data_t*>(this->memory(1));
389
390     const int N = pd()->MB();
391     const int C = pd()->C();
392     const int H = pd()->H();
393     const int W = pd()->W();
394     const int C16 = C / vsize;
395     const size_t work_amount = use_h_parallelism ? N*C16*H : N*C16;
396
397     parallel(0, work_amount, [&](const int ithr, const int nthr) {
398         size_t start{0}, end{0};
399
400         balance211(work_amount, nthr, ithr, start, end);
401         if (use_h_parallelism) {
402             int n{0}, c16{0}, h{0};
403             nd_iterator_init(start, n, N, c16, C16, h, H);
404             for (size_t iwork = start; iwork < end; ++iwork) {
405                 auto offset = n*C*H*W + c16*H*W*vsize
406                     + h*W*vsize;
407                 auto ws_offset0 = n*C*H*2*W + c16*H*2*W*vsize
408                     + h*2*W*vsize;
409                 auto ws_offset1 = ws_offset0 + W*vsize;
410
411                 jit_args_fwd_t args;
412                 args.src = &src[offset];
413                 args.dst = &dst[offset];
414                 args.ws0 = &ws[ws_offset0];
415                 args.ws1 = &ws[ws_offset1];
416
417                 if (C16 == 1)
418                     (*ker_)(&args);
419                 else if (c16 == 0)
420                     (*ker_first_)(&args);
421                 else if (c16 == C16 - 1)
422                     (*ker_last_)(&args);
423                 else
424                     (*ker_)(&args);
425                 nd_iterator_step(n, N, c16, C16, h, H);
426             }
427         } else {
428             int n{0}, c16{0};
429             nd_iterator_init(start, n, N, c16, C16);
430             for (size_t iwork = start; iwork < end; ++iwork) {
431                 auto offset = n*C*H*W + c16*H*W*vsize;
432                 auto ws_offset0 = n*C*H*2*W + c16*H*2*W*vsize;
433                 auto ws_offset1 = ws_offset0 + H*W*vsize;
434
435                 jit_args_fwd_t args;
436                 args.src = &src[offset];
437                 args.dst = &dst[offset];
438                 args.ws0 = &ws[ws_offset0];
439                 args.ws1 = &ws[ws_offset1];
440
441                 if (C16 == 1)
442                     (*ker_)(&args);
443                 else if (c16 == 0)
444                     (*ker_first_)(&args);
445                 else if (c16 == C16 - 1)
446                     (*ker_last_)(&args);
447                 else
448                     (*ker_)(&args);
449
450                 nd_iterator_step(n, N, c16, C16);
451             }
452         }
453     });
454 }
455
456 struct jit_avx512_common_lrn_bwd_t::jit_avx512_common_lrn_kernel_f32:
457     public jit_generator {
458     int HW, W;
459     bool is_first;
460     bool is_last;
461     bool is_single;
462
463     Reg64 src = rax;
464     Reg64 diffsrc = r8;
465     Reg64 diffdst = r9;
466     Reg64 workspace0 = rdx;
467     Reg64 workspace1 = rsi;
468     Reg64 imm_addr64 = rbx;
469
470     Zmm znalphabeta = zmm0;
471     Xmm xnalphabeta = xmm0;
472
473     Reg64 param = abi_param1;
474     Reg64 t = rsp;
475     Reg64 hw = r10;
476
477     int xws1_prev = 1;
478     int xdiffdst_prev = 2;
479     int zws1 = 1;
480
481     int zsrc = 1;
482     int zdiffdst = 5;
483     int zdiffsrc = 6;
484
485     int xws1_next = 1;
486     int xdiffdst_next = 3;
487
488     int za = 1;
489     int zb = 2;
490     int zd = 3;
491     int ze = 4;
492     int zws0 = 2;
493
494     float nalphabeta;
495
496     int use_h_parallelism;
497
498     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_lrn_kernel_f32)
499
500     void (*ker)(jit_args_bwd_t *);
501     void operator()(jit_args_bwd_t *arg) { ker(arg); }
502
503     enum {
504         prf0_offt = 1*BWD_RBC,
505         prf2_offt = 8*BWD_RBC
506     };
507
508     inline void compute_loop(int loop_size_param, int prefetchL1,
509             int prefetchL2)
510     {
511         // loop_size - param for IRB_LOOP macro
512         int loop_size = loop_size_param;
513
514         auto xreg = [=](int irb, int i) {
515             return Xmm(irb*6 + i);
516         };
517
518         auto zreg = [=](int irb, int i) {
519             return Zmm(irb*6 + i);
520         };
521
522 // ---- prefetching -------------------------------------------
523         if (!is_first && !is_single) {
524             if (prefetchL1)
525                 IRB_LOOP(mic_prefetcht0(ptr[workspace1 + (irb + prf0_offt
526                         - 2 * HW) * vlen]));
527             if (prefetchL1)
528                 IRB_LOOP(mic_prefetcht0(ptr[diffdst    + (irb + prf0_offt
529                         - HW) * vlen]));
530         }
531
532         if (prefetchL1)
533             IRB_LOOP(mic_prefetcht0(ptr[src + (irb + prf0_offt)*vlen]));
534         if (prefetchL2)
535             IRB_LOOP(mic_prefetcht2(ptr[src + (irb + prf2_offt)*vlen]));
536
537         if (prefetchL1)
538             IRB_LOOP(mic_prefetcht0(ptr[workspace1 + (irb + prf0_offt)*vlen]));
539
540         if (prefetchL1)
541             IRB_LOOP(mic_prefetcht0(ptr[diffdst + (irb + prf0_offt)*vlen]));
542
543         if (!is_last && !is_single) {
544             if (prefetchL1)
545                 IRB_LOOP(mic_prefetcht0(ptr[workspace1 + (irb + prf0_offt
546                         + 2 * HW) * vlen]));
547             if (prefetchL2)
548                 IRB_LOOP(mic_prefetcht2(ptr[workspace1 + (irb + prf2_offt
549                         + 2 * HW) * vlen]));
550
551             if (prefetchL1)
552                 IRB_LOOP(mic_prefetcht0(ptr[diffdst +  (irb + prf0_offt
553                           + HW) * vlen]));
554             if (prefetchL2)
555                 IRB_LOOP(mic_prefetcht2(ptr[diffdst +  (irb + prf2_offt
556                         + HW) * vlen]));
557         }
558         if (prefetchL1)
559             IRB_LOOP(mic_prefetcht0(ptr[workspace0 + (irb + prf0_offt)*vlen]));
560         if (prefetchL2)
561             IRB_LOOP(mic_prefetcht2(ptr[workspace0 + (irb + prf2_offt)*vlen]));
562 // -----------------------------------------------------------
563
564         if (loop_size_param == 0)
565             return;
566
567         if (!is_first && !is_single) {
568             IRB_LOOP(vmovups(xreg(irb, xws1_prev), ptr[workspace1 + (irb
569                     - 2 * HW) * vlen + SRC_PREV_OFFSET]));
570             IRB_LOOP(vmovups(xreg(irb, xdiffdst_prev), ptr[diffdst + (irb
571                     - HW) * vlen + SRC_PREV_OFFSET]));
572             IRB_LOOP(vmulps(xreg(irb, xdiffdst_prev), xreg(irb, xdiffdst_prev),
573                     xreg(irb, xws1_prev)));
574         }
575
576         IRB_LOOP(vmovups(zreg(irb, zws1),
577                 EVEX_compress_addr(workspace1, irb*vlen)));
578         IRB_LOOP(vmovups(zreg(irb, zdiffdst),
579                 EVEX_compress_addr(diffdst, irb*vlen)));
580         IRB_LOOP(vmulps(zreg(irb, zdiffsrc), zreg(irb, zdiffdst),
581                 zreg(irb, zws1)));
582
583         if (!is_last && !is_single) {
584             IRB_LOOP(vmovups(xreg(irb, xws1_next), ptr[workspace1 + (irb
585                     + 2 * HW) * vlen]));
586             IRB_LOOP(vmovups(xreg(irb, xdiffdst_next), ptr[diffdst +  (irb
587                     + HW) * vlen]));
588             IRB_LOOP(vmulps(xreg(irb, xdiffdst_next), xreg(irb, xdiffdst_next),
589                     xreg(irb, xws1_next)));
590         }
591
592         if (!is_first && !is_single) {
593             IRB_LOOP(vmovups(ptr[t + irb*BUFFER_BLOCK],
594                     xreg(irb, xdiffdst_prev)));
595         }
596         IRB_LOOP(vmovups(EVEX_compress_addr(t, irb*BUFFER_BLOCK + XMM_SIZE),
597                  zreg(irb, zdiffsrc)));
598         if (!is_last && !is_single) {
599             IRB_LOOP(vmovups(ptr[t + irb*BUFFER_BLOCK + BUFFER_NEXT_OFFSET],
600                  xreg(irb, xdiffdst_next)));
601         }
602
603         IRB_LOOP(vmovups(zreg(irb, za), EVEX_compress_addr(t, irb*BUFFER_BLOCK
604                 + XMM_SIZE - 2*sizeof(float))));
605         IRB_LOOP(vmovups(zreg(irb, zb), EVEX_compress_addr(t, irb*BUFFER_BLOCK
606                 + XMM_SIZE - 1*sizeof(float))));
607         IRB_LOOP(vmovups(zreg(irb, zd), EVEX_compress_addr(t, irb*BUFFER_BLOCK
608                 + XMM_SIZE + 1*sizeof(float))));
609         IRB_LOOP(vmovups(zreg(irb, ze), EVEX_compress_addr(t, irb*BUFFER_BLOCK
610                 + XMM_SIZE + 2*sizeof(float))));
611         IRB_LOOP(vaddps(zreg(irb, zdiffsrc), zreg(irb, zdiffsrc),
612                 zreg(irb, za)));
613         assert(zsrc == za);
614         IRB_LOOP(vmovups(zreg(irb, zsrc), EVEX_compress_addr(src, irb*vlen)));
615         IRB_LOOP(vaddps(zreg(irb, zdiffsrc), zreg(irb, zdiffsrc),
616                 zreg(irb, zb)));
617         IRB_LOOP(vaddps(zreg(irb, zdiffsrc), zreg(irb, zdiffsrc),
618                 zreg(irb, zd)));
619         IRB_LOOP(vaddps(zreg(irb, zdiffsrc), zreg(irb, zdiffsrc),
620                 zreg(irb, ze)));
621         IRB_LOOP(vmulps(zreg(irb, zsrc), zreg(irb, zsrc), znalphabeta));
622
623         IRB_LOOP(vmovups(zreg(irb, zws0),
624                  EVEX_compress_addr(workspace0, irb*vlen)));
625         IRB_LOOP(vdivps(zreg(irb, zdiffdst), zreg(irb, zdiffdst),
626                  zreg(irb, zws0)));
627         IRB_LOOP(vfmadd213ps(zreg(irb, zdiffsrc), zreg(irb, zsrc),
628                  zreg(irb, zdiffdst)));
629
630         Label unaligned_store, end_store;
631         test(diffsrc, vlen - 1);
632         jnz(unaligned_store, T_NEAR);
633         IRB_LOOP(uni_vmovntps(EVEX_compress_addr(diffsrc, irb*vlen),
634                  zreg(irb, zdiffsrc)));
635         jmp(end_store, T_NEAR);
636         L(unaligned_store); {
637             IRB_LOOP(uni_vmovups(EVEX_compress_addr(diffsrc, irb*vlen),
638                      zreg(irb, zdiffsrc)));
639         }
640         L(end_store);
641     }
642
643     jit_avx512_common_lrn_kernel_f32(
644         const struct nChw16c_across &J,
645         float A,
646         float B,
647         int use_h_parallel,
648         void *code_ptr = nullptr,
649         size_t code_size = 1 * Xbyak::DEFAULT_MAX_CODE_SIZE)
650         : jit_generator(code_ptr, code_size)
651         , nalphabeta(-2*A*B)
652         , use_h_parallelism(use_h_parallel)
653     {
654         this->preamble();
655
656         mov(src, ptr[param + 0]);
657         mov(diffdst, ptr[param + 8]);
658         mov(workspace0, ptr[param + 16]);
659         mov(workspace1, ptr[param + 24]);
660         mov(diffsrc, ptr[param + 32]);
661
662         W = J.W;
663         HW = J.H*J.W;
664         int LSB = this->use_h_parallelism ? W : HW;
665
666         sub(t, BWD_RBC*BUFFER_BLOCK);
667         mov(imm_addr64, float2int(this->nalphabeta));
668         movq(xnalphabeta, imm_addr64);
669         vbroadcastss(znalphabeta, xnalphabeta);
670
671         is_first = J.version == -1 || J.version == -2;
672         is_last  = J.version == +1 || J.version == +2;
673         is_single = J.version == 3;
674
675         if (is_first || is_single) {
676             vxorps(xmm1, xmm1, xmm1);
677             for(int irb = 0; irb < BWD_RBC; irb++) {
678                 vmovups(ptr[t + irb*BUFFER_BLOCK], xmm1);
679             }
680         }
681         if (is_last || is_single) {
682             vxorps(xmm1, xmm1, xmm1);
683             for(int irb = 0; irb < BWD_RBC; irb++) {
684                 vmovups(ptr[t + irb*BUFFER_BLOCK + BUFFER_NEXT_OFFSET], xmm1);
685             }
686         }
687
688         int LSREST = LSB % BWD_RBC;
689         int LS = LSB - LSREST;
690
691         Label lrn_loop;
692
693         if (LS > 0) {
694             mov(hw, LS);
695
696             L(lrn_loop);
697             {
698                 compute_loop(BWD_RBC, 1, 1);
699
700                 add(src, BWD_RBC*vlen);
701                 add(diffsrc, BWD_RBC*vlen);
702                 add(diffdst, BWD_RBC*vlen);
703                 add(workspace0, BWD_RBC*vlen);
704                 add(workspace1, BWD_RBC*vlen);
705
706                 for(int irb = 0; irb < BWD_RBC; irb++)
707                     dec(hw);
708                 cmp(hw, 0);
709                 jne(lrn_loop, T_NEAR);
710             }
711         }
712
713         compute_loop(LSREST, 1, this->use_h_parallelism ? 0 : 1);
714
715         add(t, BWD_RBC*BUFFER_BLOCK);
716         this->postamble();
717
718         ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>(
719                     this->getCode()));
720     }
721
722 };
723
724 status_t jit_avx512_common_lrn_bwd_t::pd_t::init() {
725     using namespace prop_kind;
726     using namespace alg_kind;
727
728     assert(engine()->kind() == engine_kind::cpu);
729
730     if (!mayiuse(avx512_common)) return unimplemented;
731
732     const memory_desc_wrapper data_d(data_pd_.desc());
733     bool ok = true
734         && utils::one_of(desc()->prop_kind, backward, backward_data)
735         && utils::everyone_is(data_type::f32, desc()->data_desc.data_type)
736         && !has_zero_dim_memory()
737         && data_d.ndims() == 4
738         && data_d.dims()[1] % vsize == 0
739         && attr()->has_default_values();
740     if (!ok) return unimplemented;
741
742     memory_desc_t ws_d;
743     dims_t ws_dims = { MB(), C(), H(), 2*W() };
744     mkldnn_memory_desc_init(&ws_d, 4, ws_dims, data_type::f32,
745         memory_format::nChw16c);
746     ws_pd_ = cpu_memory_t::pd_t(engine_, &ws_d);
747
748     auto fwd_ws_d_ = hint_fwd_pd_->workspace_pd()->desc();
749     bool ws_ok = true
750         && fwd_ws_d_->ndims == ws_pd_.desc()->ndims
751         && fwd_ws_d_->format == ws_pd_.desc()->format
752         && fwd_ws_d_->data_type == ws_pd_.desc()->data_type;
753     if (!ws_ok) return unimplemented;
754
755     bool args_ok_across = true
756         && desc()->alg_kind == lrn_across_channels
757         && desc()->local_size == 5
758         && desc()->lrn_beta == 0.75
759         && data_d.format() == nChw16c;
760
761     return args_ok_across ? success : unimplemented;
762 }
763
764 jit_avx512_common_lrn_bwd_t::jit_avx512_common_lrn_bwd_t(const pd_t *apd,
765         const input_vector &inputs, const output_vector &outputs)
766     : cpu_primitive_t(apd, inputs, outputs)
767     , use_h_parallelism(0),  ker_(nullptr), ker_first_(nullptr)
768     , ker_last_(nullptr) {
769     const int C = pd()->C();
770     const int H = pd()->H();
771     const int W = pd()->W();
772     const int ls = pd()->desc()->local_size;
773     const float alpha = pd()->desc()->lrn_alpha / ls;
774     const float beta = pd()->desc()->lrn_beta;
775
776     use_h_parallelism = H > 28 ? 1 : 0;
777
778     if (C / vsize == 1) {
779         ker_ = new jit_avx512_common_lrn_kernel_f32(nChw16c_across(H, W, 3),
780         alpha, beta, use_h_parallelism);
781     } else {
782         ker_ = new jit_avx512_common_lrn_kernel_f32(nChw16c_across(H, W, 0),
783             alpha, beta, use_h_parallelism);
784         ker_first_ = new jit_avx512_common_lrn_kernel_f32(
785             nChw16c_across(H, W, -1), alpha, beta, use_h_parallelism);
786         ker_last_ = new jit_avx512_common_lrn_kernel_f32(
787             nChw16c_across(H, W, +1), alpha, beta, use_h_parallelism);
788     }
789 }
790
791 jit_avx512_common_lrn_bwd_t::~jit_avx512_common_lrn_bwd_t()
792 { delete ker_; delete ker_first_; delete ker_last_; }
793
794 void jit_avx512_common_lrn_bwd_t::execute_backward() const {
795     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
796     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
797     auto ws = reinterpret_cast<const data_t *>(this->input_memory(2));
798     auto diff_src = reinterpret_cast<data_t *>(this->memory(0));
799
800     const int N = pd()->MB();
801     const int C = pd()->C();
802     const int H = pd()->H();
803     const int W = pd()->W();
804     const int C16 = C / vsize;
805     const size_t work_amount = use_h_parallelism ? N*C16*H : N*C16;
806
807     parallel(0, work_amount, [&](const int ithr, const int nthr) {
808         size_t start{0}, end{0};
809
810         balance211(work_amount, nthr, ithr, start, end);
811         if (use_h_parallelism) {
812             int n{0}, c16{0}, h{0};
813             nd_iterator_init(start, n, N,  h, H, c16, C16);
814             for (size_t iwork = start; iwork < end; ++iwork) {
815                 auto offset = n*C*H*W + c16*H*W*vsize
816                     + h*W*vsize;
817                 auto ws_offset0 = n*C*H*2*W + c16*H*2*W*vsize
818                     + h*2*W*vsize;
819                 auto ws_offset1 = ws_offset0 + W*vsize;
820
821                 jit_args_bwd_t args;
822                 args.src = &src[offset];
823                 args.diff_dst = &diff_dst[offset];
824                 args.ws0 = &ws[ws_offset0];
825                 args.ws1 = &ws[ws_offset1];
826                 args.diff_src = &diff_src[offset];
827
828                 if (C16 == 1)
829                     (*ker_)(&args);
830                 else if (c16 == 0)
831                     (*ker_first_)(&args);
832                 else if (c16 == C16 - 1)
833                     (*ker_last_)(&args);
834                 else
835                     (*ker_)(&args);
836                 nd_iterator_step(n, N, h, H, c16, C16);
837             }
838         } else {
839             int n{0}, c16{0};
840             nd_iterator_init(start, n, N, c16, C16);
841             for (size_t iwork = start; iwork < end; ++iwork) {
842                 auto offset = n*C*H*W + c16*H*W*vsize;
843                 auto ws_offset0 = n*C*H*2*W + c16*H*2*W*vsize;
844                 auto ws_offset1 = ws_offset0 + H*W*vsize;
845
846                 jit_args_bwd_t args;
847                 args.src = &src[offset];
848                 args.diff_dst = &diff_dst[offset];
849                 args.ws0 = &ws[ws_offset0];
850                 args.ws1 = &ws[ws_offset1];
851                 args.diff_src = &diff_src[offset];
852
853                 if (C16 == 1)
854                     (*ker_)(&args);
855                 else if (c16 == 0)
856                     (*ker_first_)(&args);
857                 else if (c16 == C16 - 1)
858                     (*ker_last_)(&args);
859                 else
860                     (*ker_)(&args);
861
862                 nd_iterator_step(n, N, c16, C16);
863             }
864         }
865     });
866 }
867
868 }
869 }
870 }