updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_lrn_kernel_f32.cpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "c_types_map.hpp"
18 #include "nstl.hpp"
19 #include "utils.hpp"
20 #include "jit_uni_lrn.hpp"
21
22 namespace mkldnn {
23 namespace impl {
24 namespace cpu {
25
26 using namespace Xbyak;
27
28 //////////////////////////////////////////////////////////////////////////////
29 // forward kernel
30 template<cpu_isa_t isa>
31 void jit_uni_lrn_fwd_kernel_f32<isa>::within_body(
32         int hoff, int Hoff, int woff, int Woff, int stride,
33         Xbyak::Ymm ysum, Xbyak::Ymm ydst, Xbyak::Ymm ytmp, Xbyak::Ymm ysum2,
34         prop_kind_t pk)
35 {
36     vxorps(ysum, ysum, ysum);
37     for (int i = hoff; i <= Hoff; ++i)
38     {
39         for (int j = woff; j <= Woff; ++j)
40         {
41             if (i == 0 && j == 0)
42             {
43                 vmovups(ydst, ptr[src]);
44                 vfmadd231ps(ysum, ydst, ydst);
45             }
46             else
47             {
48                 vmovups(ytmp, ptr[src + (i*stride + j)*VECTOR_LENGTH*4]);
49                 vfmadd231ps(ysum, ytmp, ytmp);
50             }
51         }
52     }
53     vfmadd132ps(ysum, yk, yalpha); // ysum <- ysum*yalpha+yk
54     vmovaps(ytmp, ysum);
55     if (pk != prop_kind::forward_inference)
56         vmovups(ptr[scratch], ytmp);
57     vmulps(ysum2, ysum, ysum);
58     vmulps(ysum, ysum, ysum2); // ysum = (ysum*yalpha+yk)^3;
59     vsqrtps(ysum, ysum);
60     vsqrtps(ysum, ysum); // ysum = (ysum*yalpha+yk)^0.75
61     vdivps(ydst, ydst, ysum); // ydst <- ydst / ysum
62     vmovups(ptr[dst], ydst);
63     add(src, 32);
64     add(dst, 32);
65     if (pk != prop_kind::forward_inference)
66         add(scratch, 32);
67 }
68
69 template<cpu_isa_t isa>
70 void jit_uni_lrn_fwd_kernel_f32<isa>::within_body_sse42(
71     int hoff, int Hoff, int woff, int Woff, int stride, prop_kind_t pk)
72 {
73     Xbyak::Xmm xtmp_lo = xmm12;
74     Xbyak::Xmm xtmp_hi = xmm13;
75     Xbyak::Xmm xsum_lo = xmm8;
76     Xbyak::Xmm xsum_hi = xmm9;
77     Xbyak::Xmm xdst_lo = xmm10;
78     Xbyak::Xmm xdst_hi = xmm11;
79     Xbyak::Xmm xsum2_lo = xmm14;
80     Xbyak::Xmm xsum2_hi = xmm15;
81
82     xorps(xsum_lo, xsum_lo);
83     xorps(xsum_hi, xsum_hi);
84     for (int i = hoff; i <= Hoff; ++i)
85     {
86         for (int j = woff; j <= Woff; ++j)
87         {
88             if (i == 0 && j == 0)
89             {
90                 movups(xdst_lo, ptr[src]);
91                 movups(xdst_hi, ptr[src + 4 * sizeof(float)]);
92                 mulps(xdst_lo, xdst_lo);
93                 mulps(xdst_hi, xdst_hi);
94                 addps(xsum_lo, xdst_lo);
95                 addps(xsum_hi, xdst_hi);
96             }
97             else
98             {
99                 movups(xtmp_lo, ptr[src + (i*stride + j)*VECTOR_LENGTH * 4]);
100                 movups(xtmp_hi, ptr[src + (i*stride + j)*VECTOR_LENGTH * 4 + 4 * sizeof(float)]);
101                 mulps(xtmp_lo, xtmp_lo);
102                 mulps(xtmp_hi, xtmp_hi);
103                 addps(xsum_lo, xtmp_lo);
104                 addps(xsum_hi, xtmp_hi);
105             }
106         }
107     }
108     mulps(xsum_lo, xalpha);
109     mulps(xsum_hi, xalpha);
110     addps(xsum_lo, xk);
111     addps(xsum_hi, xk); // xsum <- xsum*xalpha+xk
112     movaps(xtmp_lo, xsum_lo);
113     movaps(xtmp_hi, xsum_hi);
114     if (pk != prop_kind::forward_inference) {
115         movups(ptr[scratch], xtmp_lo);
116         movups(ptr[scratch + 4 * sizeof(float)], xtmp_hi);
117     }
118     movaps(xsum2_lo, xsum_lo);
119     movaps(xsum2_hi, xsum_hi);
120     mulps(xsum2_lo, xsum_lo);
121     mulps(xsum2_hi, xsum_hi);
122     mulps(xsum_lo, xsum2_lo);
123     mulps(xsum_hi, xsum2_hi); // xsum = (xsum*xalpha+xk)^3;
124
125     sqrtps(xsum_lo, xsum_lo);
126     sqrtps(xsum_hi, xsum_hi);
127     sqrtps(xsum_lo, xsum_lo);
128     sqrtps(xsum_hi, xsum_hi); // xsum = (xsum*xalpha+xk)^0.75
129
130     movups(xdst_lo, ptr[src]);
131     movups(xdst_hi, ptr[src + 4 * sizeof(float)]);
132     divps(xdst_lo, xsum_lo);
133     divps(xdst_hi, xsum_hi); // xdst <- xdst / xsum
134
135     movups(ptr[dst], xdst_lo);
136     movups(ptr[dst + 4 * sizeof(float)], xdst_hi);
137     add(src, 32);
138     add(dst, 32);
139     if (pk != prop_kind::forward_inference)
140         add(scratch, 32);
141 }
142
143 template <cpu_isa_t isa>
144 jit_uni_lrn_fwd_kernel_f32<isa>::jit_uni_lrn_fwd_kernel_f32(
145         const struct nchw8c_within &J,
146         float A,
147         float K,
148         prop_kind_t pk,
149         void *code_ptr,
150         size_t code_size)
151         : jit_generator(code_ptr, code_size)
152         , alpha(A), k(K)
153 {
154     Xbyak::Reg64 h = r9;
155     Xbyak::Reg64 w = r10;
156     Vmm ysum = Vmm(isa == avx2 ? 9 : 9);
157     Vmm ysum2 = Vmm(isa == avx2 ? 10 : 10);
158     Vmm ydst = Vmm(isa == avx2 ? 11 : 11);
159     Vmm ytmp = Vmm(isa == avx2 ? 12 : 12);
160
161     this->preamble();
162
163     mov(src, ptr[this->param1 + 0]);
164     mov(dst, ptr[this->param1 + 8]);
165     if (pk != prop_kind::forward_inference)
166         mov(scratch, ptr[this->param1 + 16]);
167
168     mov(imm_addr64, float2int(this->alpha));
169     movq(xalpha, imm_addr64);
170     if (isa == avx2) {
171         vbroadcastss(yalpha, xalpha);
172     } else {
173         shufps(xalpha, xalpha, 0);
174     }
175
176     mov(imm_addr64, float2int(this->k));
177     movq(xk, imm_addr64);
178     if (isa == avx2) {
179         vbroadcastss(yk, xk);
180     } else {
181         shufps(xk, xk, 0);
182     }
183
184     int s2 = (J.size - 1) / 2, S2 = J.size - s2 - 1;
185
186     for (int i = 0; i < s2; ++i)
187     {
188         Label label_t;
189         for (int j = 0; j < s2; ++j) {
190             if (isa == avx2) {
191                 within_body(-i, S2, -j, S2, J.W, ysum, ydst, ytmp, ysum2, pk);
192             }
193             else {
194                 within_body_sse42(-i, S2, -j, S2, J.W, pk);
195             }
196         }
197         mov(w, J.W - J.size + 1);
198         L(label_t);
199         if (isa == avx2) {
200             within_body(-i, S2, -s2, S2, J.W, ysum, ydst, ytmp, ysum2, pk);
201         } else {
202             within_body_sse42(-i, S2, -s2, S2, J.W, pk);
203         }
204         dec(w);
205         cmp(w, 0);
206         jne(label_t, T_NEAR);
207         for (int j = J.W - S2; j < J.W; ++j) {
208             if (isa == avx2) {
209                 within_body(-i, S2, -s2, J.W - 1 - j, J.W,
210                     ysum, ydst, ytmp, ysum2, pk);
211             } else {
212                 within_body_sse42(-i, S2, -s2, J.W - 1 - j, J.W, pk);
213             }
214         }
215     }
216
217     mov(h, J.H - J.size + 1);
218     Label lrn_loop_h;
219     L(lrn_loop_h);
220     for (int j = 0; j < s2; ++j) {
221         if (isa == avx2) {
222             within_body(-s2, S2, -j, S2, J.W, ysum, ydst, ytmp, ysum2, pk);
223         } else {
224             within_body_sse42(-s2, S2, -j, S2, J.W, pk);
225         }
226     }
227     mov(w, J.W - J.size + 1);
228     Label lrn_loop_w;
229     L(lrn_loop_w);
230     if (isa == avx2) {
231         within_body(-s2, S2, -s2, S2, J.W, ysum, ydst, ytmp, ysum2, pk);
232     } else {
233         within_body_sse42(-s2, S2, -s2, S2, J.W, pk);
234     }
235     dec(w);
236     cmp(w, 0);
237     jne(lrn_loop_w, T_NEAR);
238     for (int j = J.W - S2; j < J.W; ++j) {
239         if (isa == avx2) {
240             within_body(-s2, S2, -s2, J.W - 1 - j, J.W,
241                 ysum, ydst, ytmp, ysum2, pk);
242         } else {
243             within_body_sse42(-s2, S2, -s2, J.W - 1 - j, J.W, pk);
244         }
245     }
246     dec(h);
247     cmp(h, 0);
248     jne(lrn_loop_h, T_NEAR);
249
250     for (int i = J.H - S2; i < J.H; ++i)
251     {
252         for (int j = 0; j < s2; ++j) {
253             if (isa == avx2) {
254                 within_body(-s2, J.H - 1 - i, -j, S2, J.W,
255                     ysum, ydst, ytmp, ysum2, pk);
256             } else {
257                 within_body_sse42(-s2, J.H - 1 - i, -j, S2, J.W, pk);
258             }
259         }
260
261         mov(w, J.W - J.size + 1);
262         Label label_b;
263         L(label_b);
264         if (isa == avx2) {
265             within_body(-s2, J.H - 1 - i, -s2, S2, J.W,
266                 ysum, ydst, ytmp, ysum2, pk);
267         } else {
268             within_body_sse42(-s2, J.H - 1 - i, -s2, S2, J.W, pk);
269         }
270         dec(w);
271         cmp(w, 0);
272         jne(label_b, T_NEAR);
273
274         for (int j = J.W - S2; j < J.W; ++j) {
275             if (isa == avx2) {
276                 within_body(-s2, J.H - 1 - i, -s2, J.W - 1 - j, J.W,
277                     ysum, ydst, ytmp, ysum2, pk);
278             } else {
279                 within_body_sse42(-s2, J.H - 1 - i, -s2, J.W - 1 - j, J.W, pk);
280             }
281         }
282     }
283
284     this->postamble();
285
286     ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>(
287                 this->getCode()));
288 }
289
290 template<>
291 jit_uni_lrn_fwd_kernel_f32<avx2>::jit_uni_lrn_fwd_kernel_f32(
292         const struct nchw8c_across &J,
293         float A,
294         float K,
295         prop_kind_t pk,
296         void *code_ptr,
297         size_t code_size)
298         : jit_generator(code_ptr, code_size)
299         , alpha(A), k(K)
300 {
301     Xbyak::Reg64 t = rsp;
302     Xbyak::Reg64 hw = r9;
303     Xbyak::Xmm xsrc_prev = xmm2;
304     Xbyak::Ymm ysrc = ymm3;
305     Xbyak::Ymm yc = ymm3;
306     Xbyak::Xmm xsrc_next = xmm4;
307     Xbyak::Ymm ya = ymm5;
308     Xbyak::Ymm yb = ymm6;
309     Xbyak::Ymm yd = ymm7;
310     Xbyak::Ymm ye = ymm8;
311     Xbyak::Ymm ysum = ymm9;
312     Xbyak::Ymm ysum2 = ymm10;
313     Xbyak::Ymm ydst = ymm11;
314     Xbyak::Ymm ybase = ymm12;
315
316     this->preamble();
317
318     mov(src, ptr[this->param1 + 0]);
319     mov(dst, ptr[this->param1 + 8]);
320     if (pk != prop_kind::forward_inference)
321         mov(scratch, ptr[this->param1 + 16]);
322     sub(t, 64);
323     mov(imm_addr64, float2int(this->alpha));
324     movq(xalpha, imm_addr64);
325     vbroadcastss(yalpha, xalpha);
326
327     mov(imm_addr64, float2int(this->k));
328     movq(xk, imm_addr64);
329     vbroadcastss(yk, xk);
330
331     if (J.version == -1)
332     {
333         vxorps(xsrc_prev, xsrc_prev, xsrc_prev);
334         vmovups(ptr[t + 0], xsrc_prev);
335     }
336     if (J.version == +1)
337     {
338         vxorps(xsrc_next, xsrc_next, xsrc_next);
339         vmovups(ptr[t + 48], xsrc_next);
340     }
341
342     mov(hw, J.H*J.W);
343
344     Label lrn_loop;
345     L(lrn_loop);
346
347     if (J.version != -1) vmovups(xsrc_prev, ptr[src - J.H*J.W * 32 + 16]);
348     vmovups(ysrc, ptr[src]);
349     if (J.version != +1) vmovups(xsrc_next, ptr[src + J.H*J.W * 32]);
350
351     if (J.version != -1) vmovups(ptr[t + 0], xsrc_prev);
352     vmovups(ptr[t + 16], ysrc);
353     if (J.version != +1) vmovups(ptr[t + 48], xsrc_next);
354
355     vmovups(ya, ptr[t + 16 - 8]);
356     vmovups(yb, ptr[t + 16 - 4]);
357     vmovups(yd, ptr[t + 16 + 4]);
358     vmovups(ye, ptr[t + 16 + 8]);
359     vmulps(ysum, yc, yc);
360     vfmadd231ps(ysum, ya, ya); // ysum <- ysum + ya*ya
361     vfmadd231ps(ysum, yb, yb);
362     vfmadd231ps(ysum, yd, yd);
363     vfmadd231ps(ysum, ye, ye);
364     vfmadd132ps(ysum, yk, yalpha); // ysum <- ysum*yalpha+yk
365
366     vmovaps(ybase, ysum);
367     if (pk != prop_kind::forward_inference)
368         vmovups(ptr[scratch], ybase);
369     vmulps(ysum2, ysum, ysum);
370     vmulps(ysum, ysum, ysum2); // ysum = ybase^3;
371     vsqrtps(ysum, ysum);
372     vsqrtps(ysum, ysum); // ysum = ybase^0.75
373     vdivps(ydst, ysrc, ysum); // ydst = ysrc / ysum
374     vmovups(ptr[dst], ydst);
375
376     add(src, 32);
377     add(dst, 32);
378     if (pk != prop_kind::forward_inference)
379         add(scratch, 32);
380     dec(hw);
381     cmp(hw, 0);
382     jne(lrn_loop, T_NEAR);
383
384     add(t, 64);
385     this->postamble();
386
387     ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>(
388                 this->getCode()));
389 }
390
391 template<>
392 jit_uni_lrn_fwd_kernel_f32<sse42>::jit_uni_lrn_fwd_kernel_f32(
393     const struct nchw8c_across &J,
394     float A,
395     float K,
396     prop_kind_t pk,
397     void *code_ptr,
398     size_t code_size)
399     : jit_generator(code_ptr, code_size)
400     , alpha(A), k(K)
401 {
402     Xbyak::Reg64 t = rsp;
403     Xbyak::Reg64 hw = r9;
404
405     Xbyak::Xmm xsrc_lo = xmm2;
406     Xbyak::Xmm xsrc_hi = xmm3;
407     Xbyak::Xmm xc_lo = xmm4;
408     Xbyak::Xmm xc_hi = xmm5;
409     Xbyak::Xmm xsum_lo = xc_lo;
410     Xbyak::Xmm xsum_hi = xc_hi;
411     Xbyak::Xmm xsrc_prev = xmm6;
412     Xbyak::Xmm xsrc_next = xmm7;
413     Xbyak::Xmm xa_lo = xmm8;
414     Xbyak::Xmm xa_hi = xmm9;
415     Xbyak::Xmm xb_lo = xmm10;
416     Xbyak::Xmm xb_hi = xmm11;
417     Xbyak::Xmm xd_lo = xmm12;
418     Xbyak::Xmm xd_hi = xmm13;
419     Xbyak::Xmm xe_lo = xmm14;
420     Xbyak::Xmm xe_hi = xmm15;
421     Xbyak::Xmm xbase_lo = xmm14;
422     Xbyak::Xmm xbase_hi = xmm15;
423
424     this->preamble();
425
426     mov(src, ptr[this->param1 + 0]);
427     mov(dst, ptr[this->param1 + 8]);
428     if (pk != prop_kind::forward_inference)
429         mov(scratch, ptr[this->param1 + 16]);
430     sub(t, 64);
431     mov(imm_addr64, float2int(this->alpha));
432     movq(xalpha, imm_addr64);
433     shufps(xalpha, xalpha, 0);
434
435     mov(imm_addr64, float2int(this->k));
436     movq(xk, imm_addr64);
437     shufps(xk, xk, 0);
438
439     if (J.version == -1)
440     {
441         xorps(xsrc_prev, xsrc_prev);
442         movups(ptr[t + 0], xsrc_prev);
443     }
444     if (J.version == +1)
445     {
446         xorps(xsrc_next, xsrc_next);
447         movups(ptr[t + 48], xsrc_next);
448     }
449
450     mov(hw, J.H*J.W);
451     Label lrn_loop;
452     L(lrn_loop);
453
454     if (J.version != -1) movups(xsrc_prev, ptr[src - J.H*J.W * 32 + 16]);
455     movups(xsrc_lo, ptr[src]);
456     movups(xsrc_hi, ptr[src + 4 * sizeof(float)]);
457     if (J.version != +1) movups(xsrc_next, ptr[src + J.H*J.W * 32]);
458
459     if (J.version != -1) movups(ptr[t + 0], xsrc_prev);
460     movups(ptr[t + 16], xsrc_lo);
461     movups(ptr[t + 16 + 4 * sizeof(float)], xsrc_hi);
462     if (J.version != +1) movups(ptr[t + 48], xsrc_next);
463
464     movups(xa_lo, ptr[t + 16 - 8]);
465     movups(xa_hi, ptr[t + 16 - 8 + 4 * sizeof(float)]);
466     movups(xb_lo, ptr[t + 16 - 4]);
467     movups(xb_hi, ptr[t + 16 - 4 + 4 * sizeof(float)]);
468     movups(xd_lo, ptr[t + 16 + 4]);
469     movups(xd_hi, ptr[t + 16 + 4 + 4 * sizeof(float)]);
470     movups(xe_lo, ptr[t + 16 + 8]);
471     movups(xe_hi, ptr[t + 16 + 8 + 4 * sizeof(float)]);
472     movaps(xc_lo, xsrc_lo);
473     movaps(xc_hi, xsrc_hi);
474     mulps(xsum_lo, xc_lo);
475     mulps(xsum_hi, xc_hi);
476     mulps(xa_lo, xa_lo);
477     mulps(xa_hi, xa_hi);
478     addps(xsum_lo, xa_lo);
479     addps(xsum_hi, xa_hi); // xsum <- xsum + xa*xa
480     mulps(xb_lo, xb_lo);
481     mulps(xb_hi, xb_hi);
482     addps(xsum_lo, xb_lo);
483     addps(xsum_hi, xb_hi);
484     mulps(xd_lo, xd_lo);
485     mulps(xd_hi, xd_hi);
486     addps(xsum_lo, xd_lo);
487     addps(xsum_hi, xd_hi);
488     mulps(xe_lo, xe_lo);
489     mulps(xe_hi, xe_hi);
490     addps(xsum_lo, xe_lo);
491     addps(xsum_hi, xe_hi);
492
493     mulps(xsum_lo, xalpha);
494     mulps(xsum_hi, xalpha);
495     addps(xsum_lo, xk);
496     addps(xsum_hi, xk); // xsum <- xsum*xalpha+xk
497
498     movaps(xbase_lo, xsum_lo);
499     movaps(xbase_hi, xsum_hi);
500     if (pk != prop_kind::forward_inference) {
501         movups(ptr[scratch], xbase_lo);
502         movups(ptr[scratch + 4 * sizeof(float)], xbase_hi);
503     }
504     mulps(xsum_lo, xsum_lo);
505     mulps(xsum_hi, xsum_hi);
506     mulps(xsum_lo, xbase_lo);
507     mulps(xsum_hi, xbase_hi); // xsum = xbase^3;
508     sqrtps(xsum_lo, xsum_lo);
509     sqrtps(xsum_hi, xsum_hi);
510     sqrtps(xsum_lo, xsum_lo);
511     sqrtps(xsum_hi, xsum_hi); // xsum = xbase^0.75
512     divps(xsrc_lo, xsum_lo);
513     divps(xsrc_hi, xsum_hi); // xdst = xsrc / xsum
514     movups(ptr[dst], xsrc_lo);
515     movups(ptr[dst + 4 * sizeof(float)], xsrc_hi);
516
517     add(src, 32);
518     add(dst, 32);
519     if (pk != prop_kind::forward_inference)
520         add(scratch, 32);
521     dec(hw);
522     cmp(hw, 0);
523     jne(lrn_loop, T_NEAR);
524
525     add(t, 64);
526     this->postamble();
527
528     ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>(
529         this->getCode()));
530 }
531
532 template<>
533 jit_uni_lrn_fwd_kernel_f32<avx2>::jit_uni_lrn_fwd_kernel_f32(
534     const struct nhwc_across &J,
535     float A,
536     float K,
537     prop_kind_t pk,
538     void *code_ptr,
539     size_t code_size)
540     : jit_generator(code_ptr, code_size)
541     , alpha(A), k(K)
542 {
543     static const uint32_t mask[] = {
544         0, 0, 0x80000000, 0x80000000, 0x80000000, 0x80000000,
545         0x80000000, 0x80000000, 0x80000000, 0, 0
546     };
547
548     Xbyak::Reg64 c = r9;
549     Xbyak::Ymm ya = ymm2;
550     Xbyak::Ymm yb = ymm3;
551     Xbyak::Ymm yc = ymm4;
552     Xbyak::Ymm yd = ymm5;
553     Xbyak::Ymm ye = ymm6;
554     Xbyak::Ymm ysum = ymm7;
555     Xbyak::Ymm ydst = ymm8;
556     Xbyak::Ymm ybase = ymm9;
557     Xbyak::Ymm ymask = ymm10;
558
559     this->preamble();
560
561     mov(src, ptr[this->param1 + 0]);
562     mov(dst, ptr[this->param1 + 8]);
563     if (pk != prop_kind::forward_inference)
564         mov(scratch, ptr[this->param1 + 16]);
565     mov(imm_addr64, float2int(this->alpha));
566     movq(xalpha, imm_addr64);
567     vbroadcastss(yalpha, xalpha);
568
569     mov(imm_addr64, float2int(this->k));
570     movq(xk, imm_addr64);
571     vbroadcastss(yk, xk);
572
573     vxorps(ysum, ysum, ysum);
574
575     mov(imm_addr64, reinterpret_cast<size_t>(&mask[0]));
576     vmovups(ymask, ptr[imm_addr64]);
577     vmaskmovps(ya, ymask, ptr[src - 8]);
578     vfmadd231ps(ysum, ya, ya); // ysum <- ysum + ya^2+yb^2+yc^2+yd^2+ye^2
579
580     mov(imm_addr64, reinterpret_cast<size_t>(&mask[1]));
581     vmovups(ymask, ptr[imm_addr64]);
582     vmaskmovps(yb, ymask, ptr[src - 4]);
583     vfmadd231ps(ysum, yb, yb);
584
585     mov(c, J.C / 8 - 1);
586     Label lrn_loop;
587     L(lrn_loop);
588
589     vmovups(yc, ptr[src]);
590     vmovups(yd, ptr[src + 4]);
591     vmovups(ye, ptr[src + 8]);
592     vfmadd231ps(ysum, yc, yc);
593     vfmadd231ps(ysum, yd, yd);
594     vfmadd231ps(ysum, ye, ye);
595
596     vmovups(ydst, ysum);
597     vfmadd132ps(ydst, yk, yalpha); // ydst <- ysum*yalpha+yk
598
599     vmovaps(ybase, ydst);
600     if (pk != prop_kind::forward_inference)
601         vmovups(ptr[scratch], ybase);
602     vmulps(ydst, ydst, ydst);
603     vmulps(ydst, ydst, ybase); // ydst = (ysum*yalpha+yk)^3;
604     vsqrtps(ydst, ydst);
605     vsqrtps(ydst, ydst); // ydst = (ysum*yalpha+yk)^0.75
606
607     vdivps(ydst, yc, ydst); // ydst = ysrc / (ysum*yalpha+yk)^0.75
608     vmovups(ptr[dst], ydst);
609
610     vxorps(ysum, ysum, ysum);
611
612     add(src, 32);
613     add(dst, 32);
614     if (pk != prop_kind::forward_inference)
615         add(scratch, 32);
616
617     vmovups(ya, ptr[src - 8]);
618     vfmadd231ps(ysum, ya, ya);
619     vmovups(yb, ptr[src - 4]);
620     vfmadd231ps(ysum, yb, yb);
621
622     dec(c);
623     cmp(c, 0);
624     jne(lrn_loop, T_NEAR);
625
626     vmovups(yc, ptr[src]);
627     vfmadd231ps(ysum, yc, yc);
628
629     mov(imm_addr64, reinterpret_cast<size_t>(&mask[2]));
630     vmovups(ymask, ptr[imm_addr64]);
631     vmaskmovps(yd, ymask, ptr[src + 4]);
632     vfmadd231ps(ysum, yd, yd); // ysum <- ysum + ya^2+yb^2+yc^2+yd^2+ye^2
633
634     mov(imm_addr64, reinterpret_cast<size_t>(&mask[3]));
635     vmovups(ymask, ptr[imm_addr64]);
636     vmaskmovps(ye, ymask, ptr[src + 8]);
637     vfmadd231ps(ysum, ye, ye);
638
639     vmovups(ydst, ysum);
640     vfmadd132ps(ydst, yk, yalpha); // ydst <- ysum*yalpha+yk
641
642     vmovaps(ybase, ydst);
643     if (pk != prop_kind::forward_inference)
644         vmovups(ptr[scratch], ybase);
645     vmulps(ydst, ydst, ydst);
646     vmulps(ydst, ydst, ybase); // ydst = (ysum*yalpha+yk)^3;
647     vsqrtps(ydst, ydst);
648     vsqrtps(ydst, ydst); // ydst = (ysum*yalpha+yk)^0.75
649     vdivps(ydst, yc, ydst); // ydst = ysrc / (ysum*yalpha+yk)^0.75
650
651     vmovups(ptr[dst], ydst);
652
653     this->postamble();
654
655     ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>(
656                 this->getCode()));
657 }
658
659 template<>
660 jit_uni_lrn_fwd_kernel_f32<sse42>::jit_uni_lrn_fwd_kernel_f32(
661     const struct nhwc_across &J,
662     float A,
663     float K,
664     prop_kind_t pk,
665     void *code_ptr,
666     size_t code_size)
667     : jit_generator(code_ptr, code_size)
668     , alpha(A), k(K)
669 {
670     static const uint32_t mask[] = {
671         0, 0, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff,
672         0xffffffff, 0xffffffff, 0xffffffff, 0, 0
673     };
674
675     static uint32_t store[] = {
676         0, 0, 0, 0, 0, 0, 0, 0, 0, 0
677     };
678     Xbyak::Reg64 c = r9;
679
680     Xbyak::Xmm xdst_lo = xmm0;
681     Xbyak::Xmm xdst_hi = xmm1;
682     Xbyak::Xmm xa_lo = xmm2;
683     Xbyak::Xmm xa_hi = xmm3;
684     Xbyak::Xmm xb_lo = xmm2;
685     Xbyak::Xmm xb_hi = xmm3;
686     Xbyak::Xmm xc_lo = xmm4;
687     Xbyak::Xmm xc_hi = xmm5;
688     Xbyak::Xmm xd_lo = xmm6;
689     Xbyak::Xmm xd_hi = xmm7;
690     Xbyak::Xmm xe_lo = xmm8;
691     Xbyak::Xmm xe_hi = xmm9;
692     Xbyak::Xmm xsum_lo = xmm10;
693     Xbyak::Xmm xsum_hi = xmm11;
694     Xbyak::Xmm xmask_lo = xmm12;
695     Xbyak::Xmm xmask_hi = xmm13;
696     Xbyak::Xmm xbase_lo = xmm14;
697     Xbyak::Xmm xbase_hi = xmm15;
698
699     this->preamble();
700
701     mov(src, ptr[this->param1 + 0]);
702     mov(dst, ptr[this->param1 + 8]);
703     if (pk != prop_kind::forward_inference)
704         mov(scratch, ptr[this->param1 + 16]);
705     mov(imm_addr64, float2int(this->alpha));
706     movq(xalpha, imm_addr64);
707     shufps(xalpha, xalpha, 0);
708
709     mov(imm_addr64, float2int(this->k));
710     movq(xk, imm_addr64);
711     shufps(xk, xk, 0);
712
713     mov(store_addr, reinterpret_cast<size_t>(&store[0]));
714     and_(store_addr, -15);
715     movups(ptr[store_addr], xalpha);
716     movups(ptr[store_addr + 4 * sizeof(float)], xk);
717
718     xorps(xsum_lo, xsum_lo);
719     xorps(xsum_hi, xsum_hi);
720
721     mov(imm_addr64, reinterpret_cast<size_t>(&mask[0]));
722     movups(xmask_lo, ptr[imm_addr64]);
723     movups(xmask_hi, ptr[imm_addr64 + 4 * sizeof(float)]);
724     movups(xa_lo, ptr[src - 8]);
725     movups(xa_hi, ptr[src - 8 + 4 * sizeof(float)]);
726     andps(xa_lo, xmask_lo);
727     andps(xa_hi, xmask_hi);
728     mulps(xa_lo, xa_lo);
729     mulps(xa_hi, xa_hi);
730     addps(xsum_lo, xa_lo);
731     addps(xsum_hi, xa_hi); // xsum <- xsum + xa^2+xb^2+xc^2+xd^2+xe^2
732
733     mov(imm_addr64, reinterpret_cast<size_t>(&mask[1]));
734     movups(xmask_lo, ptr[imm_addr64]);
735     movups(xmask_hi, ptr[imm_addr64 + 4 * sizeof(float)]);
736     movups(xb_lo, ptr[src - 4]);
737     movups(xb_hi, ptr[src - 4 + 4 * sizeof(float)]);
738     andps(xb_lo, xmask_lo);
739     andps(xb_hi, xmask_hi);
740     mulps(xb_lo, xb_lo);
741     mulps(xb_hi, xb_hi);
742     addps(xsum_lo, xb_lo);
743     addps(xsum_hi, xb_hi);
744
745     mov(c, J.C / 8 - 1);
746     Label lrn_loop;
747     L(lrn_loop);
748
749     movups(xc_lo, ptr[src]);
750     movups(xc_hi, ptr[src + 4 * sizeof(float)]);
751     movups(xd_lo, ptr[src + 4]);
752     movups(xd_hi, ptr[src + 4 + 4 * sizeof(float)]);
753     movups(xe_lo, ptr[src + 8]);
754     movups(xe_hi, ptr[src + 8 + 4 * sizeof(float)]);
755     mulps(xc_lo, xc_lo);
756     mulps(xc_hi, xc_hi);
757     addps(xsum_lo, xc_lo);
758     addps(xsum_hi, xc_hi);
759     mulps(xd_lo, xd_lo);
760     mulps(xd_hi, xd_hi);
761     addps(xsum_lo, xd_lo);
762     addps(xsum_hi, xd_hi);
763     mulps(xe_lo, xe_lo);
764     mulps(xe_hi, xe_hi);
765     addps(xsum_lo, xe_lo);
766     addps(xsum_hi, xe_hi);
767
768     movaps(xdst_lo, xsum_lo);
769     movaps(xdst_hi, xsum_hi);
770     // xdst <- xsum*xalpha+xk
771     mulps(xdst_lo, ptr[store_addr]);
772     mulps(xdst_hi, ptr[store_addr]);
773     addps(xdst_lo, ptr[store_addr + 4 * sizeof(float)]);
774     addps(xdst_hi, ptr[store_addr + 4 * sizeof(float)]);
775
776     movaps(xbase_lo, xdst_lo);
777     movaps(xbase_hi, xdst_hi);
778     if (pk != prop_kind::forward_inference) {
779         movups(ptr[scratch], xbase_lo);
780         movups(ptr[scratch + 4 * sizeof(float)], xbase_hi);
781     }
782     mulps(xdst_lo, xdst_lo);
783     mulps(xdst_hi, xdst_hi);
784     mulps(xdst_lo, xbase_lo);
785     mulps(xdst_hi, xbase_hi); // xdst = (xsum*xalpha+xk)^3;
786     sqrtps(xdst_lo, xdst_lo);
787     sqrtps(xdst_hi, xdst_hi);
788     sqrtps(xdst_lo, xdst_lo);
789     sqrtps(xdst_hi, xdst_hi); // xdst = (xsum*xalpha+xk)^0.75
790
791     movups(xc_lo, ptr[src]);
792     movups(xc_hi, ptr[src + 4 * sizeof(float)]);
793     divps(xc_lo, xdst_lo);
794     divps(xc_hi, xdst_hi); // xdst = xsrc / (xsum*xalpha+xk)^0.75
795     movups(ptr[dst], xc_lo);
796     movups(ptr[dst + 4 * sizeof(float)], xc_hi);
797
798     xorps(xsum_lo, xsum_lo);
799     xorps(xsum_hi, xsum_hi);
800
801     add(src, 32);
802     add(dst, 32);
803     if (pk != prop_kind::forward_inference)
804         add(scratch, 32);
805
806     movups(xa_lo, ptr[src - 8]);
807     movups(xa_hi, ptr[src - 8 + 4 * sizeof(float)]);
808     mulps(xa_lo, xa_lo);
809     mulps(xa_hi, xa_hi);
810     addps(xsum_lo, xa_lo);
811     addps(xsum_hi, xa_hi);
812     movups(xb_lo, ptr[src - 4]);
813     movups(xb_hi, ptr[src - 4 + 4 * sizeof(float)]);
814     mulps(xb_lo, xb_lo);
815     mulps(xb_hi, xb_hi);
816     addps(xsum_lo, xb_lo);
817     addps(xsum_hi, xb_hi);
818
819     dec(c);
820     cmp(c, 0);
821     jne(lrn_loop, T_NEAR);
822
823     movups(xc_lo, ptr[src]);
824     movups(xc_hi, ptr[src + 4 * sizeof(float)]);
825     mulps(xc_lo, xc_lo);
826     mulps(xc_hi, xc_hi);
827     addps(xsum_lo, xc_lo);
828     addps(xsum_hi, xc_hi);
829
830     mov(imm_addr64, reinterpret_cast<size_t>(&mask[2]));
831     movups(xmask_lo, ptr[imm_addr64]);
832     movups(xmask_hi, ptr[imm_addr64 + 4 * sizeof(float)]);
833     movups(xd_lo, ptr[src + 4]);
834     movups(xd_hi, ptr[src + 4 + 4 * sizeof(float)]);
835     andps(xd_lo, xmask_lo);
836     andps(xd_hi, xmask_hi);
837     mulps(xd_lo, xd_lo);
838     mulps(xd_hi, xd_hi);
839     addps(xsum_lo, xd_lo);
840     addps(xsum_hi, xd_hi); // xsum <- xsum + xa^2+xb^2+xc^2+xd^2+xe^2
841
842     mov(imm_addr64, reinterpret_cast<size_t>(&mask[3]));
843     movups(xmask_lo, ptr[imm_addr64]);
844     movups(xmask_hi, ptr[imm_addr64 + 4 * sizeof(float)]);
845     movups(xe_lo, ptr[src + 8]);
846     movups(xe_hi, ptr[src + 8 + 4 * sizeof(float)]);
847     andps(xe_lo, xmask_lo);
848     andps(xe_hi, xmask_hi);
849     mulps(xe_lo, xe_lo);
850     mulps(xe_hi, xe_hi);
851     addps(xsum_lo, xe_lo);
852     addps(xsum_hi, xe_hi);
853
854     movups(xdst_lo, xsum_lo);
855     movups(xdst_hi, xsum_hi);
856     // xdst <- xsum*xalpha+xk
857     mulps(xdst_lo, ptr[store_addr]);
858     mulps(xdst_hi, ptr[store_addr]);
859     addps(xdst_lo, ptr[store_addr + 4 * sizeof(float)]);
860     addps(xdst_hi, ptr[store_addr + 4 * sizeof(float)]);
861
862     movaps(xbase_lo, xdst_lo);
863     movaps(xbase_hi, xdst_hi);
864     if (pk != prop_kind::forward_inference) {
865         movups(ptr[scratch], xbase_lo);
866         movups(ptr[scratch + 4 * sizeof(float)], xbase_hi);
867     }
868     mulps(xdst_lo, xdst_lo);
869     mulps(xdst_hi, xdst_hi);
870     mulps(xdst_lo, xbase_lo);
871     mulps(xdst_hi, xbase_hi); // xdst = (xsum*xalpha+xk)^3;
872     sqrtps(xdst_lo, xdst_lo);
873     sqrtps(xdst_hi, xdst_hi);
874     sqrtps(xdst_lo, xdst_lo);
875     sqrtps(xdst_hi, xdst_hi); // xdst = (xsum*xalpha+xk)^0.75
876     movups(xc_lo, ptr[src]);
877     movups(xc_hi, ptr[src + 4 * sizeof(float)]);
878     divps(xc_lo, xdst_lo);
879     divps(xc_hi, xdst_hi); // xdst = xsrc / (xsum*xalpha+xk)^0.75
880
881     movups(ptr[dst], xc_lo);
882     movups(ptr[dst + 4 * sizeof(float)], xc_hi);
883
884     this->postamble();
885
886     ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>(
887         this->getCode()));
888 }
889
890 template<>
891 void jit_uni_lrn_fwd_kernel_f32<sse42>::nchw_body(
892     int tail, int HW, prop_kind_t pk,
893     Xbyak::Ymm ymask,
894     Xbyak::Ymm ya,
895     Xbyak::Ymm yb,
896     Xbyak::Ymm yc,
897     Xbyak::Ymm yd,
898     Xbyak::Ymm ye,
899     Xbyak::Ymm ysum) {}
900
901 template<>
902 void jit_uni_lrn_fwd_kernel_f32<avx2>::nchw_body(
903     int tail, int HW, prop_kind_t pk,
904     Xbyak::Ymm ymask,
905     Xbyak::Ymm ya,
906     Xbyak::Ymm yb,
907     Xbyak::Ymm yc,
908     Xbyak::Ymm yd,
909     Xbyak::Ymm ye,
910     Xbyak::Ymm ysum)
911 {
912     Xbyak::Ymm ydst = ymm14;
913     Xbyak::Ymm ybase = ymm15;
914
915     vfmadd231ps(ysum, ye, ye);
916
917     vmovups(ydst, ysum);
918     vfmadd132ps(ydst, yk, yalpha); // ydst <- ysum*yalpha+yk
919
920     vmovaps(ybase, ydst);
921     if (pk != prop_kind::forward_inference)
922     {
923         if (tail != 0)
924             vmaskmovps(ptr[scratch], ymask, ybase);
925         else
926             vmovups(ptr[scratch], ybase);
927     }
928     vmulps(ydst, ydst, ydst);
929     vmulps(ydst, ydst, ybase); // ydst = (ysum*yalpha+yk)^3;
930     vsqrtps(ydst, ydst);
931     vsqrtps(ydst, ydst); // ydst = (ysum*yalpha+yk)^0.75
932     vdivps(ydst, yc, ydst); // ydst = ysrc / (ysum*yalpha+yk)^0.75
933
934     if (tail != 0)
935         vmaskmovps(ptr[dst], ymask, ydst);
936     else
937         vmovups(ptr[dst], ydst);
938
939
940     vfnmadd231ps(ysum, ya, ya);
941     vmovups(ya, yb);
942     vmovups(yb, yc);
943     vmovups(yc, yd);
944     vmovups(yd, ye);
945 }
946
947 template<>
948 void jit_uni_lrn_fwd_kernel_f32<avx2>::nchw_tail_sse42(
949     int tail, Xbyak::Reg64 reg_dst, Xbyak::Xmm xtail_lo, Xbyak::Xmm xtail_hi)
950 {}
951
952 template<>
953 void jit_uni_lrn_fwd_kernel_f32<sse42>::nchw_tail_sse42(
954     int tail, Xbyak::Reg64 reg_dst, Xbyak::Xmm xtail_lo, Xbyak::Xmm xtail_hi)
955 {
956     Xbyak::Xmm xmm_tmp = xmm10;
957     movaps(xmm_tmp, xtail_lo);
958     size_t offset = 0;
959
960     if (tail > 4) {
961         movups(ptr[reg_dst], xtail_lo);
962         movaps(xmm_tmp, xtail_hi);
963         offset += 4 * sizeof(float);
964         tail -= 4;
965     }
966     movss(ptr[reg_dst + offset], xmm_tmp);
967     for (int i = 1; i < tail; i++)
968     {
969         psrldq(xmm_tmp, 4);
970         movss(ptr[reg_dst + offset + i * sizeof(float)], xmm_tmp);
971     }
972 }
973
974 template<>
975 void jit_uni_lrn_fwd_kernel_f32<sse42>::nchw_body_sse42(
976     int tail, int HW, prop_kind_t pk,
977     Xbyak::Xmm xmask_lo, Xbyak::Xmm xmask_hi,
978     Xbyak::Xmm xe_lo, Xbyak::Xmm xe_hi,
979     Xbyak::Xmm xsum_lo, Xbyak::Xmm xsum_hi)
980 {
981     Xbyak::Xmm xdst_lo = xmm0;
982     Xbyak::Xmm xdst_hi = xmm1;
983     Xbyak::Xmm xbase_lo = xmm6;
984     Xbyak::Xmm xbase_hi = xmm7;
985     Xbyak::Xmm xtmp_lo = xmm8;
986     Xbyak::Xmm xtmp_hi = xmm9;
987     Xbyak::Xmm xa_lo = xmm6;
988     Xbyak::Xmm xa_hi = xmm7;
989     Xbyak::Xmm xb_lo = xmm8;
990     Xbyak::Xmm xb_hi = xmm9;
991     Xbyak::Xmm xc_lo = xmm10;
992     Xbyak::Xmm xc_hi = xmm11;
993     Xbyak::Xmm xd_lo = xmm12;
994     Xbyak::Xmm xd_hi = xmm13;
995
996     // store xe
997     movaps(ptr[store_addr + 10 * 4 * sizeof(float)], xe_lo);
998     movaps(ptr[store_addr + 11 * 4 * sizeof(float)], xe_hi);
999
1000     mulps(xe_lo, xe_lo);
1001     mulps(xe_hi, xe_hi);
1002     addps(xsum_lo, xe_lo);
1003     addps(xsum_hi, xe_hi);
1004
1005     // xdst <- xsum*xalpha+xk
1006     movaps(xdst_lo, xsum_lo);
1007     movaps(xdst_hi, xsum_hi);
1008     mulps(xdst_lo, ptr[store_addr + 0 * 4 * sizeof(float)]);
1009     mulps(xdst_hi, ptr[store_addr + 0 * 4 * sizeof(float)]);
1010     addps(xdst_lo, ptr[store_addr + 1 * 4 * sizeof(float)]);
1011     addps(xdst_hi, ptr[store_addr + 1 * 4 * sizeof(float)]);
1012
1013     movaps(xbase_lo, xdst_lo);
1014     movaps(xbase_hi, xdst_hi);
1015     if (pk != prop_kind::forward_inference)
1016     {
1017         if (tail != 0) {
1018             nchw_tail_sse42(tail, scratch, xbase_lo, xbase_hi);
1019         }
1020         else {
1021             movups(ptr[scratch], xbase_lo);
1022             movups(ptr[scratch + 4 * sizeof(float)], xbase_hi);
1023         }
1024     }
1025     mulps(xdst_lo, xdst_lo);
1026     mulps(xdst_hi, xdst_hi);
1027     mulps(xdst_lo, xbase_lo);
1028     mulps(xdst_hi, xbase_hi); // xdst = (xsum*xalpha+xk)^3;
1029     sqrtps(xdst_lo, xdst_lo);
1030     sqrtps(xdst_hi, xdst_hi);
1031     sqrtps(xdst_lo, xdst_lo);
1032     sqrtps(xdst_hi, xdst_hi); // xdst = (xsum*xalpha+xk)^0.75
1033     movaps(xtmp_lo, ptr[store_addr + 6 * 4 * sizeof(float)]);
1034     movaps(xtmp_hi, ptr[store_addr + 7 * 4 * sizeof(float)]);
1035     divps(xtmp_lo, xdst_lo);
1036     divps(xtmp_hi, xdst_hi); // xdst = xsrc / (xsum*xalpha+xk)^0.75
1037     movaps(xdst_lo, xtmp_lo);
1038     movaps(xdst_hi, xtmp_hi);
1039
1040     if (tail != 0) {
1041         nchw_tail_sse42(tail, dst, xdst_lo, xdst_hi);
1042     }
1043     else {
1044         movups(ptr[dst], xdst_lo);
1045         movups(ptr[dst + 4 * sizeof(float)], xdst_hi);
1046     }
1047
1048     movaps(xa_lo, ptr[store_addr + 2 * 4 * sizeof(float)]);
1049     movaps(xa_hi, ptr[store_addr + 3 * 4 * sizeof(float)]);
1050     mulps(xa_lo, xa_lo);
1051     mulps(xa_hi, xa_hi);
1052     subps(xsum_lo, xa_lo);
1053     subps(xsum_hi, xa_hi);
1054
1055     // xa <- xb
1056     movaps(xb_lo, ptr[store_addr + 4 * 4 * sizeof(float)]);
1057     movaps(xb_hi, ptr[store_addr + 5 * 4 * sizeof(float)]);
1058     movaps(ptr[store_addr + 2 * 4 * sizeof(float)], xb_lo);
1059     movaps(ptr[store_addr + 3 * 4 * sizeof(float)], xb_hi);
1060
1061     // xb <- xc
1062     movaps(xc_lo, ptr[store_addr + 6 * 4 * sizeof(float)]);
1063     movaps(xc_hi, ptr[store_addr + 7 * 4 * sizeof(float)]);
1064     movaps(ptr[store_addr + 4 * 4 * sizeof(float)], xc_lo);
1065     movaps(ptr[store_addr + 5 * 4 * sizeof(float)], xc_hi);
1066
1067     // xc <- xd
1068     movaps(xd_lo, ptr[store_addr + 8 * 4 * sizeof(float)]);
1069     movaps(xd_hi, ptr[store_addr + 9 * 4 * sizeof(float)]);
1070     movaps(ptr[store_addr + 6 * 4 * sizeof(float)], xd_lo);
1071     movaps(ptr[store_addr + 7 * 4 * sizeof(float)], xd_hi);
1072
1073     // xd <- xe
1074     movaps(xe_lo, ptr[store_addr + 10 * 4 * sizeof(float)]);
1075     movaps(xe_hi, ptr[store_addr + 11 * 4 * sizeof(float)]);
1076     movaps(ptr[store_addr + 8 * 4 * sizeof(float)], xe_lo);
1077     movaps(ptr[store_addr + 9 * 4 * sizeof(float)], xe_hi);
1078 }
1079
1080 template<>
1081 void jit_uni_lrn_fwd_kernel_f32<avx2>::nchw_body_sse42(
1082     int tail, int HW, prop_kind_t pk,
1083     Xbyak::Xmm xmask_lo, Xbyak::Xmm xmask_hi,
1084     Xbyak::Xmm xe_lo, Xbyak::Xmm xe_hi,
1085     Xbyak::Xmm xsum_lo, Xbyak::Xmm xsum_hi) {}
1086
1087 template<>
1088 jit_uni_lrn_fwd_kernel_f32<avx2>::jit_uni_lrn_fwd_kernel_f32(
1089     struct nchw_across J,
1090     float A,
1091     float K,
1092     prop_kind_t pk,
1093     void* code_ptr,
1094     size_t code_size)
1095     : jit_generator(code_ptr, code_size)
1096     , alpha(A), k(K)
1097 {
1098     static const uint32_t mask[] = {
1099         0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000,
1100         0x80000000, 0x80000000, 0, 0, 0, 0, 0, 0, 0
1101     };
1102     Xbyak::Reg64 c = r10;
1103     Xbyak::Ymm ymask = ymm2;
1104     Xbyak::Ymm ye = ymm3;
1105     Xbyak::Ymm ya = ymm4;
1106     Xbyak::Ymm yb = ymm5;
1107     Xbyak::Ymm yc = ymm6;
1108     Xbyak::Ymm yd = ymm7;
1109     Xbyak::Ymm ysum = ymm8;
1110
1111     this->preamble();
1112
1113     if (J.tail != 0)
1114     {
1115         mov(imm_addr64, reinterpret_cast<size_t>(&mask[7 - J.tail]));
1116         vmovups(ymask, ptr[imm_addr64]);
1117     }
1118     mov(imm_addr64, float2int(this->alpha));
1119     movq(xalpha, imm_addr64);
1120     vbroadcastss(yalpha, xalpha);
1121
1122     mov(imm_addr64, float2int(this->k));
1123     movq(xk, imm_addr64);
1124     vbroadcastss(yk, xk);
1125
1126     mov(src, ptr[this->param1 + 0]);
1127     mov(dst, ptr[this->param1 + 8]);
1128     if (pk != prop_kind::forward_inference)
1129         mov(scratch, ptr[this->param1 + 16]);
1130
1131     vxorps(ya, ya, ya);
1132     vxorps(yb, yb, yb);
1133     if (J.tail != 0)
1134         vmaskmovps(yc, ymask, ptr[src + J.HW * 0]);
1135     else
1136         vmovups(yc, ptr[src + J.HW * 0]);
1137     if (J.tail != 0)
1138         vmaskmovps(yd, ymask, ptr[src + J.HW * 4]);
1139     else
1140         vmovups(yd, ptr[src + J.HW * 4]);
1141
1142     vxorps(ysum, ysum, ysum);
1143     vfmadd231ps(ysum, yc, yc); // ysum <- ysum + ya^2+yb^2+yc^2+yd^2+ye^2
1144     vfmadd231ps(ysum, yd, yd);
1145
1146     mov(c, J.C - 2);
1147     Label lrn_loop;
1148     L(lrn_loop);
1149
1150     if (J.tail != 0)
1151         vmaskmovps(ye, ymask, ptr[src + J.HW * 8]);
1152     else
1153         vmovups(ye, ptr[src + J.HW * 8]);
1154
1155     nchw_body(J.tail, J.HW, pk, ymask, ya, yb, yc, yd, ye, ysum);
1156
1157     add(src, J.HW * 4);
1158     add(dst, J.HW * 4);
1159     if (pk != prop_kind::forward_inference)
1160         add(scratch, J.HW * 4);
1161     dec(c);
1162     cmp(c, 0);
1163     jne(lrn_loop, T_NEAR);
1164
1165     vxorps(ye, ye, ye);
1166
1167     nchw_body(J.tail, J.HW, pk, ymask, ya, yb, yc, yd, ye, ysum);
1168     add(src, J.HW * 4);
1169     add(dst, J.HW * 4);
1170     if (pk != prop_kind::forward_inference)
1171         add(scratch, J.HW * 4);
1172
1173     nchw_body(J.tail, J.HW, pk, ymask, ya, yb, yc, yd, ye, ysum);
1174
1175     this->postamble();
1176
1177     ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>(
1178                 this->getCode()));
1179 }
1180
1181 template<>
1182 jit_uni_lrn_fwd_kernel_f32<sse42>::jit_uni_lrn_fwd_kernel_f32(
1183     struct nchw_across J,
1184     float A,
1185     float K,
1186     prop_kind_t pk,
1187     void* code_ptr,
1188     size_t code_size)
1189     : jit_generator(code_ptr, code_size)
1190     , alpha(A), k(K)
1191 {
1192     static const uint32_t mask[] = {
1193         0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff,
1194         0xffffffff, 0xffffffff, 0, 0, 0, 0, 0, 0, 0
1195     };
1196
1197     Xbyak::Reg64 c = r10;
1198
1199     Xbyak::Xmm xmask_lo = xmm2;
1200     Xbyak::Xmm xmask_hi = xmm3;
1201     Xbyak::Xmm xsum_lo = xmm4;
1202     Xbyak::Xmm xsum_hi = xmm5;
1203     Xbyak::Xmm xa_lo = xmm6;
1204     Xbyak::Xmm xa_hi = xmm7;
1205     Xbyak::Xmm xb_lo = xmm8;
1206     Xbyak::Xmm xb_hi = xmm9;
1207     Xbyak::Xmm xc_lo = xmm10;
1208     Xbyak::Xmm xc_hi = xmm11;
1209     Xbyak::Xmm xd_lo = xmm12;
1210     Xbyak::Xmm xd_hi = xmm13;
1211     Xbyak::Xmm xe_lo = xmm14;
1212     Xbyak::Xmm xe_hi = xmm15;
1213
1214     this->preamble();
1215
1216     mov(src, ptr[this->param1 + 0]);
1217     mov(dst, ptr[this->param1 + 8]);
1218     if (pk != prop_kind::forward_inference)
1219         mov(scratch, ptr[this->param1 + 16]);
1220
1221     sub(rsp, stack_space_needed);
1222     mov(store_addr, rsp);
1223     and_(store_addr, -15);
1224
1225     mov(imm_addr64, float2int(this->alpha));
1226     movq(xalpha, imm_addr64);
1227     shufps(xalpha, xalpha, 0);
1228
1229     mov(imm_addr64, float2int(this->k));
1230     movq(xk, imm_addr64);
1231     shufps(xk, xk, 0);
1232
1233     // put alpha and k into store (free up regs)
1234     movaps(ptr[store_addr + 0 * 4 * sizeof(float)], xalpha);
1235     movaps(ptr[store_addr + 1 * 4 * sizeof(float)], xk);
1236
1237     if (J.tail != 0)
1238     {
1239         mov(imm_addr64, reinterpret_cast<size_t>(&mask[7 - J.tail]));
1240         movups(xmask_lo, ptr[imm_addr64]);
1241         movups(xmask_hi, ptr[imm_addr64 + 4 * sizeof(float)]);
1242     }
1243     // init xa, xb
1244     xorps(xa_lo, xa_lo);
1245     xorps(xa_hi, xa_hi);
1246     xorps(xb_lo, xb_lo);
1247     xorps(xb_hi, xb_hi);
1248
1249     // read xc, xd
1250     if (J.tail != 0) {
1251         movups(xc_lo, ptr[src + J.HW * 0]);
1252         movups(xc_hi, ptr[src + J.HW * 0 + 4 * sizeof(float)]);
1253         andps(xc_lo, xmask_lo);
1254         andps(xc_hi, xmask_hi);
1255     }
1256     else {
1257         movups(xc_lo, ptr[src + J.HW * 0]);
1258         movups(xc_hi, ptr[src + J.HW * 0 + 4 * sizeof(float)]);
1259     }
1260     if (J.tail != 0) {
1261         movups(xd_lo, ptr[src + J.HW * 4]);
1262         movups(xd_hi, ptr[src + J.HW * 4 + 4 * sizeof(float)]);
1263         andps(xd_lo, xmask_lo);
1264         andps(xd_hi, xmask_hi);
1265     }
1266     else {
1267         movups(xd_lo, ptr[src + J.HW * 4]);
1268         movups(xd_hi, ptr[src + J.HW * 4 + 4 * sizeof(float)]);
1269     }
1270
1271     // put xa, xb, xc, xd into store to free-up regs
1272     movaps(ptr[store_addr + 2 * 4 * sizeof(float)], xa_lo);
1273     movaps(ptr[store_addr + 3 * 4 * sizeof(float)], xa_hi);
1274     movaps(ptr[store_addr + 4 * 4 * sizeof(float)], xb_lo);
1275     movaps(ptr[store_addr + 5 * 4 * sizeof(float)], xb_hi);
1276     movaps(ptr[store_addr + 6 * 4 * sizeof(float)], xc_lo);
1277     movaps(ptr[store_addr + 7 * 4 * sizeof(float)], xc_hi);
1278     movaps(ptr[store_addr + 8 * 4 * sizeof(float)], xd_lo);
1279     movaps(ptr[store_addr + 9 * 4 * sizeof(float)], xd_hi);
1280
1281     xorps(xsum_lo, xsum_lo);
1282     xorps(xsum_hi, xsum_hi);
1283     mulps(xc_lo, xc_lo);
1284     mulps(xc_hi, xc_hi);
1285     addps(xsum_lo, xc_lo);
1286     addps(xsum_hi, xc_hi);
1287     mulps(xd_lo, xd_lo);
1288     mulps(xd_hi, xd_hi);
1289     addps(xsum_lo, xd_lo);
1290     addps(xsum_hi, xd_hi); // xsum <- xsum + xa^2+xb^2+xc^2+xd^2+xe^2
1291
1292     mov(c, J.C - 2);
1293     Label lrn_loop;
1294     L(lrn_loop);
1295
1296     if (J.tail != 0) {
1297         movups(xe_lo, ptr[src + J.HW * 8]);
1298         movups(xe_hi, ptr[src + J.HW * 8 + 4 * sizeof(float)]);
1299         andps(xe_lo, xmask_lo);
1300         andps(xe_hi, xmask_hi);
1301     }
1302     else {
1303         movups(xe_lo, ptr[src + J.HW * 8]);
1304         movups(xe_hi, ptr[src + J.HW * 8 + 4 * sizeof(float)]);
1305     }
1306
1307     nchw_body_sse42(J.tail, J.HW, pk, xmask_lo, xmask_hi,
1308         xe_lo, xe_hi,
1309         xsum_lo, xsum_hi);
1310
1311     add(src, J.HW * 4);
1312     add(dst, J.HW * 4);
1313     if (pk != prop_kind::forward_inference)
1314         add(scratch, J.HW * 4);
1315     dec(c);
1316     cmp(c, 0);
1317     jne(lrn_loop, T_NEAR);
1318
1319     xorps(xe_lo, xe_lo);
1320     xorps(xe_hi, xe_hi);
1321
1322     nchw_body_sse42(J.tail, J.HW, pk, xmask_lo, xmask_hi,
1323         xe_lo, xe_hi,
1324         xsum_lo, xsum_hi);
1325     add(src, J.HW * 4);
1326     add(dst, J.HW * 4);
1327     if (pk != prop_kind::forward_inference)
1328         add(scratch, J.HW * 4);
1329
1330     nchw_body_sse42(J.tail, J.HW, pk, xmask_lo, xmask_hi,
1331         xe_lo, xe_hi,
1332         xsum_lo, xsum_hi);
1333
1334     add(rsp, stack_space_needed);
1335
1336     this->postamble();
1337
1338     ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>(
1339         this->getCode()));
1340 }
1341
1342 //////////////////////////////////////////////////////////////////////////////
1343 // backward kernel
1344 template <cpu_isa_t isa>
1345 jit_uni_lrn_bwd_kernel_f32<isa>::jit_uni_lrn_bwd_kernel_f32(
1346     const struct nchw8c_across &J,
1347     float A,
1348     float B,
1349     int use_h_parallel,
1350     void *code_ptr,
1351     size_t code_size)
1352     : jit_generator(code_ptr, code_size)
1353     , nalphabeta(-2 * A*B)
1354     , use_h_parallelizm(use_h_parallel)
1355 {
1356     Xbyak::Reg64 t = rsp;
1357     Xbyak::Reg64 hw = r10;
1358
1359     Xbyak::Xmm xsrc_prev = xmm1;
1360     Xbyak::Xmm xws_prev = xmm2;
1361     Xbyak::Xmm xdiffdst_prev = xmm3;
1362     Xbyak::Ymm ysrc = ymm4;
1363     Xbyak::Ymm yws = ymm5;
1364     Xbyak::Ymm ydiffdst = ymm6;
1365     Xbyak::Xmm xsrc_next = xmm7;
1366     Xbyak::Xmm xws_next = xmm8;
1367     Xbyak::Xmm xdiffdst_next = xmm9;
1368     Xbyak::Ymm ya = ymm10;
1369     Xbyak::Xmm xa = xmm10;
1370     Xbyak::Ymm yb = ymm11;
1371     Xbyak::Ymm yd = ymm12;
1372     Xbyak::Ymm ye = ymm13;
1373     Xbyak::Ymm ysum = ymm14;
1374     Xbyak::Ymm ydiffsrc = ymm15;
1375
1376     this->preamble();
1377
1378     mov(src, ptr[this->param1 + 0]);
1379     mov(diffdst, ptr[this->param1 + 8]);
1380     mov(workspace, ptr[this->param1 + 16]);
1381     mov(diffsrc, ptr[this->param1 + 24]);
1382
1383     sub(t, 64);
1384     mov(imm_addr64, float2int(this->nalphabeta));
1385     movq(xnalphabeta, imm_addr64);
1386     vbroadcastss(ynalphabeta, xnalphabeta);
1387
1388     bool is_single = J.version == 3;
1389     bool is_first = J.version == -1 || J.version == -2;
1390     bool is_last = J.version == +1 || J.version == -2;
1391
1392     if (is_first || is_single) {
1393         vxorps(xsrc_prev, xsrc_prev, xsrc_prev);
1394         vmovups(ptr[t + 0], xsrc_prev);
1395     }
1396     if (is_last || is_single) {
1397         vxorps(xsrc_next, xsrc_next, xsrc_next);
1398         vmovups(ptr[t + 48], xsrc_next);
1399     }
1400     mov(hw, this->use_h_parallelizm ? J.W : J.H*J.W);
1401     Label lrn_loop;
1402     L(lrn_loop);
1403     {
1404         if (!is_first && !is_single) {
1405             vmovups(xws_prev, ptr[workspace - J.H*J.W * 32 + 16]);
1406             vmovups(xsrc_prev, ptr[src - J.H*J.W * 32 + 16]);
1407             vmovups(xdiffdst_prev, ptr[diffdst - J.H*J.W * 32 + 16]);
1408             vmulps(xa, xws_prev, xws_prev);
1409             vmulps(xa, xa, xws_prev);
1410             vsqrtps(xa, xa);
1411             vsqrtps(xa, xa);
1412             vmulps(xa, xa, xws_prev);
1413             vdivps(xsrc_prev, xsrc_prev, xa);
1414             vmulps(xdiffdst_prev, xdiffdst_prev, xsrc_prev);
1415         }
1416
1417         vmovups(ysrc, ptr[src]);
1418         vmovups(yws, ptr[workspace]);
1419         vmovups(ydiffdst, ptr[diffdst]);
1420         vmulps(ya, yws, yws);
1421         vmulps(ya, ya, yws);
1422         vsqrtps(ya, ya);
1423         vsqrtps(ya, ya);
1424         vdivps(ydiffsrc, ydiffdst, ya);
1425         vdivps(ysum, ydiffsrc, yws);
1426         vmulps(ysum, ysum, ysrc);
1427
1428         if (!is_last && !is_single) {
1429             vmovups(xws_next, ptr[workspace + J.H*J.W * 32]);
1430             vmovups(xsrc_next, ptr[src + J.H*J.W * 32]);
1431             vmovups(xdiffdst_next, ptr[diffdst + J.H*J.W * 32]);
1432             vmulps(xa, xws_next, xws_next);
1433             vmulps(xa, xa, xws_next);
1434             vsqrtps(xa, xa);
1435             vsqrtps(xa, xa);
1436             vmulps(xa, xa, xws_next);
1437             vdivps(xsrc_next, xsrc_next, xa);
1438             vdivps(xsrc_next, xsrc_next, xws_next);
1439             vmulps(xdiffdst_next, xdiffdst_next, xsrc_next);
1440         }
1441
1442         if (!is_first && !is_single) vmovups(ptr[t + 0], xdiffdst_prev);
1443         vmovups(ptr[t + 16], ysum);
1444         if (!is_last && !is_single) vmovups(ptr[t + 48], xdiffdst_next);
1445
1446         vmovups(ya, ptr[t + 16 - 8]);
1447         vmovups(yb, ptr[t + 16 - 4]);
1448         vaddps(ysum, ysum, ya);
1449         vmulps(ysrc, ysrc, ynalphabeta);
1450         vaddps(ysum, ysum, yb);
1451
1452         vmovups(yd, ptr[t + 16 + 4]);
1453         vmovups(ye, ptr[t + 16 + 8]);
1454         vaddps(ysum, ysum, yd);
1455         vaddps(ysum, ysum, ye);
1456
1457         vfmadd231ps(ydiffsrc, ysum, ysrc);
1458
1459         vmovups(ptr[diffsrc], ydiffsrc);
1460
1461         add(src, 32);
1462         add(diffsrc, 32);
1463         add(diffdst, 32);
1464         add(workspace, 32);
1465
1466         dec(hw);
1467         cmp(hw, 0);
1468         jne(lrn_loop, T_NEAR);
1469     }
1470
1471     add(t, 64);
1472     this->postamble();
1473
1474     ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>(
1475         this->getCode()));
1476 }
1477
1478 template struct jit_uni_lrn_fwd_kernel_f32<sse42>;
1479 template struct jit_uni_lrn_fwd_kernel_f32<avx2>;
1480 template struct jit_uni_lrn_bwd_kernel_f32<avx2>;
1481
1482 }
1483 }
1484 }
1485
1486 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s