1 /*******************************************************************************
2 * Copyright 2017-2019 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include <mkldnn_types.h>
18 #include "mkldnn_types.h"
19 #include "mkldnn_thread.hpp"
22 #include "jit_generator.hpp"
24 #include "jit_uni_eltwise.hpp"
25 #include "jit_avx512_core_bf16cvt.hpp"
27 #define GET_OFF(field) offsetof(jit_args, field)
33 using namespace Xbyak;
35 template <cpu_isa_t isa>
36 void jit_uni_eltwise_injector_f32<isa>::injector_preamble(size_t start_idx,
38 preserved_vecs_count = 0;
39 vecs_to_preserve = (size_t)aux_vecs_count(alg_);
40 start_idx_tail = start_idx;
42 // For sse42 mask register has to be Xmm(0)
43 if (isa == sse42 && vecs_to_preserve > 0) {
45 assert(idx < start_idx);
46 preserved_vec_idxs[preserved_vecs_count++] = idx;
49 for (size_t idx = preserved_vecs_count; idx < vecs_count; idx++) {
50 if (preserved_vecs_count >= vecs_to_preserve) break;
51 if (start_idx <= idx && idx < end_idx) continue;
53 preserved_vec_idxs[preserved_vecs_count++] = idx;
56 size_t preserved_vecs_count_tail = vecs_to_preserve - preserved_vecs_count;
57 for (size_t i = 0; i < preserved_vecs_count_tail; i++) {
58 preserved_vec_idxs[preserved_vecs_count++] = start_idx_tail++;
61 assert(preserved_vecs_count == vecs_to_preserve);
66 if (preserved_vecs_count)
67 h->sub(h->rsp, preserved_vecs_count * vlen);
69 for (size_t i = 0; i < preserved_vecs_count; ++i)
70 h->uni_vmovups(h->ptr[h->rsp + i * vlen],
71 Vmm(preserved_vec_idxs[i]));
79 template <cpu_isa_t isa>
80 void jit_uni_eltwise_injector_f32<isa>::injector_preamble_tail(size_t start_idx)
82 size_t tail_vecs_to_preserve = start_idx_tail - start_idx;
83 if (tail_vecs_to_preserve == 0) return;
85 const int idx_off = vecs_to_preserve - tail_vecs_to_preserve;
89 h->add(h->rsp, idx_off * vlen);
91 for (size_t i = 0; i < tail_vecs_to_preserve; ++i)
92 h->uni_vmovups(Vmm(preserved_vec_idxs[idx_off + i]),
93 h->ptr[h->rsp + i * vlen]);
96 for (size_t i = 0; i < tail_vecs_to_preserve; ++i)
97 preserved_vec_idxs[idx_off + i] += tail_vecs_to_preserve;
100 for (size_t i = 0; i < tail_vecs_to_preserve; ++i)
101 h->uni_vmovups(h->ptr[h->rsp + i * vlen],
102 Vmm(preserved_vec_idxs[idx_off + i]));
105 h->sub(h->rsp, idx_off * vlen);
111 template <cpu_isa_t isa>
112 void jit_uni_eltwise_injector_f32<isa>::injector_postamble() {
113 if (!save_state_) return;
115 for (size_t i = 0; i < preserved_vecs_count; ++i)
116 h->uni_vmovups(Vmm(preserved_vec_idxs[i]),
117 h->ptr[h->rsp + i * vlen]);
119 if (preserved_vecs_count)
120 h->add(h->rsp, preserved_vecs_count * vlen);
125 template <cpu_isa_t isa>
126 void jit_uni_eltwise_injector_f32<isa>::assign_regs() {
127 vmm_mask = Vmm(preserved_vec_idxs[0]);
128 vmm_aux0 = Vmm(preserved_vec_idxs[0]);
129 vmm_aux1 = Vmm(preserved_vec_idxs[1]);
130 vmm_aux2 = Vmm(preserved_vec_idxs[2]);
131 vmm_aux3 = Vmm(preserved_vec_idxs[3]);
132 vmm_aux4 = Vmm(preserved_vec_idxs[4]);
135 template <cpu_isa_t isa>
136 void jit_uni_eltwise_injector_f32<isa>::exp_compute_vector(const Vmm &vmm_src) {
137 h->uni_vminps(vmm_src, vmm_src, table_val(10));
138 h->uni_vmaxps(vmm_src, vmm_src, table_val(11));
139 h->uni_vmovups(vmm_aux0, vmm_src);
141 // fx = x * log2ef + 0.5
142 h->uni_vmulps(vmm_src, vmm_src, table_val(2));
143 h->uni_vaddps(vmm_src, vmm_src, table_val(1));
146 h->uni_vroundps(vmm_aux1, vmm_src, _op_floor);
148 //keep fx for further computations
149 h->uni_vmovups(vmm_src, vmm_aux1); //vmm_src = fx
152 h->uni_vfnmadd231ps(vmm_aux0, vmm_aux1, table_val(3));
155 h->uni_vcvtps2dq(vmm_aux1, vmm_src);
156 h->uni_vpaddd(vmm_aux1, vmm_aux1, table_val(4));
157 h->uni_vpslld(vmm_aux1, vmm_aux1, 23); //Vmm(6) = 2^-fx
160 h->uni_vmovups(vmm_src, table_val(9));
162 h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(8));
164 h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(7));
166 h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(6));
168 h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(0));
170 h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(5)); //exp(q)
172 h->uni_vmulps(vmm_src, vmm_src, vmm_aux1);
175 template <cpu_isa_t isa>
176 void jit_uni_eltwise_injector_f32<isa>::relu_compute_vector(const Vmm &vmm_src)
178 const int alpha_off = 0, zero_off = 1;
180 h->uni_vmovups(vmm_aux1, vmm_src);
182 h->movups(vmm_mask, vmm_src);
183 h->mulps(vmm_src, table_val(alpha_off));
184 h->cmpps(vmm_mask, table_val(zero_off), _cmp_nle_us);
185 h->blendvps(vmm_src, vmm_aux1);
186 } else if (isa == avx2) {
187 h->vmulps(vmm_src, vmm_src, table_val(alpha_off));
188 h->vcmpgtps(vmm_mask, vmm_aux1, table_val(zero_off));
189 h->vblendvps(vmm_src, vmm_src, vmm_aux1, vmm_mask);
190 } else if (isa == avx512_common) {
191 h->vmulps(vmm_src, vmm_src, table_val(alpha_off));
192 h->vcmpps(k_mask, vmm_aux1, table_val(zero_off), _cmp_nle_us);
193 h->vblendmps(vmm_src | k_mask, vmm_src, vmm_aux1);
197 template <cpu_isa_t isa>
198 void jit_uni_eltwise_injector_f32<isa>::relu_zero_ns_compute_vector(
199 const Vmm &vmm_src) {
200 const int zero_off = 1;
201 h->uni_vmaxps(vmm_src, vmm_src, table_val(zero_off));
204 template <cpu_isa_t isa>
205 void jit_uni_eltwise_injector_f32<isa>::elu_compute_vector(const Vmm &vmm_src) {
206 const int alpha_off = 23, zero_off = 24;
209 h->uni_vmovups(vmm_aux2, vmm_src);
210 exp_compute_vector(vmm_src);
212 // alpha * (exp(x) - 1)
213 h->uni_vsubps(vmm_src, vmm_src, table_val(0));
214 h->uni_vmulps(vmm_src, vmm_src, table_val(alpha_off));
218 h->pxor(vmm_mask, vmm_mask);
219 h->cmpps(vmm_mask, vmm_aux2, _cmp_le_os);
220 h->blendvps(vmm_src, vmm_aux2);
221 } else if (isa == avx2) {
222 h->uni_vcmpgtps(vmm_mask, vmm_aux2, table_val(zero_off));
223 h->uni_vblendvps(vmm_src, vmm_src, vmm_aux2, vmm_mask);
224 } else if (isa == avx512_common) {
225 h->vcmpps(k_mask, vmm_aux2, table_val(zero_off), _cmp_nle_us);
226 h->vblendmps(vmm_src | k_mask, vmm_src, vmm_aux2);
230 template <cpu_isa_t isa>
231 void jit_uni_eltwise_injector_f32<isa>::tanh_compute_vector(const Vmm &vmm_src)
233 // # comes from Taylor expansion error bound
234 // > linear_sat_point = single(sqrt(3) * 1b-12);
235 // # comes from the exp formula cancellation
236 // > exp_bound_point = (single(log(3)/2));
237 // # comes from rounding accuracy in float
238 // > one_sat_point = round(atanh(1 - 1b-25), single, RU);
239 // > P = fpminimax(f, [|1, 3, 5, 7, 9|], [|24... |],
240 // [linear_sat_point, exp_bound_point], relative, floating);
241 // > err_bound = D(sup(supnorm(P, tanh(x),
242 // [linear_sat_point, exp_bound_point], relative, theta)));
243 // 0x1.fffd6f00b9539p-25
245 // x * (0x1.fffffep-1 + x^0x1p1 * (-0x1.55539ep-2 + x^0x1p1 *
246 // (0x1.10be3ep-3 + x^0x1p1 * (-0x1.ae57b4p-5
247 // + x^0x1p1 * 0x1.09fa1p-6))))
250 // vmm_src contains input
251 // vmm_aux0 contains mask of currently valid results.
252 // 1 is need computation, 0 is already computed
253 // vmm_aux1 contains current output
254 // vmm_aux2, vmm_aux3 contains auxiliary values
255 // vmm_aux4 contains the original sign of inputs
257 Label end_tanh_label;
259 auto test_exit =[&](Xbyak::Address threshold){
260 // is not necessary for >AVX, but should not matter on perf
261 h->uni_vmovups(vmm_aux0, vmm_src);
262 if (isa == avx512_common){
263 h->vcmpps(k_mask, vmm_aux0, threshold, 0x5);
264 h->kortestw(k_mask, k_mask);
266 h->uni_vcmpgeps(vmm_aux0, vmm_aux0, threshold);
267 h->uni_vtestps(vmm_aux0, vmm_aux0);
269 h->jz(end_tanh_label, Xbyak::CodeGenerator::T_NEAR);
272 auto blend_results=[&](Vmm vmm_partial_res){
273 if (isa == avx512_common)
274 h->vblendmps(vmm_aux1 | k_mask, vmm_aux1, vmm_partial_res);
276 h->uni_vblendvps(vmm_aux1, vmm_aux1, vmm_partial_res, vmm_aux0);
279 // because tanh(x) = -tanh(-x), we extract sign to make x postive
280 // and reapply sign at the end
281 // mov is not necessary for >AVX, but should not matter for performance
282 h->uni_vmovups(vmm_aux4, vmm_src);
283 h->uni_vandps(vmm_aux4, vmm_aux4, table_val(12));
284 h->uni_vandps(vmm_src, vmm_src, table_val(17));
286 // if x < linear_sat_point for all inputs, we just return the input
287 h->uni_vmovups(vmm_aux1, vmm_src);
288 test_exit(table_val(13));
290 // if one of the mask is one, we have to compute an better approx
291 h->uni_vmovups(vmm_aux2, vmm_src);
292 h->uni_vmulps(vmm_aux2, vmm_aux2, vmm_aux2);
293 h->uni_vmovups(vmm_aux3, table_val(22));
294 h->uni_vfmadd213ps(vmm_aux3, vmm_aux2, table_val(21));
295 h->uni_vfmadd213ps(vmm_aux3, vmm_aux2, table_val(20));
296 h->uni_vfmadd213ps(vmm_aux3, vmm_aux2, table_val(19));
297 h->uni_vfmadd213ps(vmm_aux3, vmm_aux2, table_val(18));
298 h->uni_vmulps(vmm_aux3, vmm_aux3, vmm_src);
300 // we blend only the result that need update
301 blend_results(vmm_aux3);
303 // if x < exp_bound_point, we go to return point
304 test_exit(table_val(14));
306 // if not we use a better approx 1 - 2 / (1 + exp(2x))
308 h->uni_vmovups(vmm_aux3, vmm_src);
309 h->uni_vaddps(vmm_aux3, vmm_aux3, vmm_aux3);
312 // We need to save kmask, vmm_aux0, vmm_aux1 and vmm_src as exp can use them
313 // vmm_src is not more read afterwards, so we do not have to save it
314 auto stack_size = 3 * vlen + (isa == avx512_common) * 4;
315 h->sub(h->rsp, stack_size);
316 h->uni_vmovups(h->ptr[h->rsp + 0 * vlen], vmm_aux0);
317 h->uni_vmovups(h->ptr[h->rsp + 1 * vlen], vmm_aux1);
318 h->uni_vmovups(h->ptr[h->rsp + 2 * vlen], vmm_src);
319 if (isa == avx512_common)
320 h->kmovw(h->ptr[h->rsp + 3 * vlen], k_mask);
322 exp_compute_vector(vmm_aux3);
324 h->uni_vmovups(vmm_aux0, h->ptr[h->rsp + 0 * vlen]);
325 h->uni_vmovups(vmm_aux1, h->ptr[h->rsp + 1 * vlen]);
326 h->uni_vmovups(vmm_src, h->ptr[h->rsp + 2 * vlen]);
327 if (isa == avx512_common)
328 h->kmovw(k_mask, h->ptr[h->rsp + 3 * vlen]);
329 h->add(h->rsp, stack_size);
332 h->uni_vaddps(vmm_aux3, vmm_aux3, table_val(0));
334 // 1 - 2 / (1 + exp(2x))
335 h->uni_vmovups(vmm_aux2, table_val(16));
336 h->uni_vdivps(vmm_aux2, vmm_aux2, vmm_aux3);
337 h->uni_vaddps(vmm_aux2, vmm_aux2, table_val(0));
339 // we blend only the result that need update
340 blend_results(vmm_aux2);
342 // finally, we saturate to 1 if needed
343 // TODO: maybe move that up if most inputs saturate in practice
344 if (isa == avx512_common)
345 h->vcmpps(k_mask, vmm_aux0, table_val(15), 0x5);
347 h->uni_vmovups(vmm_aux0, vmm_src);
348 h->uni_vcmpgeps(vmm_aux0, vmm_aux0, table_val(15));
350 h->uni_vmovups(vmm_aux2, table_val(0));
351 blend_results(vmm_aux2);
353 h->L(end_tanh_label);
355 // we apply the sign of x to the result and we are done
356 h->uni_vmovups(vmm_src, vmm_aux1);
357 h->uni_vpxor(vmm_src, vmm_src, vmm_aux4);
361 template <cpu_isa_t isa>
362 void jit_uni_eltwise_injector_f32<isa>::square_compute_vector(
363 const Vmm &vmm_src) {
364 h->uni_vmulps(vmm_src, vmm_src, vmm_src);
367 template <cpu_isa_t isa>
368 void jit_uni_eltwise_injector_f32<isa>::abs_compute_vector(const Vmm &vmm_src) {
369 // compute abs(x) = _mm_and_ps(x, 01111..111));
370 h->uni_vandps(vmm_src, vmm_src, table_val(0));
373 template <cpu_isa_t isa>
374 void jit_uni_eltwise_injector_f32<isa>::sqrt_compute_vector(const Vmm &vmm_src)
376 if (isa == avx512_common) {
377 h->vcmpps(k_mask, vmm_src, table_val(0), _cmp_nle_us);
378 h->uni_vsqrtps(vmm_aux1, vmm_src);
379 h->uni_vmovups(vmm_src, table_val(0));
380 h->vblendmps(vmm_src | k_mask, vmm_src, vmm_aux1);
382 h->uni_vmovups(vmm_mask, vmm_src);
383 h->uni_vcmpgtps(vmm_mask, vmm_mask, table_val(0));
384 h->uni_vsqrtps(vmm_aux1, vmm_src);
385 h->uni_vmovups(vmm_src, table_val(0));
386 h->uni_vblendvps(vmm_src, vmm_src, vmm_aux1, vmm_mask);
390 template <cpu_isa_t isa>
391 void jit_uni_eltwise_injector_f32<isa>::linear_compute_vector(
392 const Vmm &vmm_src) {
393 // compute x = alpha * x + beta;
394 h->uni_vmovups(vmm_aux0, table_val(0));
395 h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(1));
398 template <cpu_isa_t isa>
399 void jit_uni_eltwise_injector_f32<isa>::bounded_relu_compute_vector(
400 const Vmm &vmm_src) {
401 // compute bounded relu */
402 h->uni_vmaxps(vmm_src, vmm_src, table_val(1));
403 h->uni_vminps(vmm_src, vmm_src, table_val(0));
406 template <cpu_isa_t isa>
407 void jit_uni_eltwise_injector_f32<isa>::soft_relu_compute_vector(
408 const Vmm &vmm_src) {
410 h->uni_vmovups(vmm_aux2, vmm_src);
412 h->uni_vminps(vmm_src, vmm_src, table_val(24));
413 h->uni_vmaxps(vmm_src, vmm_src, table_val(25));
414 h->uni_vmovups(vmm_aux1, vmm_src);
416 // fx = x * log2ef + 0.5
417 h->uni_vmulps(vmm_src, vmm_src, table_val(2));
418 h->uni_vaddps(vmm_src, vmm_src, table_val(1));
421 h->uni_vroundps(vmm_aux0, vmm_src, _op_floor);
423 // keep fx for further computations
424 h->uni_vmovups(vmm_src, vmm_aux0); //vmm_src = fx
425 // calculation fx * ln2
426 h->uni_vmulps(vmm_aux0, vmm_aux0, table_val(3));
428 h->uni_vsubps(vmm_aux1, vmm_aux1, vmm_aux0);
430 h->uni_vmovups(vmm_aux3, table_val(22));
432 h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(21));
434 h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(20));
436 h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(19));
438 h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(0));
440 h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(17));
443 if (isa == avx512_common) {
444 h->vmulps(vmm_aux1, vmm_src, table_val(23));
445 h->vcvtps2dq(vmm_aux1, vmm_aux1);
447 h->uni_vcvtps2dq(vmm_aux1, vmm_src);
448 h->uni_vpsignd(vmm_aux1, vmm_aux1, table_val(23));
451 h->uni_vpaddd(vmm_aux1, vmm_aux1, table_val(4));
452 h->uni_vpslld(vmm_aux1, vmm_aux1, 23); //vmm_aux1 = 2^-fx
453 // calculate ln(1 + y)
454 h->uni_vaddps(vmm_aux3, vmm_aux3, vmm_aux1);
455 // x = y; y is free; keep x for further computations
456 h->uni_vmovups(vmm_src, vmm_aux3);
458 h->uni_vpsrld(vmm_src, vmm_src, 23);
459 h->uni_vcvtdq2ps(vmm_src, vmm_src);
460 // got n. where n is x = 2^n * y. y = 0.5 .. 1
461 h->uni_vsubps(vmm_src, vmm_src, table_val(5));
463 h->uni_vandps(vmm_aux3, vmm_aux3, table_val(6));
464 // got y. (mantisa) 0.5 < y < 1
465 h->uni_vorps(vmm_aux3, vmm_aux3, table_val(7));
467 h->uni_vsubps(vmm_aux3, vmm_aux3, table_val(0));
469 h->uni_vmovups(vmm_aux1, table_val(16));
471 h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(15));
473 h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(14));
475 h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(13));
477 h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(12));
479 h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(11));
481 h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(10));
483 h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(9));
484 // y = y * x + p0 ; p0 = 0
485 h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(8));
486 //calculate ln(2) * n
487 h->uni_vmulps(vmm_src, vmm_src, table_val(3));
488 h->uni_vaddps(vmm_aux1, vmm_aux1, vmm_src);
489 h->uni_vaddps(vmm_aux1, vmm_aux1, vmm_aux0);
491 // get vmm_mask = src > max logf
492 h->uni_vmovups(vmm_mask, vmm_aux2);
493 if (isa == avx512_common) {
494 // y = (x < max log f) ? soft_relu(x) : x
495 h->vcmpps(k_mask, vmm_mask, table_val(24), _cmp_nle_us);
496 h->vblendmps(vmm_aux1 | k_mask, vmm_aux1, vmm_aux2);
498 // y = (x < max log f) ? soft_relu(x) : x
499 h->uni_vcmpgtps(vmm_mask, vmm_mask, table_val(24));
500 h->uni_vblendvps(vmm_aux1, vmm_aux1, vmm_aux2, vmm_mask);
503 h->uni_vmovups(vmm_src, vmm_aux1);
506 template <cpu_isa_t isa>
507 void jit_uni_eltwise_injector_f32<isa>::logistic_compute_vector(
508 const Vmm &vmm_src) {
509 // we store the original sign and make x negative
510 // IMPORTANT: we assume vmm_aux0 to be xmm0, as for sse4.2 path it is required
511 // IMPORTANT: we use vmm_aux2 for the mask as exp_compute does not use it.
512 h->uni_vmovups(vmm_aux2, vmm_src);
513 h->uni_vandps(vmm_aux2, vmm_aux2, table_val(12));
514 h->uni_vorps(vmm_src, vmm_src, table_val(12));
516 exp_compute_vector(vmm_src);
518 h->uni_vmovups(vmm_aux1, vmm_src);
520 h->uni_vaddps(vmm_aux1, vmm_aux1, table_val(0));
521 // y = exp(x) / (exp(x) + 1)
522 h->uni_vdivps(vmm_src, vmm_src, vmm_aux1);
524 // Now we have to apply the "symmetry" based on original sign
525 h->uni_vmovups(vmm_aux3, table_val(0));
526 h->uni_vsubps(vmm_aux3, vmm_aux3, vmm_src);
527 if (isa == avx512_common) {
528 h->vptestmd(k_mask, vmm_aux2, vmm_aux2);
529 h->vblendmps(vmm_aux3 | k_mask, vmm_aux3, vmm_src);
531 h->uni_vmovups(vmm_aux0, vmm_aux2);// The mask should be xmm0 for sse4.2
532 h->uni_vblendvps(vmm_aux3, vmm_aux3, vmm_src, vmm_aux0);
534 h->uni_vmovups(vmm_src, vmm_aux3);
537 template <cpu_isa_t isa>
538 void jit_uni_eltwise_injector_f32<isa>::clamp_compute_vector(
539 const Vmm &vmm_src) {
541 h->uni_vmaxps(vmm_src, vmm_src, table_val(1));
542 h->uni_vminps(vmm_src, vmm_src, table_val(0));
545 template <cpu_isa_t isa>
546 void jit_uni_eltwise_injector_f32<isa>::relu_prepare_table() {
547 for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_));
548 for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0);
551 template <cpu_isa_t isa>
552 void jit_uni_eltwise_injector_f32<isa>::elu_prepare_table() {
553 const unsigned int cvals[] = {
554 0x3f800000, // [0] 1.0f
555 0x3f000000, // [1] 0.5f
556 0x3fb8aa3b, // [2] log2ef = 1.44269502f
557 0x3f317218, // [3] ln2f = 0.69314718f
558 0x0000007f, // [4] 0x7f
560 0x3f800001, // [5] p0 = 1.0000001f
561 0x3efffe85, // [6] p2 = 0.4999887f
562 0x3e2aaa3e, // [7] p3 = 0.16666505f
563 0x3d2bb1b1, // [8] p4 = 0.041917507f
564 0x3c091ec1, // [9] p5 = 0.008369149f
565 0x42b0c0a5, //[10] max logf = 88.3762589f
566 0xc1766666, //[11] min logf = -14.5f
567 // tanh(x) constants,
568 0x80000000, //[12] mask to extract sign
569 0x39ddb3d7, //[13] arg below which tanh(x) = x
570 0x3f0c9f54, //[14] arg below which pol approx is valid
571 0x41102cb4, //[15] arg after which tanh(x) = 1
572 0xc0000000, //[16] -2.0f
573 0x7fffffff, //[17] mask to make positive
575 0x3f7fffff, //[18] p0
576 0xbeaaa9cf, //[19] p1
577 0x3e085f1f, //[20] p2
578 0xbd572bda, //[21] p3
579 0x3c84fd08, //[22] p4
582 for (size_t i = 0; i < sizeof(cvals) / sizeof(cvals[0]); ++i) {
583 for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(cvals[i]);
586 for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_));
587 for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0);
590 template <cpu_isa_t isa>
591 void jit_uni_eltwise_injector_f32<isa>::soft_relu_prepare_table() {
592 const unsigned int cvals[] = {
593 0x3f800000, // [0] 1.0f
594 0x3f000000, // [1] 0.5f
595 0x3fb8aa3b, // [2] log2ef = 1.44269502f
596 0x3f317218, // [3] ln2f = 0.69314718f
597 0x0000007f, // [4] 0x7f
598 0x42fc0000, // [5] 126
599 0x807fffff, // [6] and with (to get 0.5 * mantissa)
600 0x3f000000, // [7] or with (to get 0.5 * mantissa)
601 // ln(1 + x) polynomial
602 0xb2b4637d, // [8] p0 = 0.0000000244f
603 0x3f7fff8e, // [9] p1 = 0.9999976971f
604 0xbf001759, //[10] p2 = -0.5002478215f
605 0x3ea70608, //[11] p3 = 0.3272714505f
606 0xbea3d7bf, //[12] p4 = -0.3153830071f
607 0xbe361d04, //[13] p5 = -0.1701777461f
608 0xbfa8f1e6, //[14] p6 = -1.3254635147f
609 0xbfe1e812, //[15] p7 = -1.7971917960f
610 0xbfc4d30e, //[16] p8 = -1.5652673123f
612 0x3f800001, //[17] p0 = 1.0000001f
613 0x3f800000, //[18] p1 = 1.0f
614 0x3efffe85, //[19] p2 = 0.4999887f
615 0x3e2aaa3e, //[20] p3 = 0.16666505f
616 0x3d2bb1b1, //[21] p4 = 0.041917507f
617 0x3c091ec1, //[22] p5 = 0.008369149f
618 0xbf800000, //[23] is required for sign changing
619 0x42b0c0a5, //[24] max logf = 88.3762589f
620 0xc1766666 //[25] min logf = -14.5f
623 for (size_t i = 0; i < sizeof(cvals) / sizeof(cvals[0]); ++i) {
624 for (size_t d = 0; d < vlen / sizeof(float); ++d) {
630 template <cpu_isa_t isa>
631 void jit_uni_eltwise_injector_f32<isa>::abs_prepare_table() {
632 for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0x7fffffff);
635 template <cpu_isa_t isa>
636 void jit_uni_eltwise_injector_f32<isa>::sqrt_prepare_table() {
637 for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0);
640 template <cpu_isa_t isa>
641 void jit_uni_eltwise_injector_f32<isa>::linear_prepare_table() {
642 for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_));
643 for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(beta_));
646 template <cpu_isa_t isa>
647 void jit_uni_eltwise_injector_f32<isa>::bounded_relu_prepare_table() {
648 for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_));
649 for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0);
652 template <cpu_isa_t isa>
653 void jit_uni_eltwise_injector_f32<isa>::clamp_prepare_table() {
654 for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_));
655 for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(beta_));
658 template <cpu_isa_t isa>
659 int jit_uni_eltwise_injector_f32<isa>::aux_vecs_count(alg_kind_t alg_) {
661 case alg_kind::eltwise_relu: return (alpha_ == 0.f) ? 0 : 2;
662 case alg_kind::eltwise_elu: return 4;
663 case alg_kind::eltwise_tanh: return 5;
664 case alg_kind::eltwise_square: return 0;
665 case alg_kind::eltwise_abs: return 0;
666 case alg_kind::eltwise_sqrt: return 2;
667 case alg_kind::eltwise_linear: return 1;
668 case alg_kind::eltwise_bounded_relu: return 0;
669 case alg_kind::eltwise_soft_relu: return 4;
670 case alg_kind::eltwise_logistic: return 4;
671 case alg_kind::eltwise_clamp: return 0;
672 case alg_kind::eltwise_exp: return 4;
673 default: assert(!"unsupported eltwise algorithm");
679 template <cpu_isa_t isa>
680 void jit_uni_eltwise_injector_f32<isa>::compute_body(size_t start_idx,
682 using namespace alg_kind;
683 for (size_t idx = start_idx; idx < end_idx; idx++) {
686 if (alpha_ == 0.f) relu_zero_ns_compute_vector(Vmm(idx));
687 else relu_compute_vector(Vmm(idx));
689 case eltwise_elu: elu_compute_vector(Vmm(idx)); break;
690 case eltwise_tanh: tanh_compute_vector(Vmm(idx)); break;
691 case eltwise_square: square_compute_vector(Vmm(idx)); break;
692 case eltwise_abs: abs_compute_vector(Vmm(idx)); break;
693 case eltwise_sqrt: sqrt_compute_vector(Vmm(idx)); break;
694 case eltwise_linear: linear_compute_vector(Vmm(idx)); break;
695 case eltwise_bounded_relu: bounded_relu_compute_vector(Vmm(idx)); break;
696 case eltwise_soft_relu: soft_relu_compute_vector(Vmm(idx)); break;
697 case eltwise_logistic: logistic_compute_vector(Vmm(idx)); break;
698 case eltwise_clamp: clamp_compute_vector(Vmm(idx)); break;
699 case eltwise_exp: exp_compute_vector(Vmm(idx)); break;
700 default: assert(!"unsupported eltwise algorithm");
705 template <cpu_isa_t isa>
706 void jit_uni_eltwise_injector_f32<isa>::compute_vector_range(size_t start_idx,
708 assert(start_idx < end_idx && end_idx <= vecs_count);
710 injector_preamble(start_idx, end_idx);
711 compute_body(start_idx_tail, end_idx);
712 injector_preamble_tail(start_idx);
713 compute_body(start_idx, start_idx_tail);
714 injector_postamble();
717 template <cpu_isa_t isa>
718 void jit_uni_eltwise_injector_f32<isa>::prepare_table(bool gen_table) {
719 using namespace alg_kind;
726 case eltwise_relu: relu_prepare_table(); break;
729 case eltwise_logistic:
731 elu_prepare_table(); break;
732 case eltwise_soft_relu: soft_relu_prepare_table(); break;
733 case eltwise_abs: abs_prepare_table(); break;
734 case eltwise_sqrt: sqrt_prepare_table(); break;
735 case eltwise_linear: linear_prepare_table(); break;
736 case eltwise_bounded_relu: bounded_relu_prepare_table(); break;
737 case eltwise_square: break;
738 case eltwise_clamp: clamp_prepare_table(); break;
739 default: assert(!"unsupported eltwise algorithm");
744 template struct jit_uni_eltwise_injector_f32<avx512_common>;
745 template struct jit_uni_eltwise_injector_f32<avx2>;
746 template struct jit_uni_eltwise_injector_f32<sse42>;
751 const void *for_comparison;
756 struct jit_uni_eltwise_kernel_f32 : public c_compatible {
757 const eltwise_desc_t &desc_;
759 void (*ker_)(const jit_args *);
760 void operator()(const jit_args *args) { assert(ker_); ker_(args); }
762 jit_uni_eltwise_kernel_f32(const eltwise_desc_t &desc)
763 : desc_(desc), ker_(nullptr) {}
764 virtual ~jit_uni_eltwise_kernel_f32() {}
767 bool is_bwd() const { return desc_.prop_kind == prop_kind::backward_data; }
773 template <cpu_isa_t isa>
774 struct jit_uni_relu_kernel_f32 : public jit_uni_eltwise_kernel_f32,
777 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_relu_kernel_f32)
779 void compute_step(bool vectorize, const int uf, const int shift) {
780 for (int i = 0; i < uf; i++) {
781 auto addr_fwd = ptr[reg_from + i * shift];
782 auto addr_bwd = ptr[reg_for_comparison + i * shift];
785 vmovups(Ymm_src(i + 1), addr_fwd);
786 vpermw(Vmm(i + 1) | k_mask_cvt | T_z, zmm_idx, Zmm_src(i + 1));
788 uni_vmovups(Vmm(i + 1), addr_fwd);
792 vmovups(Ymm_src(uf + i + 1), addr_bwd);
793 vpermw(Vmm(uf + i + 1) | k_mask_cvt | T_z,
794 zmm_idx, Zmm_src(uf + i + 1));
796 uni_vmovups(Vmm(uf + i + 1), addr_bwd);
801 vmovdqu16(Ymm_src(i + 1) | k_tail_mask, addr_fwd);
802 vpermw(Vmm(i + 1) | k_mask_cvt | T_z, zmm_idx,
805 movss(Xmm(i + 1), addr_fwd);
809 vmovdqu16(Ymm_src(uf + i + 1) | k_tail_mask, addr_bwd);
810 vpermw(Vmm(uf + i + 1) | k_mask_cvt | T_z, zmm_idx,
811 Zmm_src(uf + i + 1));
813 movss(Xmm(uf + i + 1), addr_bwd);
820 for (int i = 0; i < uf; i++) {
821 movups(Vmm(2 * uf + i + 1), Vmm(i + 1));
822 mulps(Vmm(2 * uf + i + 1), vmm_ns);
826 movups(mask, Vmm(uf + i + 1));
827 cmpps(mask, vmm_zero, _cmp_nle_us);
829 movups(mask, Vmm(i + 1));
830 cmpps(mask, vmm_zero, _cmp_nle_us);
832 blendvps(Vmm(2 * uf + i + 1), Vmm(i + 1));
835 for (int i = 0; i < uf; i++) {
836 vmulps(Vmm(2 * uf + i + 1), Vmm(i + 1), vmm_ns);
839 vcmpgtps(vmm_mask, Vmm(uf + i + 1), vmm_zero);
841 vcmpgtps(vmm_mask, Vmm(i + 1), vmm_zero);
843 vblendvps(Vmm(2 * uf + i + 1), Vmm(2 * uf + i + 1),
844 Vmm(i + 1), vmm_mask);
848 vcmpps(k_mask, Vmm(uf + i + 1), vmm_zero, _cmp_nle_us);
850 vcmpps(k_mask, Vmm(i + 1), vmm_zero, _cmp_nle_us);
851 vblendmps(Vmm(2 * uf + i + 1) | k_mask, Vmm(2 * uf + i + 1),
856 auto store_data =[&] (opmask_t _kmask, int i) {
858 bf16_emu_->r_vcvtneps2bf16(Ymm_src(2 * uf + i + 1),
859 Zmm(2 * uf + i + 1));
861 vcvtneps2bf16(Ymm_src(2 * uf + i + 1), Vmm(2 * uf + i + 1));
862 vmovdqu16(ptr[reg_to + i * shift] | _kmask, Ymm_src(2 * uf + i + 1));
865 for (int i = 0; i < uf; i++) {
868 store_data(k_full_mask, i);
870 uni_vmovups(ptr[reg_to + i * shift], Vmm(2 * uf + i + 1));
873 store_data(k_tail_mask, i);
875 movss(ptr[reg_to + i * shift], Xmm(2 * uf + i + 1));
879 ~jit_uni_relu_kernel_f32() { delete bf16_emu_; }
881 jit_uni_relu_kernel_f32(const eltwise_desc_t &desc)
882 : jit_uni_eltwise_kernel_f32(desc)
884 , bf16_emu_(nullptr) {
885 assert(desc.alg_kind == alg_kind::eltwise_relu);
886 assert(isa == sse42 || isa == avx2 || isa == avx512_common);
888 Reg64 param = abi_param1;
890 is_cpx_ = mayiuse(avx512_core_bf16);
891 is_bf16_ = (desc.data_desc.data_type == data_type::bf16);
892 if (!is_cpx_ && is_bf16_)
893 bf16_emu_ = new bf16_emulation_t(this,
894 bf16_emu_reserv_1, bf16_emu_reserv_2,
895 bf16_emu_reserv_3, bf16_emu_reserv_4,
896 bf16_emu_reserv_5, bf16_emu_reserv_6);
898 const int simd_w = cpu_isa_traits<isa>::vlen / sizeof(float);
899 const int loop_dec[] = {simd_w, 1};
900 const int uf[] = {1, 1};
902 int _shift = (is_bf16_) ? sizeof(mkldnn_bfloat16_t) : sizeof(float);
903 int _vlen = (is_bf16_)
904 ? cpu_isa_traits<isa>::vlen / 2
905 : cpu_isa_traits<isa>::vlen;
907 const int shift[] = {_vlen, _shift};
908 const bool loop_vectorize[] = {true, false};
913 mov(mask_reg, 0xAAAAAAAA);
914 kmovd(k_mask_cvt, mask_reg);
917 kmovd(k_tail_mask, mask_reg);
919 mov(mask_reg, 0xffff);
920 kmovd(k_full_mask, mask_reg);
922 if (!is_cpx_ && is_bf16_)
923 bf16_emu_->init_vcvtneps2bf16();
925 mov(reg_from, ptr[param + GET_OFF(from)]);
927 mov(reg_for_comparison, ptr[param + GET_OFF(for_comparison)]);
928 mov(reg_to, ptr[param + GET_OFF(to)]);
929 mov(reg_work_amount, ptr[param + GET_OFF(work_amount)]);
932 mov(p_idx_table, idx_table);
933 vmovups(zmm_idx, ptr[p_idx_table]);
936 mov(imm_addr64, float2int(desc.alpha));
937 movq(xmm_ns, imm_addr64);
938 uni_vbroadcastss(vmm_ns, xmm_ns);
940 uni_vpxor(vmm_zero, vmm_zero, vmm_zero);
944 for (int id = 0; id < 2; id++) {
946 cmp(reg_work_amount, uf[id] * loop_dec[id] - 1);
947 jle(loop_label[id + 1], T_NEAR);
949 compute_step(loop_vectorize[id], uf[id], shift[id]);
951 add(reg_from, uf[id] * shift[id]);
952 add(reg_to, uf[id] * shift[id]);
954 add(reg_for_comparison, uf[id] * shift[id]);
956 sub(reg_work_amount, uf[id] * loop_dec[id]);
966 const uint16_t _idx[] = { 0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7,8,8,
967 9,9,10,10,11,11,12,12,13,13,14,14,15,15 };
968 for (size_t i = 0; i < sizeof(_idx) / sizeof(_idx[0]); ++i)
972 ker_ = (decltype(ker_))this->getCode();
976 using Vmm = typename utils::conditional3<isa == sse42, Xmm,
977 isa == avx2, Ymm, Zmm>::type;
978 using opmask_t = const Xbyak::Opmask;
980 Reg64 reg_from = rax;
981 Reg64 reg_for_comparison = is_bwd() ? rdx : reg_from;
983 Reg64 reg_work_amount = rsi;
984 Reg64 imm_addr64 = rbx;
986 Reg32 mask_reg = r14d;
987 Reg32 reg32_tmp = mask_reg;
989 Reg64 p_idx_table = r13;
991 Xmm xmm_ns = Xmm(14);
993 Vmm vmm_ns = Vmm(isa == avx512_common ? 30 : 14);
994 Vmm vmm_zero = Vmm(isa == avx512_common ? 31 : 15);
996 Vmm vmm_mask = Vmm(isa == avx512_common ? 28 : 12);
997 Opmask k_mask = Opmask(1);
999 inline Ymm Ymm_src(int i) {
1002 inline Zmm Zmm_src(int i) {
1005 Zmm zmm_idx = Zmm(29);
1007 Zmm bf16_emu_reserv_1 = Zmm(24);
1008 Zmm bf16_emu_reserv_2 = Zmm(25);
1009 Zmm bf16_emu_reserv_3 = Zmm(26);
1010 Reg64 bf16_emu_reserv_4 = r14;
1011 Zmm bf16_emu_reserv_5 = Zmm(27);
1012 Zmm bf16_emu_reserv_6 = Zmm(27);
1014 opmask_t k_mask_cvt = k7;
1015 opmask_t k_tail_mask = k6;
1016 opmask_t k_full_mask = k5;
1023 bf16_emulation_t *bf16_emu_;
1026 template <cpu_isa_t isa>
1027 struct jit_uni_kernel_fwd_f32: public jit_uni_eltwise_kernel_f32,
1028 public jit_generator {
1029 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_kernel_fwd_f32)
1031 jit_uni_kernel_fwd_f32(const eltwise_desc_t &desc)
1032 : jit_uni_eltwise_kernel_f32(desc)
1034 , bf16_emu_(nullptr) {
1036 is_cpx_ = mayiuse(avx512_core_bf16);
1037 bool is_bf16_ = (desc.data_desc.data_type == data_type::bf16);
1039 if (!is_cpx_ && is_bf16_)
1040 bf16_emu_ = new bf16_emulation_t(this,
1041 bf16_emu_reserv_1, bf16_emu_reserv_2,
1042 bf16_emu_reserv_3, bf16_emu_reserv_4,
1043 bf16_emu_reserv_5, bf16_emu_reserv_6);
1045 eltwise_injector_ = new jit_uni_eltwise_injector_f32<isa>(this,
1046 desc.alg_kind, desc.alpha, desc.beta, false, r9, Opmask(1));
1048 using namespace alg_kind;
1050 assert(is_bwd() == false);
1051 assert(utils::one_of(desc.alg_kind, eltwise_tanh, eltwise_elu,
1052 eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
1053 eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic,
1054 eltwise_clamp, eltwise_exp));
1059 mov(mask_reg, 0xAAAAAAAA);
1060 kmovd(k_mask, mask_reg);
1063 kmovd(k_tail_mask, mask_reg);
1065 mov(mask_reg, 0xffff);
1066 kmovd(k_full_mask, mask_reg);
1068 if (!is_cpx_ && is_bf16_)
1069 bf16_emu_->init_vcvtneps2bf16();
1071 Reg64 param = abi_param1;
1072 mov(reg_from, ptr[param + GET_OFF(from)]);
1073 mov(reg_to, ptr[param + GET_OFF(to)]);
1075 mov(p_idx_table, idx_table);
1076 vmovups(zmm_idx, ptr[p_idx_table]);
1078 mov(reg_work_amount, ptr[param + GET_OFF(work_amount)]);
1080 eltwise_injector_->load_table_addr();
1082 Label reminder_loop_start, reminder_loop_end;
1083 Label vectorized_loop_start, vectorized_loop_end;
1085 cmp(reg_work_amount, simd_w);
1086 jl(reminder_loop_start, T_NEAR);
1088 L(vectorized_loop_start);
1090 auto store_data =[&] (opmask_t _kmask) {
1092 bf16_emu_->r_vcvtneps2bf16(ymm_src, zmm_src_1);
1094 vcvtneps2bf16(ymm_src, vmm_src);
1095 vmovdqu16(ptr[reg_to] | _kmask, ymm_src);
1099 vmovups(ymm_src, ptr[reg_from]);
1100 vpermw(vmm_src | k_mask | T_z, zmm_idx, zmm_src);
1101 eltwise_injector_->compute_vector(vmm_src.getIdx());
1102 store_data(k_full_mask);
1104 uni_vmovups(vmm_src, ptr[reg_from]);
1105 eltwise_injector_->compute_vector(vmm_src.getIdx());
1106 uni_vmovups(ptr[reg_to], vmm_src);
1108 auto shift = (is_bf16_) ? vlen / 2 : vlen;
1109 add(reg_from, shift);
1112 sub(reg_work_amount, simd_w);
1113 cmp(reg_work_amount, simd_w);
1114 jge(vectorized_loop_start, T_NEAR);
1116 L(vectorized_loop_end);
1118 L(reminder_loop_start);
1120 cmp(reg_work_amount, 0);
1121 jle(reminder_loop_end, T_NEAR);
1123 vmovups(ymm_src | k_tail_mask, ptr[reg_from]);
1124 vpermw(vmm_src | k_mask | T_z, zmm_idx, zmm_src);
1125 eltwise_injector_->compute_vector(vmm_src.getIdx());
1126 store_data(k_tail_mask);
1128 movss(xmm_src, ptr[reg_from]);
1129 eltwise_injector_->compute_vector(xmm_src.getIdx());
1130 movss(ptr[reg_to], xmm_src);
1132 auto size_step = (is_bf16_) ? sizeof(mkldnn_bfloat16_t) : sizeof(float);
1133 add(reg_from, size_step);
1134 add(reg_to, size_step);
1136 dec(reg_work_amount);
1137 jmp(reminder_loop_start, T_NEAR);
1139 L(reminder_loop_end);
1143 eltwise_injector_->prepare_table();
1148 const uint16_t _idx[] = { 0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7,8,8,
1149 9,9,10,10,11,11,12,12,13,13,14,14,15,15 };
1150 for (size_t i = 0; i < sizeof(_idx) / sizeof(_idx[0]); ++i)
1154 ker_ = (decltype(ker_))this->getCode();
1157 ~jit_uni_kernel_fwd_f32() {
1158 delete eltwise_injector_;
1163 using Vmm = typename utils::conditional3<isa == sse42, Xmm,
1164 isa == avx2, Ymm, Zmm>::type;
1165 using opmask_t = const Xbyak::Opmask;
1167 const int simd_w = cpu_isa_traits<isa>::vlen / sizeof(float);
1168 const int vlen = cpu_isa_traits<isa>::vlen;
1170 Reg64 reg_from = rax;
1172 Reg64 reg_work_amount = rsi;
1173 Reg64 imm_addr64 = rbx;
1174 Reg32 mask_reg = edx;
1175 Reg64 reg_idx = r15;
1176 Reg32 reg32_tmp = r14d;
1177 Reg64 p_idx_table = r13;
1179 Xmm xmm_src = Xmm(1);
1180 Vmm vmm_src = Vmm(1);
1181 Zmm zmm_src_1 = Zmm(1);
1183 Ymm ymm_src = Ymm(30);
1184 Zmm zmm_src = Zmm(30);
1185 Zmm zmm_idx = Zmm(31);
1187 Zmm bf16_emu_reserv_1 = Zmm(26);
1188 Zmm bf16_emu_reserv_2 = Zmm(27);
1189 Zmm bf16_emu_reserv_3 = Zmm(28);
1190 Reg64 bf16_emu_reserv_4 = r14;
1191 Zmm bf16_emu_reserv_5 = Zmm(29);
1192 Zmm bf16_emu_reserv_6 = Zmm(29);
1196 opmask_t k_mask = k7;
1197 opmask_t k_tail_mask = k6;
1198 opmask_t k_full_mask = k5;
1202 jit_uni_eltwise_injector_f32<isa> *eltwise_injector_;
1203 bf16_emulation_t *bf16_emu_;
1208 template <cpu_isa_t isa, data_type_t d_type>
1209 status_t jit_uni_eltwise_fwd_t<isa, d_type>::pd_t::init() {
1210 using namespace alg_kind;
1212 assert(engine()->kind() == engine_kind::cpu);
1213 bool ok = true && mayiuse(isa)
1214 && utils::one_of(desc()->prop_kind, prop_kind::forward_training,
1215 prop_kind::forward_inference)
1216 && desc()->data_desc.data_type == d_type
1217 && !has_zero_dim_memory()
1218 && utils::one_of(desc()->alg_kind, eltwise_relu, eltwise_tanh,
1219 eltwise_elu, eltwise_square, eltwise_abs, eltwise_sqrt,
1220 eltwise_linear, eltwise_bounded_relu, eltwise_soft_relu,
1221 eltwise_logistic, eltwise_clamp, eltwise_exp)
1222 && memory_desc_wrapper(src_pd()).is_dense(true)
1223 && IMPLICATION(!memory_desc_wrapper(src_pd()).is_dense(false),
1224 math::eltwise_fwd_preserves_zero(desc()->alg_kind, true))
1225 && attr()->has_default_values();
1227 return ok ? status::success : status::unimplemented;
1230 template <cpu_isa_t isa, data_type_t d_type>
1231 jit_uni_eltwise_fwd_t<isa, d_type>::jit_uni_eltwise_fwd_t(const pd_t *apd,
1232 const input_vector &inputs, const output_vector &outputs)
1233 : cpu_primitive_t(apd, inputs, outputs), kernel_(nullptr) {
1234 const auto &desc = *pd()->desc();
1235 switch (desc.alg_kind) {
1236 case alg_kind::eltwise_relu:
1237 kernel_ = new jit_uni_relu_kernel_f32<isa>(desc); break;
1239 kernel_ = new jit_uni_kernel_fwd_f32<isa>(desc);
1243 template <cpu_isa_t isa, data_type_t d_type>
1244 jit_uni_eltwise_fwd_t<isa, d_type>::~jit_uni_eltwise_fwd_t()
1247 template <cpu_isa_t isa, data_type_t d_type>
1248 void jit_uni_eltwise_fwd_t<isa, d_type>::execute_forward() const {
1249 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
1250 auto dst = reinterpret_cast<data_t *>(this->memory(0));
1252 const memory_desc_wrapper data_d(pd()->src_pd());
1254 const size_t nelems = data_d.nelems(true);
1256 src += data_d.blocking_desc().offset_padding;
1257 dst += data_d.blocking_desc().offset_padding;
1259 const int cache_line = 16;
1260 parallel(0, utils::div_up(nelems, cache_line), [&](const int ithr, const int nthr) {
1261 size_t start{0}, end{0};
1262 balance211(utils::div_up(nelems, cache_line), nthr, ithr, start, end);
1263 start = nstl::min(nelems, start * cache_line);
1264 end = nstl::min(nelems, end * cache_line);
1266 auto arg = jit_args();
1267 arg.from = (const void*)&src[start];
1268 arg.for_comparison = (const void*)&src[start];
1269 arg.to = (const void*)&dst[start];
1270 arg.work_amount = end - start;
1271 if (arg.work_amount)
1276 template <cpu_isa_t isa, data_type_t d_type>
1277 status_t jit_uni_eltwise_bwd_t<isa, d_type>::pd_t::init() {
1278 assert(engine()->kind() == engine_kind::cpu);
1281 && desc()->prop_kind == prop_kind::backward_data
1282 && utils::one_of(desc()->alg_kind, alg_kind::eltwise_relu)
1283 && src_pd()->desc()->data_type == d_type
1284 && !has_zero_dim_memory()
1286 && memory_desc_wrapper(src_pd()).is_dense()
1287 && memory_desc_wrapper(diff_dst_pd()) == memory_desc_wrapper(src_pd())
1288 && attr()->has_default_values();
1290 return ok ? status::success : status::unimplemented;
1293 template <cpu_isa_t isa, data_type_t d_type>
1294 jit_uni_eltwise_bwd_t<isa, d_type>::jit_uni_eltwise_bwd_t(const pd_t *apd,
1295 const input_vector &inputs, const output_vector &outputs)
1296 : cpu_primitive_t(apd, inputs, outputs), kernel_(nullptr) {
1297 const auto &desc = *pd()->desc();
1298 switch (desc.alg_kind) {
1299 case alg_kind::eltwise_relu:
1300 kernel_ = new jit_uni_relu_kernel_f32<isa>(desc); break;
1301 default: assert(!"unknown eltwise alg_kind");
1305 template <cpu_isa_t isa, data_type_t d_type>
1306 jit_uni_eltwise_bwd_t<isa, d_type>::~jit_uni_eltwise_bwd_t()
1309 template <cpu_isa_t isa, data_type_t d_type>
1310 void jit_uni_eltwise_bwd_t<isa, d_type>::execute_backward() const {
1311 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
1312 auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
1313 auto diff_src = reinterpret_cast<data_t *>(this->memory(0));
1315 const memory_desc_wrapper data_d(pd()->src_pd());
1316 const memory_desc_wrapper diff_data_d(pd()->diff_src_pd());
1318 const size_t nelems = data_d.nelems();
1320 src += data_d.blocking_desc().offset_padding;
1321 diff_dst += diff_data_d.blocking_desc().offset_padding;
1322 diff_src += diff_data_d.blocking_desc().offset_padding;
1324 const int cache_line = 16;
1326 parallel(0, utils::div_up(nelems, cache_line), [&](const int ithr, const int nthr) {
1327 size_t start{0}, end{0};
1329 balance211(utils::div_up(nelems, cache_line), nthr, ithr, start, end);
1330 start = nstl::min(nelems, start * cache_line);
1331 end = nstl::min(nelems, end * cache_line);
1333 auto arg = jit_args();
1334 arg.from = (const void*)&diff_dst[start];
1335 arg.to = (const void*)&diff_src[start];
1336 arg.for_comparison = (const void*)&src[start];
1337 arg.work_amount = end - start;
1338 if (arg.work_amount) {
1344 template struct jit_uni_eltwise_fwd_t<sse42, data_type::f32>;
1345 template struct jit_uni_eltwise_bwd_t<sse42, data_type::f32>;
1346 template struct jit_uni_eltwise_fwd_t<avx2, data_type::f32>;
1347 template struct jit_uni_eltwise_bwd_t<avx2, data_type::f32>;
1348 template struct jit_uni_eltwise_fwd_t<avx512_common, data_type::f32>;
1349 template struct jit_uni_eltwise_fwd_t<avx512_common, data_type::bf16>;
1350 template struct jit_uni_eltwise_bwd_t<avx512_common, data_type::f32>;
1351 template struct jit_uni_eltwise_bwd_t<avx512_common, data_type::bf16>;