1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "mkldnn_types.h"
19 #include "c_types_map.hpp"
20 #include "jit_avx512_common_lrn.hpp"
21 #include "jit_generator.hpp"
22 #include "mkldnn_thread.hpp"
23 #include "type_helpers.hpp"
29 #define XMM_SIZE (4*sizeof(float))
30 #define ZMM_SIZE (vlen)
31 #define BUFFER_BLOCK (XMM_SIZE + ZMM_SIZE + XMM_SIZE)
32 #define BUFFER_NEXT_OFFSET (XMM_SIZE + ZMM_SIZE)
33 #define SRC_PREV_OFFSET (vlen - XMM_SIZE)
35 #define IRB_LOOP(statement) for(int irb = 0; irb < loop_size; irb++) { \
43 using namespace mkldnn::impl::status;
44 using namespace mkldnn::impl::memory_format;
45 using namespace mkldnn::impl::utils;
47 using namespace Xbyak;
49 enum params { vsize = 16, vlen = 64};
53 float *dst, *ws0, *ws1;
57 const float *src, *diff_dst, *ws0, *ws1;
61 struct nChw16c_across {
64 * 1: channels C-16 .. C-1,
66 * 3: channels only for this kernel(without prev and next)
69 nChw16c_across(int h, int w, int v) : H(h), W(w), version(v) {}
72 struct jit_avx512_common_lrn_fwd_t::jit_avx512_common_lrn_kernel_f32:
73 public jit_generator {
83 Reg64 imm_addr64 = rbx;
90 Reg64 param = abi_param1;
109 int use_h_parallelism;
113 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_lrn_kernel_f32)
115 void (*ker)(jit_args_fwd_t *);
116 void operator()(jit_args_fwd_t *arg) { ker(arg); }
119 prf0_offt = 1*FWD_RBC,
120 prf2_offt = 8*FWD_RBC
123 inline void compute_loop(int loop_size_param)
125 // loop_size - param for IRB_LOOP macro
126 int loop_size = FWD_RBC;
128 auto xreg = [=](int irb, int i) {
129 return Xmm(irb*3 + i);
132 auto zreg = [=](int irb, int i) {
133 return Zmm(irb*7 + i);
136 if (!is_first && !is_single) {
137 IRB_LOOP(mic_prefetcht0(ptr[src + (irb + prf0_offt - HW)*vlen]));
138 IRB_LOOP(mic_prefetcht2(ptr[src + (irb + prf2_offt - HW)*vlen]));
140 IRB_LOOP(mic_prefetcht0(EVEX_compress_addr(src, (irb + prf0_offt)*vlen)));
141 IRB_LOOP(mic_prefetcht2(EVEX_compress_addr(src, (irb + prf2_offt)*vlen)));
142 if (!is_last && !is_single) {
143 IRB_LOOP(mic_prefetcht0(ptr[src + (irb + prf0_offt + HW)*vlen]));
144 IRB_LOOP(mic_prefetcht2(ptr[src + (irb + prf2_offt + HW)*vlen]));
146 if (pk != prop_kind::forward_inference) {
147 IRB_LOOP(mic_prefetcht0(EVEX_compress_addr(scratch0,
148 (irb + prf0_offt)*vlen)));
149 IRB_LOOP(mic_prefetcht2(EVEX_compress_addr(scratch0,
150 (irb + prf2_offt)*vlen)));
152 IRB_LOOP(mic_prefetcht0(EVEX_compress_addr(dst, (irb + prf0_offt)*vlen)));
153 IRB_LOOP(mic_prefetcht2(EVEX_compress_addr(dst, (irb + prf2_offt)*vlen)));
154 if (pk != prop_kind::forward_inference) {
155 IRB_LOOP(mic_prefetcht0(EVEX_compress_addr(scratch1,
156 (irb + prf0_offt) * vlen)));
157 IRB_LOOP(mic_prefetcht2(EVEX_compress_addr(scratch1,
158 (irb + prf2_offt) * vlen)));
161 loop_size = loop_size_param;
164 if (!is_first && !is_single) {
165 IRB_LOOP(vmovups(xreg(irb, xsrc_prev),
166 ptr[src + (irb - HW) * vlen + SRC_PREV_OFFSET]));
168 IRB_LOOP(vmovups(zreg(irb, zsrc), EVEX_compress_addr(src,irb*vlen)));
169 if (!is_last && !is_single) {
170 IRB_LOOP(vmovups(xreg(irb, xsrc_next),
171 ptr[src + (irb + HW) * vlen]));
174 if (!is_first && !is_single) {
175 IRB_LOOP(vmovups(ptr[t + irb*BUFFER_BLOCK],
176 xreg(irb, xsrc_prev)));
178 IRB_LOOP(vmovups(EVEX_compress_addr(t, irb*BUFFER_BLOCK + XMM_SIZE),
180 if (!is_last && !is_single) {
181 IRB_LOOP(vmovups(ptr[t + irb*BUFFER_BLOCK + BUFFER_NEXT_OFFSET],
182 xreg(irb, xsrc_next)));
185 IRB_LOOP(vmovups(zreg(irb, za), EVEX_compress_addr(t, irb*BUFFER_BLOCK
186 + XMM_SIZE - 2*sizeof(float))));
187 IRB_LOOP(vmovups(zreg(irb, zb), EVEX_compress_addr(t, irb*BUFFER_BLOCK
188 + XMM_SIZE - sizeof(float))));
189 IRB_LOOP(vmovups(zreg(irb, zd), EVEX_compress_addr(t, irb*BUFFER_BLOCK
190 + XMM_SIZE + sizeof(float))));
191 IRB_LOOP(vmovups(zreg(irb, ze), EVEX_compress_addr(t, irb*BUFFER_BLOCK
192 + XMM_SIZE + 2*sizeof(float))));
195 IRB_LOOP(vmulps(zreg(irb, zsum), zreg(irb, zc), zreg(irb, zc)));
197 IRB_LOOP(vfmadd231ps(zreg(irb, zsum), zreg(irb, za), zreg(irb, za)));
198 IRB_LOOP(vfmadd231ps(zreg(irb, zsum), zreg(irb, zb), zreg(irb, zb)));
199 IRB_LOOP(vfmadd231ps(zreg(irb, zsum), zreg(irb, zd), zreg(irb, zd)));
200 IRB_LOOP(vfmadd231ps(zreg(irb, zsum), zreg(irb, ze), zreg(irb, ze)));
202 IRB_LOOP(vfmadd132ps(zreg(irb, zsum), zk, zalpha));
204 IRB_LOOP(vmovaps(zreg(irb, zbase), zreg(irb, zsum)));
206 IRB_LOOP(vmulps(zreg(irb, zsum2), zreg(irb, zsum), zreg(irb, zsum)));
207 IRB_LOOP(vmulps(zreg(irb, zsum), zreg(irb, zsum), zreg(irb, zsum2)));
209 IRB_LOOP(vsqrtps(zreg(irb, zsum), zreg(irb, zsum)));
210 IRB_LOOP(vsqrtps(zreg(irb, zsum), zreg(irb, zsum)));
212 if (pk != prop_kind::forward_inference) {
213 IRB_LOOP(vmovups(EVEX_compress_addr(scratch0, irb*vlen),
216 IRB_LOOP(vdivps(zreg(irb, zdst), zreg(irb, zsrc), zreg(irb, zsum)));
217 IRB_LOOP(vmovups(EVEX_compress_addr(dst, irb*vlen), zreg(irb, zdst)));
218 if (pk != prop_kind::forward_inference) {
219 /* ws1 = zdst / zbase = zsrc / (zbase^1.75) */
220 IRB_LOOP(vdivps(zreg(irb, zsum), zreg(irb, zdst), zreg(irb, zbase)));
221 IRB_LOOP(vmovups(EVEX_compress_addr(scratch1, irb*vlen),
226 jit_avx512_common_lrn_kernel_f32(
227 const struct nChw16c_across &J,
228 prop_kind_t prop_kind,
232 void *code_ptr = nullptr,
233 size_t code_size = 2 * Xbyak::DEFAULT_MAX_CODE_SIZE)
234 : jit_generator(code_ptr, code_size)
236 , use_h_parallelism(use_h_parallel)
242 mov(src, ptr[param + 0]);
243 mov(dst, ptr[param + 8]);
244 if (pk != prop_kind::forward_inference)
246 mov(scratch0, ptr[param + 16]);
247 mov(scratch1, ptr[param + 24]);
249 is_first = J.version == -1 || J.version == -2;
250 is_last = J.version == +1 || J.version == -2;
251 is_single = J.version == 3;
255 int LSB = use_h_parallelism ? W : HW;
257 sub(t, FWD_RBC*BUFFER_BLOCK);
258 mov(imm_addr64, float2int(this->alpha));
259 movq(xalpha, imm_addr64);
260 vbroadcastss(zalpha, xalpha);
262 mov(imm_addr64, float2int(this->k));
263 movq(xk, imm_addr64);
264 vbroadcastss(zk, xk);
266 if (is_first || is_single) {
267 vxorps(xmm2, xmm2, xmm2);
268 for(int irb = 0; irb < FWD_RBC; irb++) {
269 vmovups(ptr[t + irb*BUFFER_BLOCK], xmm2);
272 if (is_last || is_single) {
273 vxorps(xmm2, xmm2, xmm2);
274 for(int irb = 0; irb < FWD_RBC; irb++) {
275 vmovups(ptr[t + irb*BUFFER_BLOCK + BUFFER_NEXT_OFFSET],
280 int LSREST = LSB % FWD_RBC;
281 int LS = LSB - LSREST;
290 compute_loop(FWD_RBC);
292 add(src, FWD_RBC*vlen);
293 add(dst, FWD_RBC*vlen);
294 if (pk != prop_kind::forward_inference)
296 add(scratch0, FWD_RBC*vlen);
297 add(scratch1, FWD_RBC*vlen);
300 for(int irb = 0; irb < FWD_RBC; irb++)
303 jne(lrn_loop, T_NEAR);
307 compute_loop(LSREST);
309 add(t, FWD_RBC*BUFFER_BLOCK);
312 ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>(
317 status_t jit_avx512_common_lrn_fwd_t::pd_t::init() {
318 using namespace prop_kind;
319 using namespace alg_kind;
321 assert(engine()->kind() == engine_kind::cpu);
323 if (!mayiuse(avx512_common)) return unimplemented;
325 const memory_desc_wrapper data_d(data_pd_.desc());
327 && one_of(desc()->prop_kind, forward_training, forward_inference)
328 && !has_zero_dim_memory()
329 && everyone_is(data_type::f32, desc()->data_desc.data_type)
330 && data_d.ndims() == 4
331 && data_d.dims()[1] % vsize == 0
332 && attr()->has_default_values();
333 if (!ok) return unimplemented;
335 if (desc()->prop_kind == forward_training) {
337 dims_t ws_dims = { MB(), C(), H(), 2*W() };
338 mkldnn_memory_desc_init(&ws_d, 4, ws_dims, data_type::f32,
339 memory_format::nChw16c);
340 ws_pd_ = cpu_memory_t::pd_t(engine_, &ws_d);
343 bool args_ok_across = true
344 && desc()->alg_kind == lrn_across_channels
345 && desc()->local_size == 5
346 && desc()->lrn_beta == 0.75
347 && data_d.format() == nChw16c;
349 return args_ok_across ? success : unimplemented;
352 jit_avx512_common_lrn_fwd_t::jit_avx512_common_lrn_fwd_t(const pd_t *apd,
353 const input_vector &inputs, const output_vector &outputs)
354 : cpu_primitive_t(apd, inputs, outputs)
355 , use_h_parallelism(0), ker_(nullptr), ker_first_(nullptr)
356 , ker_last_(nullptr) {
357 using namespace alg_kind;
358 const int C = pd()->C();
359 const int H = pd()->H();
360 const int W = pd()->W();
361 const int ls = pd()->desc()->local_size;
362 const float alpha = pd()->desc()->lrn_alpha / ls;
363 const float k = pd()->desc()->lrn_k;
365 auto pk = pd()->desc()->prop_kind;
367 use_h_parallelism = H > 28 ? 1 : 0;
369 if (C / vsize == 1) {
370 ker_ = new jit_avx512_common_lrn_kernel_f32(nChw16c_across(H, W, 3), pk,
371 use_h_parallelism, alpha, k);
373 ker_ = new jit_avx512_common_lrn_kernel_f32(nChw16c_across(H, W, 0), pk,
374 use_h_parallelism, alpha, k);
375 ker_first_ = new jit_avx512_common_lrn_kernel_f32(
376 nChw16c_across(H, W, -1), pk, use_h_parallelism, alpha, k);
377 ker_last_ = new jit_avx512_common_lrn_kernel_f32(
378 nChw16c_across(H, W, +1), pk, use_h_parallelism, alpha, k);
382 jit_avx512_common_lrn_fwd_t::~jit_avx512_common_lrn_fwd_t()
383 { delete ker_; delete ker_first_; delete ker_last_; }
385 void jit_avx512_common_lrn_fwd_t::execute_forward() const {
386 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
387 auto dst = reinterpret_cast<data_t*>(this->memory(0));
388 auto ws = reinterpret_cast<data_t*>(this->memory(1));
390 const int N = pd()->MB();
391 const int C = pd()->C();
392 const int H = pd()->H();
393 const int W = pd()->W();
394 const int C16 = C / vsize;
395 const size_t work_amount = use_h_parallelism ? N*C16*H : N*C16;
397 parallel(0, work_amount, [&](const int ithr, const int nthr) {
398 size_t start{0}, end{0};
400 balance211(work_amount, nthr, ithr, start, end);
401 if (use_h_parallelism) {
402 int n{0}, c16{0}, h{0};
403 nd_iterator_init(start, n, N, c16, C16, h, H);
404 for (size_t iwork = start; iwork < end; ++iwork) {
405 auto offset = n*C*H*W + c16*H*W*vsize
407 auto ws_offset0 = n*C*H*2*W + c16*H*2*W*vsize
409 auto ws_offset1 = ws_offset0 + W*vsize;
412 args.src = &src[offset];
413 args.dst = &dst[offset];
414 args.ws0 = &ws[ws_offset0];
415 args.ws1 = &ws[ws_offset1];
420 (*ker_first_)(&args);
421 else if (c16 == C16 - 1)
425 nd_iterator_step(n, N, c16, C16, h, H);
429 nd_iterator_init(start, n, N, c16, C16);
430 for (size_t iwork = start; iwork < end; ++iwork) {
431 auto offset = n*C*H*W + c16*H*W*vsize;
432 auto ws_offset0 = n*C*H*2*W + c16*H*2*W*vsize;
433 auto ws_offset1 = ws_offset0 + H*W*vsize;
436 args.src = &src[offset];
437 args.dst = &dst[offset];
438 args.ws0 = &ws[ws_offset0];
439 args.ws1 = &ws[ws_offset1];
444 (*ker_first_)(&args);
445 else if (c16 == C16 - 1)
450 nd_iterator_step(n, N, c16, C16);
456 struct jit_avx512_common_lrn_bwd_t::jit_avx512_common_lrn_kernel_f32:
457 public jit_generator {
466 Reg64 workspace0 = rdx;
467 Reg64 workspace1 = rsi;
468 Reg64 imm_addr64 = rbx;
470 Zmm znalphabeta = zmm0;
471 Xmm xnalphabeta = xmm0;
473 Reg64 param = abi_param1;
478 int xdiffdst_prev = 2;
486 int xdiffdst_next = 3;
496 int use_h_parallelism;
498 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_lrn_kernel_f32)
500 void (*ker)(jit_args_bwd_t *);
501 void operator()(jit_args_bwd_t *arg) { ker(arg); }
504 prf0_offt = 1*BWD_RBC,
505 prf2_offt = 8*BWD_RBC
508 inline void compute_loop(int loop_size_param, int prefetchL1,
511 // loop_size - param for IRB_LOOP macro
512 int loop_size = loop_size_param;
514 auto xreg = [=](int irb, int i) {
515 return Xmm(irb*6 + i);
518 auto zreg = [=](int irb, int i) {
519 return Zmm(irb*6 + i);
522 // ---- prefetching -------------------------------------------
523 if (!is_first && !is_single) {
525 IRB_LOOP(mic_prefetcht0(ptr[workspace1 + (irb + prf0_offt
528 IRB_LOOP(mic_prefetcht0(ptr[diffdst + (irb + prf0_offt
533 IRB_LOOP(mic_prefetcht0(ptr[src + (irb + prf0_offt)*vlen]));
535 IRB_LOOP(mic_prefetcht2(ptr[src + (irb + prf2_offt)*vlen]));
538 IRB_LOOP(mic_prefetcht0(ptr[workspace1 + (irb + prf0_offt)*vlen]));
541 IRB_LOOP(mic_prefetcht0(ptr[diffdst + (irb + prf0_offt)*vlen]));
543 if (!is_last && !is_single) {
545 IRB_LOOP(mic_prefetcht0(ptr[workspace1 + (irb + prf0_offt
548 IRB_LOOP(mic_prefetcht2(ptr[workspace1 + (irb + prf2_offt
552 IRB_LOOP(mic_prefetcht0(ptr[diffdst + (irb + prf0_offt
555 IRB_LOOP(mic_prefetcht2(ptr[diffdst + (irb + prf2_offt
559 IRB_LOOP(mic_prefetcht0(ptr[workspace0 + (irb + prf0_offt)*vlen]));
561 IRB_LOOP(mic_prefetcht2(ptr[workspace0 + (irb + prf2_offt)*vlen]));
562 // -----------------------------------------------------------
564 if (loop_size_param == 0)
567 if (!is_first && !is_single) {
568 IRB_LOOP(vmovups(xreg(irb, xws1_prev), ptr[workspace1 + (irb
569 - 2 * HW) * vlen + SRC_PREV_OFFSET]));
570 IRB_LOOP(vmovups(xreg(irb, xdiffdst_prev), ptr[diffdst + (irb
571 - HW) * vlen + SRC_PREV_OFFSET]));
572 IRB_LOOP(vmulps(xreg(irb, xdiffdst_prev), xreg(irb, xdiffdst_prev),
573 xreg(irb, xws1_prev)));
576 IRB_LOOP(vmovups(zreg(irb, zws1),
577 EVEX_compress_addr(workspace1, irb*vlen)));
578 IRB_LOOP(vmovups(zreg(irb, zdiffdst),
579 EVEX_compress_addr(diffdst, irb*vlen)));
580 IRB_LOOP(vmulps(zreg(irb, zdiffsrc), zreg(irb, zdiffdst),
583 if (!is_last && !is_single) {
584 IRB_LOOP(vmovups(xreg(irb, xws1_next), ptr[workspace1 + (irb
586 IRB_LOOP(vmovups(xreg(irb, xdiffdst_next), ptr[diffdst + (irb
588 IRB_LOOP(vmulps(xreg(irb, xdiffdst_next), xreg(irb, xdiffdst_next),
589 xreg(irb, xws1_next)));
592 if (!is_first && !is_single) {
593 IRB_LOOP(vmovups(ptr[t + irb*BUFFER_BLOCK],
594 xreg(irb, xdiffdst_prev)));
596 IRB_LOOP(vmovups(EVEX_compress_addr(t, irb*BUFFER_BLOCK + XMM_SIZE),
597 zreg(irb, zdiffsrc)));
598 if (!is_last && !is_single) {
599 IRB_LOOP(vmovups(ptr[t + irb*BUFFER_BLOCK + BUFFER_NEXT_OFFSET],
600 xreg(irb, xdiffdst_next)));
603 IRB_LOOP(vmovups(zreg(irb, za), EVEX_compress_addr(t, irb*BUFFER_BLOCK
604 + XMM_SIZE - 2*sizeof(float))));
605 IRB_LOOP(vmovups(zreg(irb, zb), EVEX_compress_addr(t, irb*BUFFER_BLOCK
606 + XMM_SIZE - 1*sizeof(float))));
607 IRB_LOOP(vmovups(zreg(irb, zd), EVEX_compress_addr(t, irb*BUFFER_BLOCK
608 + XMM_SIZE + 1*sizeof(float))));
609 IRB_LOOP(vmovups(zreg(irb, ze), EVEX_compress_addr(t, irb*BUFFER_BLOCK
610 + XMM_SIZE + 2*sizeof(float))));
611 IRB_LOOP(vaddps(zreg(irb, zdiffsrc), zreg(irb, zdiffsrc),
614 IRB_LOOP(vmovups(zreg(irb, zsrc), EVEX_compress_addr(src, irb*vlen)));
615 IRB_LOOP(vaddps(zreg(irb, zdiffsrc), zreg(irb, zdiffsrc),
617 IRB_LOOP(vaddps(zreg(irb, zdiffsrc), zreg(irb, zdiffsrc),
619 IRB_LOOP(vaddps(zreg(irb, zdiffsrc), zreg(irb, zdiffsrc),
621 IRB_LOOP(vmulps(zreg(irb, zsrc), zreg(irb, zsrc), znalphabeta));
623 IRB_LOOP(vmovups(zreg(irb, zws0),
624 EVEX_compress_addr(workspace0, irb*vlen)));
625 IRB_LOOP(vdivps(zreg(irb, zdiffdst), zreg(irb, zdiffdst),
627 IRB_LOOP(vfmadd213ps(zreg(irb, zdiffsrc), zreg(irb, zsrc),
628 zreg(irb, zdiffdst)));
630 Label unaligned_store, end_store;
631 test(diffsrc, vlen - 1);
632 jnz(unaligned_store, T_NEAR);
633 IRB_LOOP(uni_vmovntps(EVEX_compress_addr(diffsrc, irb*vlen),
634 zreg(irb, zdiffsrc)));
635 jmp(end_store, T_NEAR);
636 L(unaligned_store); {
637 IRB_LOOP(uni_vmovups(EVEX_compress_addr(diffsrc, irb*vlen),
638 zreg(irb, zdiffsrc)));
643 jit_avx512_common_lrn_kernel_f32(
644 const struct nChw16c_across &J,
648 void *code_ptr = nullptr,
649 size_t code_size = 1 * Xbyak::DEFAULT_MAX_CODE_SIZE)
650 : jit_generator(code_ptr, code_size)
652 , use_h_parallelism(use_h_parallel)
656 mov(src, ptr[param + 0]);
657 mov(diffdst, ptr[param + 8]);
658 mov(workspace0, ptr[param + 16]);
659 mov(workspace1, ptr[param + 24]);
660 mov(diffsrc, ptr[param + 32]);
664 int LSB = this->use_h_parallelism ? W : HW;
666 sub(t, BWD_RBC*BUFFER_BLOCK);
667 mov(imm_addr64, float2int(this->nalphabeta));
668 movq(xnalphabeta, imm_addr64);
669 vbroadcastss(znalphabeta, xnalphabeta);
671 is_first = J.version == -1 || J.version == -2;
672 is_last = J.version == +1 || J.version == +2;
673 is_single = J.version == 3;
675 if (is_first || is_single) {
676 vxorps(xmm1, xmm1, xmm1);
677 for(int irb = 0; irb < BWD_RBC; irb++) {
678 vmovups(ptr[t + irb*BUFFER_BLOCK], xmm1);
681 if (is_last || is_single) {
682 vxorps(xmm1, xmm1, xmm1);
683 for(int irb = 0; irb < BWD_RBC; irb++) {
684 vmovups(ptr[t + irb*BUFFER_BLOCK + BUFFER_NEXT_OFFSET], xmm1);
688 int LSREST = LSB % BWD_RBC;
689 int LS = LSB - LSREST;
698 compute_loop(BWD_RBC, 1, 1);
700 add(src, BWD_RBC*vlen);
701 add(diffsrc, BWD_RBC*vlen);
702 add(diffdst, BWD_RBC*vlen);
703 add(workspace0, BWD_RBC*vlen);
704 add(workspace1, BWD_RBC*vlen);
706 for(int irb = 0; irb < BWD_RBC; irb++)
709 jne(lrn_loop, T_NEAR);
713 compute_loop(LSREST, 1, this->use_h_parallelism ? 0 : 1);
715 add(t, BWD_RBC*BUFFER_BLOCK);
718 ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>(
724 status_t jit_avx512_common_lrn_bwd_t::pd_t::init() {
725 using namespace prop_kind;
726 using namespace alg_kind;
728 assert(engine()->kind() == engine_kind::cpu);
730 if (!mayiuse(avx512_common)) return unimplemented;
732 const memory_desc_wrapper data_d(data_pd_.desc());
734 && utils::one_of(desc()->prop_kind, backward, backward_data)
735 && utils::everyone_is(data_type::f32, desc()->data_desc.data_type)
736 && !has_zero_dim_memory()
737 && data_d.ndims() == 4
738 && data_d.dims()[1] % vsize == 0
739 && attr()->has_default_values();
740 if (!ok) return unimplemented;
743 dims_t ws_dims = { MB(), C(), H(), 2*W() };
744 mkldnn_memory_desc_init(&ws_d, 4, ws_dims, data_type::f32,
745 memory_format::nChw16c);
746 ws_pd_ = cpu_memory_t::pd_t(engine_, &ws_d);
748 auto fwd_ws_d_ = hint_fwd_pd_->workspace_pd()->desc();
750 && fwd_ws_d_->ndims == ws_pd_.desc()->ndims
751 && fwd_ws_d_->format == ws_pd_.desc()->format
752 && fwd_ws_d_->data_type == ws_pd_.desc()->data_type;
753 if (!ws_ok) return unimplemented;
755 bool args_ok_across = true
756 && desc()->alg_kind == lrn_across_channels
757 && desc()->local_size == 5
758 && desc()->lrn_beta == 0.75
759 && data_d.format() == nChw16c;
761 return args_ok_across ? success : unimplemented;
764 jit_avx512_common_lrn_bwd_t::jit_avx512_common_lrn_bwd_t(const pd_t *apd,
765 const input_vector &inputs, const output_vector &outputs)
766 : cpu_primitive_t(apd, inputs, outputs)
767 , use_h_parallelism(0), ker_(nullptr), ker_first_(nullptr)
768 , ker_last_(nullptr) {
769 const int C = pd()->C();
770 const int H = pd()->H();
771 const int W = pd()->W();
772 const int ls = pd()->desc()->local_size;
773 const float alpha = pd()->desc()->lrn_alpha / ls;
774 const float beta = pd()->desc()->lrn_beta;
776 use_h_parallelism = H > 28 ? 1 : 0;
778 if (C / vsize == 1) {
779 ker_ = new jit_avx512_common_lrn_kernel_f32(nChw16c_across(H, W, 3),
780 alpha, beta, use_h_parallelism);
782 ker_ = new jit_avx512_common_lrn_kernel_f32(nChw16c_across(H, W, 0),
783 alpha, beta, use_h_parallelism);
784 ker_first_ = new jit_avx512_common_lrn_kernel_f32(
785 nChw16c_across(H, W, -1), alpha, beta, use_h_parallelism);
786 ker_last_ = new jit_avx512_common_lrn_kernel_f32(
787 nChw16c_across(H, W, +1), alpha, beta, use_h_parallelism);
791 jit_avx512_common_lrn_bwd_t::~jit_avx512_common_lrn_bwd_t()
792 { delete ker_; delete ker_first_; delete ker_last_; }
794 void jit_avx512_common_lrn_bwd_t::execute_backward() const {
795 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
796 auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
797 auto ws = reinterpret_cast<const data_t *>(this->input_memory(2));
798 auto diff_src = reinterpret_cast<data_t *>(this->memory(0));
800 const int N = pd()->MB();
801 const int C = pd()->C();
802 const int H = pd()->H();
803 const int W = pd()->W();
804 const int C16 = C / vsize;
805 const size_t work_amount = use_h_parallelism ? N*C16*H : N*C16;
807 parallel(0, work_amount, [&](const int ithr, const int nthr) {
808 size_t start{0}, end{0};
810 balance211(work_amount, nthr, ithr, start, end);
811 if (use_h_parallelism) {
812 int n{0}, c16{0}, h{0};
813 nd_iterator_init(start, n, N, h, H, c16, C16);
814 for (size_t iwork = start; iwork < end; ++iwork) {
815 auto offset = n*C*H*W + c16*H*W*vsize
817 auto ws_offset0 = n*C*H*2*W + c16*H*2*W*vsize
819 auto ws_offset1 = ws_offset0 + W*vsize;
822 args.src = &src[offset];
823 args.diff_dst = &diff_dst[offset];
824 args.ws0 = &ws[ws_offset0];
825 args.ws1 = &ws[ws_offset1];
826 args.diff_src = &diff_src[offset];
831 (*ker_first_)(&args);
832 else if (c16 == C16 - 1)
836 nd_iterator_step(n, N, h, H, c16, C16);
840 nd_iterator_init(start, n, N, c16, C16);
841 for (size_t iwork = start; iwork < end; ++iwork) {
842 auto offset = n*C*H*W + c16*H*W*vsize;
843 auto ws_offset0 = n*C*H*2*W + c16*H*2*W*vsize;
844 auto ws_offset1 = ws_offset0 + H*W*vsize;
847 args.src = &src[offset];
848 args.diff_dst = &diff_dst[offset];
849 args.ws0 = &ws[ws_offset0];
850 args.ws1 = &ws[ws_offset1];
851 args.diff_src = &diff_src[offset];
856 (*ker_first_)(&args);
857 else if (c16 == C16 - 1)
862 nd_iterator_step(n, N, c16, C16);