1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "c_types_map.hpp"
20 #include "jit_uni_lrn.hpp"
26 using namespace Xbyak;
28 //////////////////////////////////////////////////////////////////////////////
30 template<cpu_isa_t isa>
31 void jit_uni_lrn_fwd_kernel_f32<isa>::within_body(
32 int hoff, int Hoff, int woff, int Woff, int stride,
33 Xbyak::Ymm ysum, Xbyak::Ymm ydst, Xbyak::Ymm ytmp, Xbyak::Ymm ysum2,
36 vxorps(ysum, ysum, ysum);
37 for (int i = hoff; i <= Hoff; ++i)
39 for (int j = woff; j <= Woff; ++j)
43 vmovups(ydst, ptr[src]);
44 vfmadd231ps(ysum, ydst, ydst);
48 vmovups(ytmp, ptr[src + (i*stride + j)*VECTOR_LENGTH*4]);
49 vfmadd231ps(ysum, ytmp, ytmp);
53 vfmadd132ps(ysum, yk, yalpha); // ysum <- ysum*yalpha+yk
55 if (pk != prop_kind::forward_inference)
56 vmovups(ptr[scratch], ytmp);
57 vmulps(ysum2, ysum, ysum);
58 vmulps(ysum, ysum, ysum2); // ysum = (ysum*yalpha+yk)^3;
60 vsqrtps(ysum, ysum); // ysum = (ysum*yalpha+yk)^0.75
61 vdivps(ydst, ydst, ysum); // ydst <- ydst / ysum
62 vmovups(ptr[dst], ydst);
65 if (pk != prop_kind::forward_inference)
69 template<cpu_isa_t isa>
70 void jit_uni_lrn_fwd_kernel_f32<isa>::within_body_sse42(
71 int hoff, int Hoff, int woff, int Woff, int stride, prop_kind_t pk)
73 Xbyak::Xmm xtmp_lo = xmm12;
74 Xbyak::Xmm xtmp_hi = xmm13;
75 Xbyak::Xmm xsum_lo = xmm8;
76 Xbyak::Xmm xsum_hi = xmm9;
77 Xbyak::Xmm xdst_lo = xmm10;
78 Xbyak::Xmm xdst_hi = xmm11;
79 Xbyak::Xmm xsum2_lo = xmm14;
80 Xbyak::Xmm xsum2_hi = xmm15;
82 xorps(xsum_lo, xsum_lo);
83 xorps(xsum_hi, xsum_hi);
84 for (int i = hoff; i <= Hoff; ++i)
86 for (int j = woff; j <= Woff; ++j)
90 movups(xdst_lo, ptr[src]);
91 movups(xdst_hi, ptr[src + 4 * sizeof(float)]);
92 mulps(xdst_lo, xdst_lo);
93 mulps(xdst_hi, xdst_hi);
94 addps(xsum_lo, xdst_lo);
95 addps(xsum_hi, xdst_hi);
99 movups(xtmp_lo, ptr[src + (i*stride + j)*VECTOR_LENGTH * 4]);
100 movups(xtmp_hi, ptr[src + (i*stride + j)*VECTOR_LENGTH * 4 + 4 * sizeof(float)]);
101 mulps(xtmp_lo, xtmp_lo);
102 mulps(xtmp_hi, xtmp_hi);
103 addps(xsum_lo, xtmp_lo);
104 addps(xsum_hi, xtmp_hi);
108 mulps(xsum_lo, xalpha);
109 mulps(xsum_hi, xalpha);
111 addps(xsum_hi, xk); // xsum <- xsum*xalpha+xk
112 movaps(xtmp_lo, xsum_lo);
113 movaps(xtmp_hi, xsum_hi);
114 if (pk != prop_kind::forward_inference) {
115 movups(ptr[scratch], xtmp_lo);
116 movups(ptr[scratch + 4 * sizeof(float)], xtmp_hi);
118 movaps(xsum2_lo, xsum_lo);
119 movaps(xsum2_hi, xsum_hi);
120 mulps(xsum2_lo, xsum_lo);
121 mulps(xsum2_hi, xsum_hi);
122 mulps(xsum_lo, xsum2_lo);
123 mulps(xsum_hi, xsum2_hi); // xsum = (xsum*xalpha+xk)^3;
125 sqrtps(xsum_lo, xsum_lo);
126 sqrtps(xsum_hi, xsum_hi);
127 sqrtps(xsum_lo, xsum_lo);
128 sqrtps(xsum_hi, xsum_hi); // xsum = (xsum*xalpha+xk)^0.75
130 movups(xdst_lo, ptr[src]);
131 movups(xdst_hi, ptr[src + 4 * sizeof(float)]);
132 divps(xdst_lo, xsum_lo);
133 divps(xdst_hi, xsum_hi); // xdst <- xdst / xsum
135 movups(ptr[dst], xdst_lo);
136 movups(ptr[dst + 4 * sizeof(float)], xdst_hi);
139 if (pk != prop_kind::forward_inference)
143 template <cpu_isa_t isa>
144 jit_uni_lrn_fwd_kernel_f32<isa>::jit_uni_lrn_fwd_kernel_f32(
145 const struct nchw8c_within &J,
151 : jit_generator(code_ptr, code_size)
155 Xbyak::Reg64 w = r10;
156 Vmm ysum = Vmm(isa == avx2 ? 9 : 9);
157 Vmm ysum2 = Vmm(isa == avx2 ? 10 : 10);
158 Vmm ydst = Vmm(isa == avx2 ? 11 : 11);
159 Vmm ytmp = Vmm(isa == avx2 ? 12 : 12);
163 mov(src, ptr[this->param1 + 0]);
164 mov(dst, ptr[this->param1 + 8]);
165 if (pk != prop_kind::forward_inference)
166 mov(scratch, ptr[this->param1 + 16]);
168 mov(imm_addr64, float2int(this->alpha));
169 movq(xalpha, imm_addr64);
171 vbroadcastss(yalpha, xalpha);
173 shufps(xalpha, xalpha, 0);
176 mov(imm_addr64, float2int(this->k));
177 movq(xk, imm_addr64);
179 vbroadcastss(yk, xk);
184 int s2 = (J.size - 1) / 2, S2 = J.size - s2 - 1;
186 for (int i = 0; i < s2; ++i)
189 for (int j = 0; j < s2; ++j) {
191 within_body(-i, S2, -j, S2, J.W, ysum, ydst, ytmp, ysum2, pk);
194 within_body_sse42(-i, S2, -j, S2, J.W, pk);
197 mov(w, J.W - J.size + 1);
200 within_body(-i, S2, -s2, S2, J.W, ysum, ydst, ytmp, ysum2, pk);
202 within_body_sse42(-i, S2, -s2, S2, J.W, pk);
206 jne(label_t, T_NEAR);
207 for (int j = J.W - S2; j < J.W; ++j) {
209 within_body(-i, S2, -s2, J.W - 1 - j, J.W,
210 ysum, ydst, ytmp, ysum2, pk);
212 within_body_sse42(-i, S2, -s2, J.W - 1 - j, J.W, pk);
217 mov(h, J.H - J.size + 1);
220 for (int j = 0; j < s2; ++j) {
222 within_body(-s2, S2, -j, S2, J.W, ysum, ydst, ytmp, ysum2, pk);
224 within_body_sse42(-s2, S2, -j, S2, J.W, pk);
227 mov(w, J.W - J.size + 1);
231 within_body(-s2, S2, -s2, S2, J.W, ysum, ydst, ytmp, ysum2, pk);
233 within_body_sse42(-s2, S2, -s2, S2, J.W, pk);
237 jne(lrn_loop_w, T_NEAR);
238 for (int j = J.W - S2; j < J.W; ++j) {
240 within_body(-s2, S2, -s2, J.W - 1 - j, J.W,
241 ysum, ydst, ytmp, ysum2, pk);
243 within_body_sse42(-s2, S2, -s2, J.W - 1 - j, J.W, pk);
248 jne(lrn_loop_h, T_NEAR);
250 for (int i = J.H - S2; i < J.H; ++i)
252 for (int j = 0; j < s2; ++j) {
254 within_body(-s2, J.H - 1 - i, -j, S2, J.W,
255 ysum, ydst, ytmp, ysum2, pk);
257 within_body_sse42(-s2, J.H - 1 - i, -j, S2, J.W, pk);
261 mov(w, J.W - J.size + 1);
265 within_body(-s2, J.H - 1 - i, -s2, S2, J.W,
266 ysum, ydst, ytmp, ysum2, pk);
268 within_body_sse42(-s2, J.H - 1 - i, -s2, S2, J.W, pk);
272 jne(label_b, T_NEAR);
274 for (int j = J.W - S2; j < J.W; ++j) {
276 within_body(-s2, J.H - 1 - i, -s2, J.W - 1 - j, J.W,
277 ysum, ydst, ytmp, ysum2, pk);
279 within_body_sse42(-s2, J.H - 1 - i, -s2, J.W - 1 - j, J.W, pk);
286 ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>(
291 jit_uni_lrn_fwd_kernel_f32<avx2>::jit_uni_lrn_fwd_kernel_f32(
292 const struct nchw8c_across &J,
298 : jit_generator(code_ptr, code_size)
301 Xbyak::Reg64 t = rsp;
302 Xbyak::Reg64 hw = r9;
303 Xbyak::Xmm xsrc_prev = xmm2;
304 Xbyak::Ymm ysrc = ymm3;
305 Xbyak::Ymm yc = ymm3;
306 Xbyak::Xmm xsrc_next = xmm4;
307 Xbyak::Ymm ya = ymm5;
308 Xbyak::Ymm yb = ymm6;
309 Xbyak::Ymm yd = ymm7;
310 Xbyak::Ymm ye = ymm8;
311 Xbyak::Ymm ysum = ymm9;
312 Xbyak::Ymm ysum2 = ymm10;
313 Xbyak::Ymm ydst = ymm11;
314 Xbyak::Ymm ybase = ymm12;
318 mov(src, ptr[this->param1 + 0]);
319 mov(dst, ptr[this->param1 + 8]);
320 if (pk != prop_kind::forward_inference)
321 mov(scratch, ptr[this->param1 + 16]);
323 mov(imm_addr64, float2int(this->alpha));
324 movq(xalpha, imm_addr64);
325 vbroadcastss(yalpha, xalpha);
327 mov(imm_addr64, float2int(this->k));
328 movq(xk, imm_addr64);
329 vbroadcastss(yk, xk);
333 vxorps(xsrc_prev, xsrc_prev, xsrc_prev);
334 vmovups(ptr[t + 0], xsrc_prev);
338 vxorps(xsrc_next, xsrc_next, xsrc_next);
339 vmovups(ptr[t + 48], xsrc_next);
347 if (J.version != -1) vmovups(xsrc_prev, ptr[src - J.H*J.W * 32 + 16]);
348 vmovups(ysrc, ptr[src]);
349 if (J.version != +1) vmovups(xsrc_next, ptr[src + J.H*J.W * 32]);
351 if (J.version != -1) vmovups(ptr[t + 0], xsrc_prev);
352 vmovups(ptr[t + 16], ysrc);
353 if (J.version != +1) vmovups(ptr[t + 48], xsrc_next);
355 vmovups(ya, ptr[t + 16 - 8]);
356 vmovups(yb, ptr[t + 16 - 4]);
357 vmovups(yd, ptr[t + 16 + 4]);
358 vmovups(ye, ptr[t + 16 + 8]);
359 vmulps(ysum, yc, yc);
360 vfmadd231ps(ysum, ya, ya); // ysum <- ysum + ya*ya
361 vfmadd231ps(ysum, yb, yb);
362 vfmadd231ps(ysum, yd, yd);
363 vfmadd231ps(ysum, ye, ye);
364 vfmadd132ps(ysum, yk, yalpha); // ysum <- ysum*yalpha+yk
366 vmovaps(ybase, ysum);
367 if (pk != prop_kind::forward_inference)
368 vmovups(ptr[scratch], ybase);
369 vmulps(ysum2, ysum, ysum);
370 vmulps(ysum, ysum, ysum2); // ysum = ybase^3;
372 vsqrtps(ysum, ysum); // ysum = ybase^0.75
373 vdivps(ydst, ysrc, ysum); // ydst = ysrc / ysum
374 vmovups(ptr[dst], ydst);
378 if (pk != prop_kind::forward_inference)
382 jne(lrn_loop, T_NEAR);
387 ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>(
392 jit_uni_lrn_fwd_kernel_f32<sse42>::jit_uni_lrn_fwd_kernel_f32(
393 const struct nchw8c_across &J,
399 : jit_generator(code_ptr, code_size)
402 Xbyak::Reg64 t = rsp;
403 Xbyak::Reg64 hw = r9;
405 Xbyak::Xmm xsrc_lo = xmm2;
406 Xbyak::Xmm xsrc_hi = xmm3;
407 Xbyak::Xmm xc_lo = xmm4;
408 Xbyak::Xmm xc_hi = xmm5;
409 Xbyak::Xmm xsum_lo = xc_lo;
410 Xbyak::Xmm xsum_hi = xc_hi;
411 Xbyak::Xmm xsrc_prev = xmm6;
412 Xbyak::Xmm xsrc_next = xmm7;
413 Xbyak::Xmm xa_lo = xmm8;
414 Xbyak::Xmm xa_hi = xmm9;
415 Xbyak::Xmm xb_lo = xmm10;
416 Xbyak::Xmm xb_hi = xmm11;
417 Xbyak::Xmm xd_lo = xmm12;
418 Xbyak::Xmm xd_hi = xmm13;
419 Xbyak::Xmm xe_lo = xmm14;
420 Xbyak::Xmm xe_hi = xmm15;
421 Xbyak::Xmm xbase_lo = xmm14;
422 Xbyak::Xmm xbase_hi = xmm15;
426 mov(src, ptr[this->param1 + 0]);
427 mov(dst, ptr[this->param1 + 8]);
428 if (pk != prop_kind::forward_inference)
429 mov(scratch, ptr[this->param1 + 16]);
431 mov(imm_addr64, float2int(this->alpha));
432 movq(xalpha, imm_addr64);
433 shufps(xalpha, xalpha, 0);
435 mov(imm_addr64, float2int(this->k));
436 movq(xk, imm_addr64);
441 xorps(xsrc_prev, xsrc_prev);
442 movups(ptr[t + 0], xsrc_prev);
446 xorps(xsrc_next, xsrc_next);
447 movups(ptr[t + 48], xsrc_next);
454 if (J.version != -1) movups(xsrc_prev, ptr[src - J.H*J.W * 32 + 16]);
455 movups(xsrc_lo, ptr[src]);
456 movups(xsrc_hi, ptr[src + 4 * sizeof(float)]);
457 if (J.version != +1) movups(xsrc_next, ptr[src + J.H*J.W * 32]);
459 if (J.version != -1) movups(ptr[t + 0], xsrc_prev);
460 movups(ptr[t + 16], xsrc_lo);
461 movups(ptr[t + 16 + 4 * sizeof(float)], xsrc_hi);
462 if (J.version != +1) movups(ptr[t + 48], xsrc_next);
464 movups(xa_lo, ptr[t + 16 - 8]);
465 movups(xa_hi, ptr[t + 16 - 8 + 4 * sizeof(float)]);
466 movups(xb_lo, ptr[t + 16 - 4]);
467 movups(xb_hi, ptr[t + 16 - 4 + 4 * sizeof(float)]);
468 movups(xd_lo, ptr[t + 16 + 4]);
469 movups(xd_hi, ptr[t + 16 + 4 + 4 * sizeof(float)]);
470 movups(xe_lo, ptr[t + 16 + 8]);
471 movups(xe_hi, ptr[t + 16 + 8 + 4 * sizeof(float)]);
472 movaps(xc_lo, xsrc_lo);
473 movaps(xc_hi, xsrc_hi);
474 mulps(xsum_lo, xc_lo);
475 mulps(xsum_hi, xc_hi);
478 addps(xsum_lo, xa_lo);
479 addps(xsum_hi, xa_hi); // xsum <- xsum + xa*xa
482 addps(xsum_lo, xb_lo);
483 addps(xsum_hi, xb_hi);
486 addps(xsum_lo, xd_lo);
487 addps(xsum_hi, xd_hi);
490 addps(xsum_lo, xe_lo);
491 addps(xsum_hi, xe_hi);
493 mulps(xsum_lo, xalpha);
494 mulps(xsum_hi, xalpha);
496 addps(xsum_hi, xk); // xsum <- xsum*xalpha+xk
498 movaps(xbase_lo, xsum_lo);
499 movaps(xbase_hi, xsum_hi);
500 if (pk != prop_kind::forward_inference) {
501 movups(ptr[scratch], xbase_lo);
502 movups(ptr[scratch + 4 * sizeof(float)], xbase_hi);
504 mulps(xsum_lo, xsum_lo);
505 mulps(xsum_hi, xsum_hi);
506 mulps(xsum_lo, xbase_lo);
507 mulps(xsum_hi, xbase_hi); // xsum = xbase^3;
508 sqrtps(xsum_lo, xsum_lo);
509 sqrtps(xsum_hi, xsum_hi);
510 sqrtps(xsum_lo, xsum_lo);
511 sqrtps(xsum_hi, xsum_hi); // xsum = xbase^0.75
512 divps(xsrc_lo, xsum_lo);
513 divps(xsrc_hi, xsum_hi); // xdst = xsrc / xsum
514 movups(ptr[dst], xsrc_lo);
515 movups(ptr[dst + 4 * sizeof(float)], xsrc_hi);
519 if (pk != prop_kind::forward_inference)
523 jne(lrn_loop, T_NEAR);
528 ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>(
533 jit_uni_lrn_fwd_kernel_f32<avx2>::jit_uni_lrn_fwd_kernel_f32(
534 const struct nhwc_across &J,
540 : jit_generator(code_ptr, code_size)
543 static const uint32_t mask[] = {
544 0, 0, 0x80000000, 0x80000000, 0x80000000, 0x80000000,
545 0x80000000, 0x80000000, 0x80000000, 0, 0
549 Xbyak::Ymm ya = ymm2;
550 Xbyak::Ymm yb = ymm3;
551 Xbyak::Ymm yc = ymm4;
552 Xbyak::Ymm yd = ymm5;
553 Xbyak::Ymm ye = ymm6;
554 Xbyak::Ymm ysum = ymm7;
555 Xbyak::Ymm ydst = ymm8;
556 Xbyak::Ymm ybase = ymm9;
557 Xbyak::Ymm ymask = ymm10;
561 mov(src, ptr[this->param1 + 0]);
562 mov(dst, ptr[this->param1 + 8]);
563 if (pk != prop_kind::forward_inference)
564 mov(scratch, ptr[this->param1 + 16]);
565 mov(imm_addr64, float2int(this->alpha));
566 movq(xalpha, imm_addr64);
567 vbroadcastss(yalpha, xalpha);
569 mov(imm_addr64, float2int(this->k));
570 movq(xk, imm_addr64);
571 vbroadcastss(yk, xk);
573 vxorps(ysum, ysum, ysum);
575 mov(imm_addr64, reinterpret_cast<size_t>(&mask[0]));
576 vmovups(ymask, ptr[imm_addr64]);
577 vmaskmovps(ya, ymask, ptr[src - 8]);
578 vfmadd231ps(ysum, ya, ya); // ysum <- ysum + ya^2+yb^2+yc^2+yd^2+ye^2
580 mov(imm_addr64, reinterpret_cast<size_t>(&mask[1]));
581 vmovups(ymask, ptr[imm_addr64]);
582 vmaskmovps(yb, ymask, ptr[src - 4]);
583 vfmadd231ps(ysum, yb, yb);
589 vmovups(yc, ptr[src]);
590 vmovups(yd, ptr[src + 4]);
591 vmovups(ye, ptr[src + 8]);
592 vfmadd231ps(ysum, yc, yc);
593 vfmadd231ps(ysum, yd, yd);
594 vfmadd231ps(ysum, ye, ye);
597 vfmadd132ps(ydst, yk, yalpha); // ydst <- ysum*yalpha+yk
599 vmovaps(ybase, ydst);
600 if (pk != prop_kind::forward_inference)
601 vmovups(ptr[scratch], ybase);
602 vmulps(ydst, ydst, ydst);
603 vmulps(ydst, ydst, ybase); // ydst = (ysum*yalpha+yk)^3;
605 vsqrtps(ydst, ydst); // ydst = (ysum*yalpha+yk)^0.75
607 vdivps(ydst, yc, ydst); // ydst = ysrc / (ysum*yalpha+yk)^0.75
608 vmovups(ptr[dst], ydst);
610 vxorps(ysum, ysum, ysum);
614 if (pk != prop_kind::forward_inference)
617 vmovups(ya, ptr[src - 8]);
618 vfmadd231ps(ysum, ya, ya);
619 vmovups(yb, ptr[src - 4]);
620 vfmadd231ps(ysum, yb, yb);
624 jne(lrn_loop, T_NEAR);
626 vmovups(yc, ptr[src]);
627 vfmadd231ps(ysum, yc, yc);
629 mov(imm_addr64, reinterpret_cast<size_t>(&mask[2]));
630 vmovups(ymask, ptr[imm_addr64]);
631 vmaskmovps(yd, ymask, ptr[src + 4]);
632 vfmadd231ps(ysum, yd, yd); // ysum <- ysum + ya^2+yb^2+yc^2+yd^2+ye^2
634 mov(imm_addr64, reinterpret_cast<size_t>(&mask[3]));
635 vmovups(ymask, ptr[imm_addr64]);
636 vmaskmovps(ye, ymask, ptr[src + 8]);
637 vfmadd231ps(ysum, ye, ye);
640 vfmadd132ps(ydst, yk, yalpha); // ydst <- ysum*yalpha+yk
642 vmovaps(ybase, ydst);
643 if (pk != prop_kind::forward_inference)
644 vmovups(ptr[scratch], ybase);
645 vmulps(ydst, ydst, ydst);
646 vmulps(ydst, ydst, ybase); // ydst = (ysum*yalpha+yk)^3;
648 vsqrtps(ydst, ydst); // ydst = (ysum*yalpha+yk)^0.75
649 vdivps(ydst, yc, ydst); // ydst = ysrc / (ysum*yalpha+yk)^0.75
651 vmovups(ptr[dst], ydst);
655 ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>(
660 jit_uni_lrn_fwd_kernel_f32<sse42>::jit_uni_lrn_fwd_kernel_f32(
661 const struct nhwc_across &J,
667 : jit_generator(code_ptr, code_size)
670 static const uint32_t mask[] = {
671 0, 0, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff,
672 0xffffffff, 0xffffffff, 0xffffffff, 0, 0
675 static uint32_t store[] = {
676 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
680 Xbyak::Xmm xdst_lo = xmm0;
681 Xbyak::Xmm xdst_hi = xmm1;
682 Xbyak::Xmm xa_lo = xmm2;
683 Xbyak::Xmm xa_hi = xmm3;
684 Xbyak::Xmm xb_lo = xmm2;
685 Xbyak::Xmm xb_hi = xmm3;
686 Xbyak::Xmm xc_lo = xmm4;
687 Xbyak::Xmm xc_hi = xmm5;
688 Xbyak::Xmm xd_lo = xmm6;
689 Xbyak::Xmm xd_hi = xmm7;
690 Xbyak::Xmm xe_lo = xmm8;
691 Xbyak::Xmm xe_hi = xmm9;
692 Xbyak::Xmm xsum_lo = xmm10;
693 Xbyak::Xmm xsum_hi = xmm11;
694 Xbyak::Xmm xmask_lo = xmm12;
695 Xbyak::Xmm xmask_hi = xmm13;
696 Xbyak::Xmm xbase_lo = xmm14;
697 Xbyak::Xmm xbase_hi = xmm15;
701 mov(src, ptr[this->param1 + 0]);
702 mov(dst, ptr[this->param1 + 8]);
703 if (pk != prop_kind::forward_inference)
704 mov(scratch, ptr[this->param1 + 16]);
705 mov(imm_addr64, float2int(this->alpha));
706 movq(xalpha, imm_addr64);
707 shufps(xalpha, xalpha, 0);
709 mov(imm_addr64, float2int(this->k));
710 movq(xk, imm_addr64);
713 mov(store_addr, reinterpret_cast<size_t>(&store[0]));
714 and_(store_addr, -15);
715 movups(ptr[store_addr], xalpha);
716 movups(ptr[store_addr + 4 * sizeof(float)], xk);
718 xorps(xsum_lo, xsum_lo);
719 xorps(xsum_hi, xsum_hi);
721 mov(imm_addr64, reinterpret_cast<size_t>(&mask[0]));
722 movups(xmask_lo, ptr[imm_addr64]);
723 movups(xmask_hi, ptr[imm_addr64 + 4 * sizeof(float)]);
724 movups(xa_lo, ptr[src - 8]);
725 movups(xa_hi, ptr[src - 8 + 4 * sizeof(float)]);
726 andps(xa_lo, xmask_lo);
727 andps(xa_hi, xmask_hi);
730 addps(xsum_lo, xa_lo);
731 addps(xsum_hi, xa_hi); // xsum <- xsum + xa^2+xb^2+xc^2+xd^2+xe^2
733 mov(imm_addr64, reinterpret_cast<size_t>(&mask[1]));
734 movups(xmask_lo, ptr[imm_addr64]);
735 movups(xmask_hi, ptr[imm_addr64 + 4 * sizeof(float)]);
736 movups(xb_lo, ptr[src - 4]);
737 movups(xb_hi, ptr[src - 4 + 4 * sizeof(float)]);
738 andps(xb_lo, xmask_lo);
739 andps(xb_hi, xmask_hi);
742 addps(xsum_lo, xb_lo);
743 addps(xsum_hi, xb_hi);
749 movups(xc_lo, ptr[src]);
750 movups(xc_hi, ptr[src + 4 * sizeof(float)]);
751 movups(xd_lo, ptr[src + 4]);
752 movups(xd_hi, ptr[src + 4 + 4 * sizeof(float)]);
753 movups(xe_lo, ptr[src + 8]);
754 movups(xe_hi, ptr[src + 8 + 4 * sizeof(float)]);
757 addps(xsum_lo, xc_lo);
758 addps(xsum_hi, xc_hi);
761 addps(xsum_lo, xd_lo);
762 addps(xsum_hi, xd_hi);
765 addps(xsum_lo, xe_lo);
766 addps(xsum_hi, xe_hi);
768 movaps(xdst_lo, xsum_lo);
769 movaps(xdst_hi, xsum_hi);
770 // xdst <- xsum*xalpha+xk
771 mulps(xdst_lo, ptr[store_addr]);
772 mulps(xdst_hi, ptr[store_addr]);
773 addps(xdst_lo, ptr[store_addr + 4 * sizeof(float)]);
774 addps(xdst_hi, ptr[store_addr + 4 * sizeof(float)]);
776 movaps(xbase_lo, xdst_lo);
777 movaps(xbase_hi, xdst_hi);
778 if (pk != prop_kind::forward_inference) {
779 movups(ptr[scratch], xbase_lo);
780 movups(ptr[scratch + 4 * sizeof(float)], xbase_hi);
782 mulps(xdst_lo, xdst_lo);
783 mulps(xdst_hi, xdst_hi);
784 mulps(xdst_lo, xbase_lo);
785 mulps(xdst_hi, xbase_hi); // xdst = (xsum*xalpha+xk)^3;
786 sqrtps(xdst_lo, xdst_lo);
787 sqrtps(xdst_hi, xdst_hi);
788 sqrtps(xdst_lo, xdst_lo);
789 sqrtps(xdst_hi, xdst_hi); // xdst = (xsum*xalpha+xk)^0.75
791 movups(xc_lo, ptr[src]);
792 movups(xc_hi, ptr[src + 4 * sizeof(float)]);
793 divps(xc_lo, xdst_lo);
794 divps(xc_hi, xdst_hi); // xdst = xsrc / (xsum*xalpha+xk)^0.75
795 movups(ptr[dst], xc_lo);
796 movups(ptr[dst + 4 * sizeof(float)], xc_hi);
798 xorps(xsum_lo, xsum_lo);
799 xorps(xsum_hi, xsum_hi);
803 if (pk != prop_kind::forward_inference)
806 movups(xa_lo, ptr[src - 8]);
807 movups(xa_hi, ptr[src - 8 + 4 * sizeof(float)]);
810 addps(xsum_lo, xa_lo);
811 addps(xsum_hi, xa_hi);
812 movups(xb_lo, ptr[src - 4]);
813 movups(xb_hi, ptr[src - 4 + 4 * sizeof(float)]);
816 addps(xsum_lo, xb_lo);
817 addps(xsum_hi, xb_hi);
821 jne(lrn_loop, T_NEAR);
823 movups(xc_lo, ptr[src]);
824 movups(xc_hi, ptr[src + 4 * sizeof(float)]);
827 addps(xsum_lo, xc_lo);
828 addps(xsum_hi, xc_hi);
830 mov(imm_addr64, reinterpret_cast<size_t>(&mask[2]));
831 movups(xmask_lo, ptr[imm_addr64]);
832 movups(xmask_hi, ptr[imm_addr64 + 4 * sizeof(float)]);
833 movups(xd_lo, ptr[src + 4]);
834 movups(xd_hi, ptr[src + 4 + 4 * sizeof(float)]);
835 andps(xd_lo, xmask_lo);
836 andps(xd_hi, xmask_hi);
839 addps(xsum_lo, xd_lo);
840 addps(xsum_hi, xd_hi); // xsum <- xsum + xa^2+xb^2+xc^2+xd^2+xe^2
842 mov(imm_addr64, reinterpret_cast<size_t>(&mask[3]));
843 movups(xmask_lo, ptr[imm_addr64]);
844 movups(xmask_hi, ptr[imm_addr64 + 4 * sizeof(float)]);
845 movups(xe_lo, ptr[src + 8]);
846 movups(xe_hi, ptr[src + 8 + 4 * sizeof(float)]);
847 andps(xe_lo, xmask_lo);
848 andps(xe_hi, xmask_hi);
851 addps(xsum_lo, xe_lo);
852 addps(xsum_hi, xe_hi);
854 movups(xdst_lo, xsum_lo);
855 movups(xdst_hi, xsum_hi);
856 // xdst <- xsum*xalpha+xk
857 mulps(xdst_lo, ptr[store_addr]);
858 mulps(xdst_hi, ptr[store_addr]);
859 addps(xdst_lo, ptr[store_addr + 4 * sizeof(float)]);
860 addps(xdst_hi, ptr[store_addr + 4 * sizeof(float)]);
862 movaps(xbase_lo, xdst_lo);
863 movaps(xbase_hi, xdst_hi);
864 if (pk != prop_kind::forward_inference) {
865 movups(ptr[scratch], xbase_lo);
866 movups(ptr[scratch + 4 * sizeof(float)], xbase_hi);
868 mulps(xdst_lo, xdst_lo);
869 mulps(xdst_hi, xdst_hi);
870 mulps(xdst_lo, xbase_lo);
871 mulps(xdst_hi, xbase_hi); // xdst = (xsum*xalpha+xk)^3;
872 sqrtps(xdst_lo, xdst_lo);
873 sqrtps(xdst_hi, xdst_hi);
874 sqrtps(xdst_lo, xdst_lo);
875 sqrtps(xdst_hi, xdst_hi); // xdst = (xsum*xalpha+xk)^0.75
876 movups(xc_lo, ptr[src]);
877 movups(xc_hi, ptr[src + 4 * sizeof(float)]);
878 divps(xc_lo, xdst_lo);
879 divps(xc_hi, xdst_hi); // xdst = xsrc / (xsum*xalpha+xk)^0.75
881 movups(ptr[dst], xc_lo);
882 movups(ptr[dst + 4 * sizeof(float)], xc_hi);
886 ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>(
891 void jit_uni_lrn_fwd_kernel_f32<sse42>::nchw_body(
892 int tail, int HW, prop_kind_t pk,
902 void jit_uni_lrn_fwd_kernel_f32<avx2>::nchw_body(
903 int tail, int HW, prop_kind_t pk,
912 Xbyak::Ymm ydst = ymm14;
913 Xbyak::Ymm ybase = ymm15;
915 vfmadd231ps(ysum, ye, ye);
918 vfmadd132ps(ydst, yk, yalpha); // ydst <- ysum*yalpha+yk
920 vmovaps(ybase, ydst);
921 if (pk != prop_kind::forward_inference)
924 vmaskmovps(ptr[scratch], ymask, ybase);
926 vmovups(ptr[scratch], ybase);
928 vmulps(ydst, ydst, ydst);
929 vmulps(ydst, ydst, ybase); // ydst = (ysum*yalpha+yk)^3;
931 vsqrtps(ydst, ydst); // ydst = (ysum*yalpha+yk)^0.75
932 vdivps(ydst, yc, ydst); // ydst = ysrc / (ysum*yalpha+yk)^0.75
935 vmaskmovps(ptr[dst], ymask, ydst);
937 vmovups(ptr[dst], ydst);
940 vfnmadd231ps(ysum, ya, ya);
948 void jit_uni_lrn_fwd_kernel_f32<avx2>::nchw_tail_sse42(
949 int tail, Xbyak::Reg64 reg_dst, Xbyak::Xmm xtail_lo, Xbyak::Xmm xtail_hi)
953 void jit_uni_lrn_fwd_kernel_f32<sse42>::nchw_tail_sse42(
954 int tail, Xbyak::Reg64 reg_dst, Xbyak::Xmm xtail_lo, Xbyak::Xmm xtail_hi)
956 Xbyak::Xmm xmm_tmp = xmm10;
957 movaps(xmm_tmp, xtail_lo);
961 movups(ptr[reg_dst], xtail_lo);
962 movaps(xmm_tmp, xtail_hi);
963 offset += 4 * sizeof(float);
966 movss(ptr[reg_dst + offset], xmm_tmp);
967 for (int i = 1; i < tail; i++)
970 movss(ptr[reg_dst + offset + i * sizeof(float)], xmm_tmp);
975 void jit_uni_lrn_fwd_kernel_f32<sse42>::nchw_body_sse42(
976 int tail, int HW, prop_kind_t pk,
977 Xbyak::Xmm xmask_lo, Xbyak::Xmm xmask_hi,
978 Xbyak::Xmm xe_lo, Xbyak::Xmm xe_hi,
979 Xbyak::Xmm xsum_lo, Xbyak::Xmm xsum_hi)
981 Xbyak::Xmm xdst_lo = xmm0;
982 Xbyak::Xmm xdst_hi = xmm1;
983 Xbyak::Xmm xbase_lo = xmm6;
984 Xbyak::Xmm xbase_hi = xmm7;
985 Xbyak::Xmm xtmp_lo = xmm8;
986 Xbyak::Xmm xtmp_hi = xmm9;
987 Xbyak::Xmm xa_lo = xmm6;
988 Xbyak::Xmm xa_hi = xmm7;
989 Xbyak::Xmm xb_lo = xmm8;
990 Xbyak::Xmm xb_hi = xmm9;
991 Xbyak::Xmm xc_lo = xmm10;
992 Xbyak::Xmm xc_hi = xmm11;
993 Xbyak::Xmm xd_lo = xmm12;
994 Xbyak::Xmm xd_hi = xmm13;
997 movaps(ptr[store_addr + 10 * 4 * sizeof(float)], xe_lo);
998 movaps(ptr[store_addr + 11 * 4 * sizeof(float)], xe_hi);
1000 mulps(xe_lo, xe_lo);
1001 mulps(xe_hi, xe_hi);
1002 addps(xsum_lo, xe_lo);
1003 addps(xsum_hi, xe_hi);
1005 // xdst <- xsum*xalpha+xk
1006 movaps(xdst_lo, xsum_lo);
1007 movaps(xdst_hi, xsum_hi);
1008 mulps(xdst_lo, ptr[store_addr + 0 * 4 * sizeof(float)]);
1009 mulps(xdst_hi, ptr[store_addr + 0 * 4 * sizeof(float)]);
1010 addps(xdst_lo, ptr[store_addr + 1 * 4 * sizeof(float)]);
1011 addps(xdst_hi, ptr[store_addr + 1 * 4 * sizeof(float)]);
1013 movaps(xbase_lo, xdst_lo);
1014 movaps(xbase_hi, xdst_hi);
1015 if (pk != prop_kind::forward_inference)
1018 nchw_tail_sse42(tail, scratch, xbase_lo, xbase_hi);
1021 movups(ptr[scratch], xbase_lo);
1022 movups(ptr[scratch + 4 * sizeof(float)], xbase_hi);
1025 mulps(xdst_lo, xdst_lo);
1026 mulps(xdst_hi, xdst_hi);
1027 mulps(xdst_lo, xbase_lo);
1028 mulps(xdst_hi, xbase_hi); // xdst = (xsum*xalpha+xk)^3;
1029 sqrtps(xdst_lo, xdst_lo);
1030 sqrtps(xdst_hi, xdst_hi);
1031 sqrtps(xdst_lo, xdst_lo);
1032 sqrtps(xdst_hi, xdst_hi); // xdst = (xsum*xalpha+xk)^0.75
1033 movaps(xtmp_lo, ptr[store_addr + 6 * 4 * sizeof(float)]);
1034 movaps(xtmp_hi, ptr[store_addr + 7 * 4 * sizeof(float)]);
1035 divps(xtmp_lo, xdst_lo);
1036 divps(xtmp_hi, xdst_hi); // xdst = xsrc / (xsum*xalpha+xk)^0.75
1037 movaps(xdst_lo, xtmp_lo);
1038 movaps(xdst_hi, xtmp_hi);
1041 nchw_tail_sse42(tail, dst, xdst_lo, xdst_hi);
1044 movups(ptr[dst], xdst_lo);
1045 movups(ptr[dst + 4 * sizeof(float)], xdst_hi);
1048 movaps(xa_lo, ptr[store_addr + 2 * 4 * sizeof(float)]);
1049 movaps(xa_hi, ptr[store_addr + 3 * 4 * sizeof(float)]);
1050 mulps(xa_lo, xa_lo);
1051 mulps(xa_hi, xa_hi);
1052 subps(xsum_lo, xa_lo);
1053 subps(xsum_hi, xa_hi);
1056 movaps(xb_lo, ptr[store_addr + 4 * 4 * sizeof(float)]);
1057 movaps(xb_hi, ptr[store_addr + 5 * 4 * sizeof(float)]);
1058 movaps(ptr[store_addr + 2 * 4 * sizeof(float)], xb_lo);
1059 movaps(ptr[store_addr + 3 * 4 * sizeof(float)], xb_hi);
1062 movaps(xc_lo, ptr[store_addr + 6 * 4 * sizeof(float)]);
1063 movaps(xc_hi, ptr[store_addr + 7 * 4 * sizeof(float)]);
1064 movaps(ptr[store_addr + 4 * 4 * sizeof(float)], xc_lo);
1065 movaps(ptr[store_addr + 5 * 4 * sizeof(float)], xc_hi);
1068 movaps(xd_lo, ptr[store_addr + 8 * 4 * sizeof(float)]);
1069 movaps(xd_hi, ptr[store_addr + 9 * 4 * sizeof(float)]);
1070 movaps(ptr[store_addr + 6 * 4 * sizeof(float)], xd_lo);
1071 movaps(ptr[store_addr + 7 * 4 * sizeof(float)], xd_hi);
1074 movaps(xe_lo, ptr[store_addr + 10 * 4 * sizeof(float)]);
1075 movaps(xe_hi, ptr[store_addr + 11 * 4 * sizeof(float)]);
1076 movaps(ptr[store_addr + 8 * 4 * sizeof(float)], xe_lo);
1077 movaps(ptr[store_addr + 9 * 4 * sizeof(float)], xe_hi);
1081 void jit_uni_lrn_fwd_kernel_f32<avx2>::nchw_body_sse42(
1082 int tail, int HW, prop_kind_t pk,
1083 Xbyak::Xmm xmask_lo, Xbyak::Xmm xmask_hi,
1084 Xbyak::Xmm xe_lo, Xbyak::Xmm xe_hi,
1085 Xbyak::Xmm xsum_lo, Xbyak::Xmm xsum_hi) {}
1088 jit_uni_lrn_fwd_kernel_f32<avx2>::jit_uni_lrn_fwd_kernel_f32(
1089 struct nchw_across J,
1095 : jit_generator(code_ptr, code_size)
1098 static const uint32_t mask[] = {
1099 0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000,
1100 0x80000000, 0x80000000, 0, 0, 0, 0, 0, 0, 0
1102 Xbyak::Reg64 c = r10;
1103 Xbyak::Ymm ymask = ymm2;
1104 Xbyak::Ymm ye = ymm3;
1105 Xbyak::Ymm ya = ymm4;
1106 Xbyak::Ymm yb = ymm5;
1107 Xbyak::Ymm yc = ymm6;
1108 Xbyak::Ymm yd = ymm7;
1109 Xbyak::Ymm ysum = ymm8;
1115 mov(imm_addr64, reinterpret_cast<size_t>(&mask[7 - J.tail]));
1116 vmovups(ymask, ptr[imm_addr64]);
1118 mov(imm_addr64, float2int(this->alpha));
1119 movq(xalpha, imm_addr64);
1120 vbroadcastss(yalpha, xalpha);
1122 mov(imm_addr64, float2int(this->k));
1123 movq(xk, imm_addr64);
1124 vbroadcastss(yk, xk);
1126 mov(src, ptr[this->param1 + 0]);
1127 mov(dst, ptr[this->param1 + 8]);
1128 if (pk != prop_kind::forward_inference)
1129 mov(scratch, ptr[this->param1 + 16]);
1134 vmaskmovps(yc, ymask, ptr[src + J.HW * 0]);
1136 vmovups(yc, ptr[src + J.HW * 0]);
1138 vmaskmovps(yd, ymask, ptr[src + J.HW * 4]);
1140 vmovups(yd, ptr[src + J.HW * 4]);
1142 vxorps(ysum, ysum, ysum);
1143 vfmadd231ps(ysum, yc, yc); // ysum <- ysum + ya^2+yb^2+yc^2+yd^2+ye^2
1144 vfmadd231ps(ysum, yd, yd);
1151 vmaskmovps(ye, ymask, ptr[src + J.HW * 8]);
1153 vmovups(ye, ptr[src + J.HW * 8]);
1155 nchw_body(J.tail, J.HW, pk, ymask, ya, yb, yc, yd, ye, ysum);
1159 if (pk != prop_kind::forward_inference)
1160 add(scratch, J.HW * 4);
1163 jne(lrn_loop, T_NEAR);
1167 nchw_body(J.tail, J.HW, pk, ymask, ya, yb, yc, yd, ye, ysum);
1170 if (pk != prop_kind::forward_inference)
1171 add(scratch, J.HW * 4);
1173 nchw_body(J.tail, J.HW, pk, ymask, ya, yb, yc, yd, ye, ysum);
1177 ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>(
1182 jit_uni_lrn_fwd_kernel_f32<sse42>::jit_uni_lrn_fwd_kernel_f32(
1183 struct nchw_across J,
1189 : jit_generator(code_ptr, code_size)
1192 static const uint32_t mask[] = {
1193 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff,
1194 0xffffffff, 0xffffffff, 0, 0, 0, 0, 0, 0, 0
1197 Xbyak::Reg64 c = r10;
1199 Xbyak::Xmm xmask_lo = xmm2;
1200 Xbyak::Xmm xmask_hi = xmm3;
1201 Xbyak::Xmm xsum_lo = xmm4;
1202 Xbyak::Xmm xsum_hi = xmm5;
1203 Xbyak::Xmm xa_lo = xmm6;
1204 Xbyak::Xmm xa_hi = xmm7;
1205 Xbyak::Xmm xb_lo = xmm8;
1206 Xbyak::Xmm xb_hi = xmm9;
1207 Xbyak::Xmm xc_lo = xmm10;
1208 Xbyak::Xmm xc_hi = xmm11;
1209 Xbyak::Xmm xd_lo = xmm12;
1210 Xbyak::Xmm xd_hi = xmm13;
1211 Xbyak::Xmm xe_lo = xmm14;
1212 Xbyak::Xmm xe_hi = xmm15;
1216 mov(src, ptr[this->param1 + 0]);
1217 mov(dst, ptr[this->param1 + 8]);
1218 if (pk != prop_kind::forward_inference)
1219 mov(scratch, ptr[this->param1 + 16]);
1221 sub(rsp, stack_space_needed);
1222 mov(store_addr, rsp);
1223 and_(store_addr, -15);
1225 mov(imm_addr64, float2int(this->alpha));
1226 movq(xalpha, imm_addr64);
1227 shufps(xalpha, xalpha, 0);
1229 mov(imm_addr64, float2int(this->k));
1230 movq(xk, imm_addr64);
1233 // put alpha and k into store (free up regs)
1234 movaps(ptr[store_addr + 0 * 4 * sizeof(float)], xalpha);
1235 movaps(ptr[store_addr + 1 * 4 * sizeof(float)], xk);
1239 mov(imm_addr64, reinterpret_cast<size_t>(&mask[7 - J.tail]));
1240 movups(xmask_lo, ptr[imm_addr64]);
1241 movups(xmask_hi, ptr[imm_addr64 + 4 * sizeof(float)]);
1244 xorps(xa_lo, xa_lo);
1245 xorps(xa_hi, xa_hi);
1246 xorps(xb_lo, xb_lo);
1247 xorps(xb_hi, xb_hi);
1251 movups(xc_lo, ptr[src + J.HW * 0]);
1252 movups(xc_hi, ptr[src + J.HW * 0 + 4 * sizeof(float)]);
1253 andps(xc_lo, xmask_lo);
1254 andps(xc_hi, xmask_hi);
1257 movups(xc_lo, ptr[src + J.HW * 0]);
1258 movups(xc_hi, ptr[src + J.HW * 0 + 4 * sizeof(float)]);
1261 movups(xd_lo, ptr[src + J.HW * 4]);
1262 movups(xd_hi, ptr[src + J.HW * 4 + 4 * sizeof(float)]);
1263 andps(xd_lo, xmask_lo);
1264 andps(xd_hi, xmask_hi);
1267 movups(xd_lo, ptr[src + J.HW * 4]);
1268 movups(xd_hi, ptr[src + J.HW * 4 + 4 * sizeof(float)]);
1271 // put xa, xb, xc, xd into store to free-up regs
1272 movaps(ptr[store_addr + 2 * 4 * sizeof(float)], xa_lo);
1273 movaps(ptr[store_addr + 3 * 4 * sizeof(float)], xa_hi);
1274 movaps(ptr[store_addr + 4 * 4 * sizeof(float)], xb_lo);
1275 movaps(ptr[store_addr + 5 * 4 * sizeof(float)], xb_hi);
1276 movaps(ptr[store_addr + 6 * 4 * sizeof(float)], xc_lo);
1277 movaps(ptr[store_addr + 7 * 4 * sizeof(float)], xc_hi);
1278 movaps(ptr[store_addr + 8 * 4 * sizeof(float)], xd_lo);
1279 movaps(ptr[store_addr + 9 * 4 * sizeof(float)], xd_hi);
1281 xorps(xsum_lo, xsum_lo);
1282 xorps(xsum_hi, xsum_hi);
1283 mulps(xc_lo, xc_lo);
1284 mulps(xc_hi, xc_hi);
1285 addps(xsum_lo, xc_lo);
1286 addps(xsum_hi, xc_hi);
1287 mulps(xd_lo, xd_lo);
1288 mulps(xd_hi, xd_hi);
1289 addps(xsum_lo, xd_lo);
1290 addps(xsum_hi, xd_hi); // xsum <- xsum + xa^2+xb^2+xc^2+xd^2+xe^2
1297 movups(xe_lo, ptr[src + J.HW * 8]);
1298 movups(xe_hi, ptr[src + J.HW * 8 + 4 * sizeof(float)]);
1299 andps(xe_lo, xmask_lo);
1300 andps(xe_hi, xmask_hi);
1303 movups(xe_lo, ptr[src + J.HW * 8]);
1304 movups(xe_hi, ptr[src + J.HW * 8 + 4 * sizeof(float)]);
1307 nchw_body_sse42(J.tail, J.HW, pk, xmask_lo, xmask_hi,
1313 if (pk != prop_kind::forward_inference)
1314 add(scratch, J.HW * 4);
1317 jne(lrn_loop, T_NEAR);
1319 xorps(xe_lo, xe_lo);
1320 xorps(xe_hi, xe_hi);
1322 nchw_body_sse42(J.tail, J.HW, pk, xmask_lo, xmask_hi,
1327 if (pk != prop_kind::forward_inference)
1328 add(scratch, J.HW * 4);
1330 nchw_body_sse42(J.tail, J.HW, pk, xmask_lo, xmask_hi,
1334 add(rsp, stack_space_needed);
1338 ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>(
1342 //////////////////////////////////////////////////////////////////////////////
1344 template <cpu_isa_t isa>
1345 jit_uni_lrn_bwd_kernel_f32<isa>::jit_uni_lrn_bwd_kernel_f32(
1346 const struct nchw8c_across &J,
1352 : jit_generator(code_ptr, code_size)
1353 , nalphabeta(-2 * A*B)
1354 , use_h_parallelizm(use_h_parallel)
1356 Xbyak::Reg64 t = rsp;
1357 Xbyak::Reg64 hw = r10;
1359 Xbyak::Xmm xsrc_prev = xmm1;
1360 Xbyak::Xmm xws_prev = xmm2;
1361 Xbyak::Xmm xdiffdst_prev = xmm3;
1362 Xbyak::Ymm ysrc = ymm4;
1363 Xbyak::Ymm yws = ymm5;
1364 Xbyak::Ymm ydiffdst = ymm6;
1365 Xbyak::Xmm xsrc_next = xmm7;
1366 Xbyak::Xmm xws_next = xmm8;
1367 Xbyak::Xmm xdiffdst_next = xmm9;
1368 Xbyak::Ymm ya = ymm10;
1369 Xbyak::Xmm xa = xmm10;
1370 Xbyak::Ymm yb = ymm11;
1371 Xbyak::Ymm yd = ymm12;
1372 Xbyak::Ymm ye = ymm13;
1373 Xbyak::Ymm ysum = ymm14;
1374 Xbyak::Ymm ydiffsrc = ymm15;
1378 mov(src, ptr[this->param1 + 0]);
1379 mov(diffdst, ptr[this->param1 + 8]);
1380 mov(workspace, ptr[this->param1 + 16]);
1381 mov(diffsrc, ptr[this->param1 + 24]);
1384 mov(imm_addr64, float2int(this->nalphabeta));
1385 movq(xnalphabeta, imm_addr64);
1386 vbroadcastss(ynalphabeta, xnalphabeta);
1388 bool is_single = J.version == 3;
1389 bool is_first = J.version == -1 || J.version == -2;
1390 bool is_last = J.version == +1 || J.version == -2;
1392 if (is_first || is_single) {
1393 vxorps(xsrc_prev, xsrc_prev, xsrc_prev);
1394 vmovups(ptr[t + 0], xsrc_prev);
1396 if (is_last || is_single) {
1397 vxorps(xsrc_next, xsrc_next, xsrc_next);
1398 vmovups(ptr[t + 48], xsrc_next);
1400 mov(hw, this->use_h_parallelizm ? J.W : J.H*J.W);
1404 if (!is_first && !is_single) {
1405 vmovups(xws_prev, ptr[workspace - J.H*J.W * 32 + 16]);
1406 vmovups(xsrc_prev, ptr[src - J.H*J.W * 32 + 16]);
1407 vmovups(xdiffdst_prev, ptr[diffdst - J.H*J.W * 32 + 16]);
1408 vmulps(xa, xws_prev, xws_prev);
1409 vmulps(xa, xa, xws_prev);
1412 vmulps(xa, xa, xws_prev);
1413 vdivps(xsrc_prev, xsrc_prev, xa);
1414 vmulps(xdiffdst_prev, xdiffdst_prev, xsrc_prev);
1417 vmovups(ysrc, ptr[src]);
1418 vmovups(yws, ptr[workspace]);
1419 vmovups(ydiffdst, ptr[diffdst]);
1420 vmulps(ya, yws, yws);
1421 vmulps(ya, ya, yws);
1424 vdivps(ydiffsrc, ydiffdst, ya);
1425 vdivps(ysum, ydiffsrc, yws);
1426 vmulps(ysum, ysum, ysrc);
1428 if (!is_last && !is_single) {
1429 vmovups(xws_next, ptr[workspace + J.H*J.W * 32]);
1430 vmovups(xsrc_next, ptr[src + J.H*J.W * 32]);
1431 vmovups(xdiffdst_next, ptr[diffdst + J.H*J.W * 32]);
1432 vmulps(xa, xws_next, xws_next);
1433 vmulps(xa, xa, xws_next);
1436 vmulps(xa, xa, xws_next);
1437 vdivps(xsrc_next, xsrc_next, xa);
1438 vdivps(xsrc_next, xsrc_next, xws_next);
1439 vmulps(xdiffdst_next, xdiffdst_next, xsrc_next);
1442 if (!is_first && !is_single) vmovups(ptr[t + 0], xdiffdst_prev);
1443 vmovups(ptr[t + 16], ysum);
1444 if (!is_last && !is_single) vmovups(ptr[t + 48], xdiffdst_next);
1446 vmovups(ya, ptr[t + 16 - 8]);
1447 vmovups(yb, ptr[t + 16 - 4]);
1448 vaddps(ysum, ysum, ya);
1449 vmulps(ysrc, ysrc, ynalphabeta);
1450 vaddps(ysum, ysum, yb);
1452 vmovups(yd, ptr[t + 16 + 4]);
1453 vmovups(ye, ptr[t + 16 + 8]);
1454 vaddps(ysum, ysum, yd);
1455 vaddps(ysum, ysum, ye);
1457 vfmadd231ps(ydiffsrc, ysum, ysrc);
1459 vmovups(ptr[diffsrc], ydiffsrc);
1468 jne(lrn_loop, T_NEAR);
1474 ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>(
1478 template struct jit_uni_lrn_fwd_kernel_f32<sse42>;
1479 template struct jit_uni_lrn_fwd_kernel_f32<avx2>;
1480 template struct jit_uni_lrn_bwd_kernel_f32<avx2>;
1486 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s