1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include <mkldnn_types.h>
18 #include "mkldnn_types.h"
19 #include "mkldnn_thread.hpp"
22 #include "jit_generator.hpp"
24 #include "jit_uni_eltwise.hpp"
26 #define GET_OFF(field) offsetof(jit_args, field)
32 using namespace Xbyak;
34 template <cpu_isa_t isa>
35 void jit_uni_eltwise_injector_f32<isa>::injector_preamble(size_t start_idx,
37 preserved_vecs_count = 0;
38 vecs_to_preserve = (size_t)aux_vecs_count(alg_);
39 start_idx_tail = start_idx;
41 // For sse42 mask register has to be Xmm(0)
42 if (isa == sse42 && vecs_to_preserve > 0) {
44 assert(idx < start_idx);
45 preserved_vec_idxs[preserved_vecs_count++] = idx;
48 for (size_t idx = preserved_vecs_count; idx < vecs_count; idx++) {
49 if (preserved_vecs_count >= vecs_to_preserve) break;
50 if (start_idx <= idx && idx < end_idx) continue;
52 preserved_vec_idxs[preserved_vecs_count++] = idx;
55 size_t preserved_vecs_count_tail = vecs_to_preserve - preserved_vecs_count;
56 for (size_t i = 0; i < preserved_vecs_count_tail; i++) {
57 preserved_vec_idxs[preserved_vecs_count++] = start_idx_tail++;
60 assert(preserved_vecs_count == vecs_to_preserve);
65 if (preserved_vecs_count)
66 h->sub(h->rsp, preserved_vecs_count * vlen);
68 for (size_t i = 0; i < preserved_vecs_count; ++i)
69 h->uni_vmovups(h->ptr[h->rsp + i * vlen],
70 Vmm(preserved_vec_idxs[i]));
78 template <cpu_isa_t isa>
79 void jit_uni_eltwise_injector_f32<isa>::injector_preamble_tail(size_t start_idx)
81 size_t tail_vecs_to_preserve = start_idx_tail - start_idx;
82 if (tail_vecs_to_preserve == 0) return;
84 const int idx_off = vecs_to_preserve - tail_vecs_to_preserve;
88 h->add(h->rsp, idx_off * vlen);
90 for (size_t i = 0; i < tail_vecs_to_preserve; ++i)
91 h->uni_vmovups(Vmm(preserved_vec_idxs[idx_off + i]),
92 h->ptr[h->rsp + i * vlen]);
95 for (size_t i = 0; i < tail_vecs_to_preserve; ++i)
96 preserved_vec_idxs[idx_off + i] += tail_vecs_to_preserve;
99 for (size_t i = 0; i < tail_vecs_to_preserve; ++i)
100 h->uni_vmovups(h->ptr[h->rsp + i * vlen],
101 Vmm(preserved_vec_idxs[idx_off + i]));
104 h->sub(h->rsp, idx_off * vlen);
110 template <cpu_isa_t isa>
111 void jit_uni_eltwise_injector_f32<isa>::injector_postamble() {
112 if (!save_state_) return;
114 for (size_t i = 0; i < preserved_vecs_count; ++i)
115 h->uni_vmovups(Vmm(preserved_vec_idxs[i]),
116 h->ptr[h->rsp + i * vlen]);
118 if (preserved_vecs_count)
119 h->add(h->rsp, preserved_vecs_count * vlen);
124 template <cpu_isa_t isa>
125 void jit_uni_eltwise_injector_f32<isa>::assign_regs() {
126 vmm_mask = Vmm(preserved_vec_idxs[0]);
127 vmm_aux0 = Vmm(preserved_vec_idxs[0]);
128 vmm_aux1 = Vmm(preserved_vec_idxs[1]);
129 vmm_aux2 = Vmm(preserved_vec_idxs[2]);
130 vmm_aux3 = Vmm(preserved_vec_idxs[3]);
131 vmm_aux4 = Vmm(preserved_vec_idxs[4]);
134 template <cpu_isa_t isa>
135 void jit_uni_eltwise_injector_f32<isa>::exp_compute_vector(const Vmm &vmm_src) {
136 h->uni_vminps(vmm_src, vmm_src, table_val(10));
137 h->uni_vmaxps(vmm_src, vmm_src, table_val(11));
138 h->uni_vmovups(vmm_aux0, vmm_src);
140 // fx = x * log2ef + 0.5
141 h->uni_vmulps(vmm_src, vmm_src, table_val(2));
142 h->uni_vaddps(vmm_src, vmm_src, table_val(1));
145 if (isa == avx512_common) {
146 h->vcvtps2dq(vmm_aux1 | h->T_rd_sae, vmm_src);
147 h->vcvtdq2ps(vmm_aux1, vmm_aux1);
149 h->vcmpps(k_mask, vmm_aux1, vmm_src, _cmp_nle_us);
150 h->vmovups(vmm_aux3 | k_mask | h->T_z, table_val(0));
152 h->uni_vsubps(vmm_aux1, vmm_aux1, vmm_aux3);
154 h->uni_vroundps(vmm_aux1, vmm_src, _op_floor);
157 //keep fx for further computations
158 h->uni_vmovups(vmm_src, vmm_aux1); //vmm_src = fx
161 h->uni_vfnmadd231ps(vmm_aux0, vmm_aux1, table_val(3));
164 h->uni_vcvtps2dq(vmm_aux1, vmm_src);
165 h->uni_vpaddd(vmm_aux1, vmm_aux1, table_val(4));
166 h->uni_vpslld(vmm_aux1, vmm_aux1, 23); //Vmm(6) = 2^-fx
169 h->uni_vmovups(vmm_src, table_val(9));
171 h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(8));
173 h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(7));
175 h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(6));
177 h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(0));
179 h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(5)); //exp(q)
181 h->uni_vmulps(vmm_src, vmm_src, vmm_aux1);
184 template <cpu_isa_t isa>
185 void jit_uni_eltwise_injector_f32<isa>::relu_compute_vector(const Vmm &vmm_src)
187 const int alpha_off = 0, zero_off = 1;
189 h->uni_vmovups(vmm_aux1, vmm_src);
191 h->movups(vmm_mask, vmm_src);
192 h->mulps(vmm_src, table_val(alpha_off));
193 h->cmpps(vmm_mask, table_val(zero_off), _cmp_nle_us);
194 h->blendvps(vmm_src, vmm_aux1);
195 } else if (isa == avx2) {
196 h->vmulps(vmm_src, vmm_src, table_val(alpha_off));
197 h->vcmpgtps(vmm_mask, vmm_aux1, table_val(zero_off));
198 h->vblendvps(vmm_src, vmm_src, vmm_aux1, vmm_mask);
199 } else if (isa == avx512_common) {
200 h->vmulps(vmm_src, vmm_src, table_val(alpha_off));
201 h->vcmpps(k_mask, vmm_aux1, table_val(zero_off), _cmp_nle_us);
202 h->vblendmps(vmm_src | k_mask, vmm_src, vmm_aux1);
206 template <cpu_isa_t isa>
207 void jit_uni_eltwise_injector_f32<isa>::relu_zero_ns_compute_vector(
208 const Vmm &vmm_src) {
209 const int zero_off = 1;
210 h->uni_vmaxps(vmm_src, vmm_src, table_val(zero_off));
213 template <cpu_isa_t isa>
214 void jit_uni_eltwise_injector_f32<isa>::elu_compute_vector(const Vmm &vmm_src) {
215 const int alpha_off = 23, zero_off = 24;
218 h->uni_vmovups(vmm_aux2, vmm_src);
219 exp_compute_vector(vmm_src);
221 // alpha * (exp(x) - 1)
222 h->uni_vsubps(vmm_src, vmm_src, table_val(0));
223 h->uni_vmulps(vmm_src, vmm_src, table_val(alpha_off));
227 h->pxor(vmm_mask, vmm_mask);
228 h->cmpps(vmm_mask, vmm_aux2, _cmp_le_os);
229 h->blendvps(vmm_src, vmm_aux2);
230 } else if (isa == avx2) {
231 h->uni_vcmpgtps(vmm_mask, vmm_aux2, table_val(zero_off));
232 h->uni_vblendvps(vmm_src, vmm_src, vmm_aux2, vmm_mask);
233 } else if (isa == avx512_common) {
234 h->vcmpps(k_mask, vmm_aux2, table_val(zero_off), _cmp_nle_us);
235 h->vblendmps(vmm_src | k_mask, vmm_src, vmm_aux2);
239 template <cpu_isa_t isa>
240 void jit_uni_eltwise_injector_f32<isa>::tanh_compute_vector(const Vmm &vmm_src)
242 // # comes from Taylor expansion error bound
243 // > linear_sat_point = single(sqrt(3) * 1b-12);
244 // # comes from the exp formula cancellation
245 // > exp_bound_point = (single(log(3)/2));
246 // # comes from rounding accuracy in float
247 // > one_sat_point = round(atanh(1 - 1b-25), single, RU);
248 // > P = fpminimax(f, [|1, 3, 5, 7, 9|], [|24... |],
249 // [linear_sat_point, exp_bound_point], relative, floating);
250 // > err_bound = D(sup(supnorm(P, tanh(x),
251 // [linear_sat_point, exp_bound_point], relative, theta)));
252 // 0x1.fffd6f00b9539p-25
254 // x * (0x1.fffffep-1 + x^0x1p1 * (-0x1.55539ep-2 + x^0x1p1 *
255 // (0x1.10be3ep-3 + x^0x1p1 * (-0x1.ae57b4p-5
256 // + x^0x1p1 * 0x1.09fa1p-6))))
259 // vmm_src contains input
260 // vmm_aux0 contains mask of currently valid results.
261 // 1 is need computation, 0 is already computed
262 // vmm_aux1 contains current output
263 // vmm_aux2, vmm_aux3 contains auxiliary values
264 // vmm_aux4 contains the original sign of inputs
266 Label end_tanh_label;
268 auto test_exit =[&](Xbyak::Address threshold){
269 // is not necessary for >AVX, but should not matter on perf
270 h->uni_vmovups(vmm_aux0, vmm_src);
271 if (isa == avx512_common){
272 h->vcmpps(k_mask, vmm_aux0, threshold, 0x5);
273 h->kortestw(k_mask, k_mask);
275 h->uni_vcmpgeps(vmm_aux0, vmm_aux0, threshold);
276 h->uni_vtestps(vmm_aux0, vmm_aux0);
278 h->jz(end_tanh_label, Xbyak::CodeGenerator::T_NEAR);
281 auto blend_results=[&](Vmm vmm_partial_res){
282 if (isa == avx512_common)
283 h->vblendmps(vmm_aux1 | k_mask, vmm_aux1, vmm_partial_res);
285 h->uni_vblendvps(vmm_aux1, vmm_aux1, vmm_partial_res, vmm_aux0);
288 // because tanh(x) = -tanh(-x), we extract sign to make x postive
289 // and reapply sign at the end
290 // mov is not necessary for >AVX, but should not matter for performance
291 h->uni_vmovups(vmm_aux4, vmm_src);
292 h->uni_vandps(vmm_aux4, vmm_aux4, table_val(12));
293 h->uni_vandps(vmm_src, vmm_src, table_val(17));
295 // if x < linear_sat_point for all inputs, we just return the input
296 h->uni_vmovups(vmm_aux1, vmm_src);
297 test_exit(table_val(13));
299 // if one of the mask is one, we have to compute an better approx
300 h->uni_vmovups(vmm_aux2, vmm_src);
301 h->uni_vmulps(vmm_aux2, vmm_aux2, vmm_aux2);
302 h->uni_vmovups(vmm_aux3, table_val(22));
303 h->uni_vfmadd213ps(vmm_aux3, vmm_aux2, table_val(21));
304 h->uni_vfmadd213ps(vmm_aux3, vmm_aux2, table_val(20));
305 h->uni_vfmadd213ps(vmm_aux3, vmm_aux2, table_val(19));
306 h->uni_vfmadd213ps(vmm_aux3, vmm_aux2, table_val(18));
307 h->uni_vmulps(vmm_aux3, vmm_aux3, vmm_src);
309 // we blend only the result that need update
310 blend_results(vmm_aux3);
312 // if x < exp_bound_point, we go to return point
313 test_exit(table_val(14));
315 // if not we use a better approx 1 - 2 / (1 + exp(2x))
317 h->uni_vmovups(vmm_aux3, vmm_src);
318 h->uni_vaddps(vmm_aux3, vmm_aux3, vmm_aux3);
321 // We need to save kmask, vmm_aux0, vmm_aux1 and vmm_src as exp can use them
322 // vmm_src is not more read afterwards, so we do not have to save it
323 auto stack_size = 3 * vlen + (isa == avx512_common) * 4;
324 h->sub(h->rsp, stack_size);
325 h->uni_vmovups(h->ptr[h->rsp + 0 * vlen], vmm_aux0);
326 h->uni_vmovups(h->ptr[h->rsp + 1 * vlen], vmm_aux1);
327 h->uni_vmovups(h->ptr[h->rsp + 2 * vlen], vmm_src);
328 if (isa == avx512_common)
329 h->kmovw(h->ptr[h->rsp + 3 * vlen], k_mask);
331 exp_compute_vector(vmm_aux3);
333 h->uni_vmovups(vmm_aux0, h->ptr[h->rsp + 0 * vlen]);
334 h->uni_vmovups(vmm_aux1, h->ptr[h->rsp + 1 * vlen]);
335 h->uni_vmovups(vmm_src, h->ptr[h->rsp + 2 * vlen]);
336 if (isa == avx512_common)
337 h->kmovw(k_mask, h->ptr[h->rsp + 3 * vlen]);
338 h->add(h->rsp, stack_size);
341 h->uni_vaddps(vmm_aux3, vmm_aux3, table_val(0));
343 // 1 - 2 / (1 + exp(2x))
344 h->uni_vmovups(vmm_aux2, table_val(16));
345 h->uni_vdivps(vmm_aux2, vmm_aux2, vmm_aux3);
346 h->uni_vaddps(vmm_aux2, vmm_aux2, table_val(0));
348 // we blend only the result that need update
349 blend_results(vmm_aux2);
351 // finally, we saturate to 1 if needed
352 // TODO: maybe move that up if most inputs saturate in practice
353 if (isa == avx512_common)
354 h->vcmpps(k_mask, vmm_aux0, table_val(15), 0x5);
356 h->uni_vmovups(vmm_aux0, vmm_src);
357 h->uni_vcmpgeps(vmm_aux0, vmm_aux0, table_val(15));
359 h->uni_vmovups(vmm_aux2, table_val(0));
360 blend_results(vmm_aux2);
362 h->L(end_tanh_label);
364 // we apply the sign of x to the result and we are done
365 h->uni_vmovups(vmm_src, vmm_aux1);
366 h->uni_vpxor(vmm_src, vmm_src, vmm_aux4);
370 template <cpu_isa_t isa>
371 void jit_uni_eltwise_injector_f32<isa>::square_compute_vector(
372 const Vmm &vmm_src) {
373 h->uni_vmulps(vmm_src, vmm_src, vmm_src);
376 template <cpu_isa_t isa>
377 void jit_uni_eltwise_injector_f32<isa>::abs_compute_vector(const Vmm &vmm_src) {
378 // compute abs(x) = _mm_and_ps(x, 01111..111));
379 h->uni_vandps(vmm_src, vmm_src, table_val(0));
382 template <cpu_isa_t isa>
383 void jit_uni_eltwise_injector_f32<isa>::sqrt_compute_vector(const Vmm &vmm_src)
385 if (isa == avx512_common) {
386 h->vcmpps(k_mask, vmm_src, table_val(0), _cmp_nle_us);
387 h->uni_vsqrtps(vmm_aux1, vmm_src);
388 h->uni_vmovups(vmm_src, table_val(0));
389 h->vblendmps(vmm_src | k_mask, vmm_src, vmm_aux1);
391 h->uni_vmovups(vmm_mask, vmm_src);
392 h->uni_vcmpgtps(vmm_mask, vmm_mask, table_val(0));
393 h->uni_vsqrtps(vmm_aux1, vmm_src);
394 h->uni_vmovups(vmm_src, table_val(0));
395 h->uni_vblendvps(vmm_src, vmm_src, vmm_aux1, vmm_mask);
399 template <cpu_isa_t isa>
400 void jit_uni_eltwise_injector_f32<isa>::linear_compute_vector(
401 const Vmm &vmm_src) {
402 // compute x = alpha * x + beta;
403 h->uni_vmovups(vmm_aux0, table_val(0));
404 h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(1));
407 template <cpu_isa_t isa>
408 void jit_uni_eltwise_injector_f32<isa>::bounded_relu_compute_vector(
409 const Vmm &vmm_src) {
410 // compute bounded relu */
411 h->uni_vmaxps(vmm_src, vmm_src, table_val(1));
412 h->uni_vminps(vmm_src, vmm_src, table_val(0));
415 template <cpu_isa_t isa>
416 void jit_uni_eltwise_injector_f32<isa>::soft_relu_compute_vector(
417 const Vmm &vmm_src) {
419 h->uni_vmovups(vmm_aux2, vmm_src);
421 h->uni_vminps(vmm_src, vmm_src, table_val(24));
422 h->uni_vmaxps(vmm_src, vmm_src, table_val(25));
423 h->uni_vmovups(vmm_aux1, vmm_src);
425 // fx = x * log2ef + 0.5
426 h->uni_vmulps(vmm_src, vmm_src, table_val(2));
427 h->uni_vaddps(vmm_src, vmm_src, table_val(1));
430 if (isa == avx512_common) {
431 h->vcvtps2dq(vmm_aux0 | h->T_rd_sae, vmm_src);
432 h->vcvtdq2ps(vmm_aux0, vmm_aux0);
434 h->vcmpps(k_mask, vmm_aux0, vmm_src, _cmp_nle_us);
435 h->vmovups(vmm_aux3 | k_mask | h->T_z, table_val(0));
437 h->vsubps(vmm_aux0, vmm_aux0, vmm_aux3);
439 h->uni_vroundps(vmm_aux0, vmm_src, _op_floor);
442 // keep fx for further computations
443 h->uni_vmovups(vmm_src, vmm_aux0); //vmm_src = fx
444 // calculation fx * ln2
445 h->uni_vmulps(vmm_aux0, vmm_aux0, table_val(3));
447 h->uni_vsubps(vmm_aux1, vmm_aux1, vmm_aux0);
449 h->uni_vmovups(vmm_aux3, table_val(22));
451 h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(21));
453 h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(20));
455 h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(19));
457 h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(0));
459 h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(17));
462 if (isa == avx512_common) {
463 h->vmulps(vmm_aux1, vmm_src, table_val(23));
464 h->vcvtps2dq(vmm_aux1, vmm_aux1);
466 h->uni_vcvtps2dq(vmm_aux1, vmm_src);
467 h->uni_vpsignd(vmm_aux1, vmm_aux1, table_val(23));
470 h->uni_vpaddd(vmm_aux1, vmm_aux1, table_val(4));
471 h->uni_vpslld(vmm_aux1, vmm_aux1, 23); //vmm_aux1 = 2^-fx
472 // calculate ln(1 + y)
473 h->uni_vaddps(vmm_aux3, vmm_aux3, vmm_aux1);
474 // x = y; y is free; keep x for further computations
475 h->uni_vmovups(vmm_src, vmm_aux3);
477 h->uni_vpsrld(vmm_src, vmm_src, 23);
478 h->uni_vcvtdq2ps(vmm_src, vmm_src);
479 // got n. where n is x = 2^n * y. y = 0.5 .. 1
480 h->uni_vsubps(vmm_src, vmm_src, table_val(5));
482 h->uni_vandps(vmm_aux3, vmm_aux3, table_val(6));
483 // got y. (mantisa) 0.5 < y < 1
484 h->uni_vorps(vmm_aux3, vmm_aux3, table_val(7));
486 h->uni_vsubps(vmm_aux3, vmm_aux3, table_val(0));
488 h->uni_vmovups(vmm_aux1, table_val(16));
490 h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(15));
492 h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(14));
494 h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(13));
496 h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(12));
498 h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(11));
500 h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(10));
502 h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(9));
503 // y = y * x + p0 ; p0 = 0
504 h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(8));
505 //calculate ln(2) * n
506 h->uni_vmulps(vmm_src, vmm_src, table_val(3));
507 h->uni_vaddps(vmm_aux1, vmm_aux1, vmm_src);
508 h->uni_vaddps(vmm_aux1, vmm_aux1, vmm_aux0);
510 // get vmm_mask = src > max logf
511 h->uni_vmovups(vmm_mask, vmm_aux2);
512 if (isa == avx512_common) {
513 // y = (x < max log f) ? soft_relu(x) : x
514 h->vcmpps(k_mask, vmm_mask, table_val(24), _cmp_nle_us);
515 h->vblendmps(vmm_aux1 | k_mask, vmm_aux1, vmm_aux2);
517 // y = (x < max log f) ? soft_relu(x) : x
518 h->uni_vcmpgtps(vmm_mask, vmm_mask, table_val(24));
519 h->uni_vblendvps(vmm_aux1, vmm_aux1, vmm_aux2, vmm_mask);
522 h->uni_vmovups(vmm_src, vmm_aux1);
525 template <cpu_isa_t isa>
526 void jit_uni_eltwise_injector_f32<isa>::logistic_compute_vector(
527 const Vmm &vmm_src) {
528 // we store the original sign and make x negative
529 // IMPORTANT: we assume vmm_aux0 to be xmm0, as for sse4.2 path it is required
530 // IMPORTANT: we use vmm_aux2 for the mask as exp_compute does not use it.
531 h->uni_vmovups(vmm_aux2, vmm_src);
532 h->uni_vandps(vmm_aux2, vmm_aux2, table_val(12));
533 h->uni_vorps(vmm_src, vmm_src, table_val(12));
535 exp_compute_vector(vmm_src);
537 h->uni_vmovups(vmm_aux1, vmm_src);
539 h->uni_vaddps(vmm_aux1, vmm_aux1, table_val(0));
540 // y = exp(x) / (exp(x) + 1)
541 h->uni_vdivps(vmm_src, vmm_src, vmm_aux1);
543 // Now we have to apply the "symmetry" based on original sign
544 h->uni_vmovups(vmm_aux3, table_val(0));
545 h->uni_vsubps(vmm_aux3, vmm_aux3, vmm_src);
546 if (isa == avx512_common) {
547 h->vptestmd(k_mask, vmm_aux2, vmm_aux2);
548 h->vblendmps(vmm_aux3 | k_mask, vmm_aux3, vmm_src);
550 h->uni_vmovups(vmm_aux0, vmm_aux2);// The mask should be xmm0 for sse4.2
551 h->uni_vblendvps(vmm_aux3, vmm_aux3, vmm_src, vmm_aux0);
553 h->uni_vmovups(vmm_src, vmm_aux3);
556 template <cpu_isa_t isa>
557 void jit_uni_eltwise_injector_f32<isa>::clamp_compute_vector(
558 const Vmm &vmm_src) {
560 h->uni_vmaxps(vmm_src, vmm_src, table_val(1));
561 h->uni_vminps(vmm_src, vmm_src, table_val(0));
564 template <cpu_isa_t isa>
565 void jit_uni_eltwise_injector_f32<isa>::relu_prepare_table() {
566 for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_));
567 for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0);
570 template <cpu_isa_t isa>
571 void jit_uni_eltwise_injector_f32<isa>::elu_prepare_table() {
572 const unsigned int cvals[] = {
573 0x3f800000, // [0] 1.0f
574 0x3f000000, // [1] 0.5f
575 0x3fb8aa3b, // [2] log2ef = 1.44269502f
576 0x3f317218, // [3] ln2f = 0.69314718f
577 0x0000007f, // [4] 0x7f
579 0x3f800001, // [5] p0 = 1.0000001f
580 0x3efffe85, // [6] p2 = 0.4999887f
581 0x3e2aaa3e, // [7] p3 = 0.16666505f
582 0x3d2bb1b1, // [8] p4 = 0.041917507f
583 0x3c091ec1, // [9] p5 = 0.008369149f
584 0x42b0c0a5, //[10] max logf = 88.3762589f
585 0xc1766666, //[11] min logf = -14.5f
586 // tanh(x) constants,
587 0x80000000, //[12] mask to extract sign
588 0x39ddb3d7, //[13] arg below which tanh(x) = x
589 0x3f0c9f54, //[14] arg below which pol approx is valid
590 0x41102cb4, //[15] arg after which tanh(x) = 1
591 0xc0000000, //[16] -2.0f
592 0x7fffffff, //[17] mask to make positive
594 0x3f7fffff, //[18] p0
595 0xbeaaa9cf, //[19] p1
596 0x3e085f1f, //[20] p2
597 0xbd572bda, //[21] p3
598 0x3c84fd08, //[22] p4
601 for (size_t i = 0; i < sizeof(cvals) / sizeof(cvals[0]); ++i) {
602 for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(cvals[i]);
605 for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_));
606 for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0);
609 template <cpu_isa_t isa>
610 void jit_uni_eltwise_injector_f32<isa>::soft_relu_prepare_table() {
611 const unsigned int cvals[] = {
612 0x3f800000, // [0] 1.0f
613 0x3f000000, // [1] 0.5f
614 0x3fb8aa3b, // [2] log2ef = 1.44269502f
615 0x3f317218, // [3] ln2f = 0.69314718f
616 0x0000007f, // [4] 0x7f
617 0x42fc0000, // [5] 126
618 0x807fffff, // [6] and with (to get 0.5 * mantissa)
619 0x3f000000, // [7] or with (to get 0.5 * mantissa)
620 // ln(1 + x) polynomial
621 0xb2b4637d, // [8] p0 = 0.0000000244f
622 0x3f7fff8e, // [9] p1 = 0.9999976971f
623 0xbf001759, //[10] p2 = -0.5002478215f
624 0x3ea70608, //[11] p3 = 0.3272714505f
625 0xbea3d7bf, //[12] p4 = -0.3153830071f
626 0xbe361d04, //[13] p5 = -0.1701777461f
627 0xbfa8f1e6, //[14] p6 = -1.3254635147f
628 0xbfe1e812, //[15] p7 = -1.7971917960f
629 0xbfc4d30e, //[16] p8 = -1.5652673123f
631 0x3f800001, //[17] p0 = 1.0000001f
632 0x3f800000, //[18] p1 = 1.0f
633 0x3efffe85, //[19] p2 = 0.4999887f
634 0x3e2aaa3e, //[20] p3 = 0.16666505f
635 0x3d2bb1b1, //[21] p4 = 0.041917507f
636 0x3c091ec1, //[22] p5 = 0.008369149f
637 0xbf800000, //[23] is required for sign changing
638 0x42b0c0a5, //[24] max logf = 88.3762589f
639 0xc1766666 //[25] min logf = -14.5f
642 for (size_t i = 0; i < sizeof(cvals) / sizeof(cvals[0]); ++i) {
643 for (size_t d = 0; d < vlen / sizeof(float); ++d) {
649 template <cpu_isa_t isa>
650 void jit_uni_eltwise_injector_f32<isa>::abs_prepare_table() {
651 for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0x7fffffff);
654 template <cpu_isa_t isa>
655 void jit_uni_eltwise_injector_f32<isa>::sqrt_prepare_table() {
656 for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0);
659 template <cpu_isa_t isa>
660 void jit_uni_eltwise_injector_f32<isa>::linear_prepare_table() {
661 for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_));
662 for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(beta_));
665 template <cpu_isa_t isa>
666 void jit_uni_eltwise_injector_f32<isa>::bounded_relu_prepare_table() {
667 for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_));
668 for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0);
671 template <cpu_isa_t isa>
672 void jit_uni_eltwise_injector_f32<isa>::clamp_prepare_table() {
673 for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_));
674 for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(beta_));
677 template <cpu_isa_t isa>
678 int jit_uni_eltwise_injector_f32<isa>::aux_vecs_count(alg_kind_t alg_) {
680 case alg_kind::eltwise_relu: return (alpha_ == 0.f) ? 0 : 2;
681 case alg_kind::eltwise_elu: return 4;
682 case alg_kind::eltwise_tanh: return 5;
683 case alg_kind::eltwise_square: return 0;
684 case alg_kind::eltwise_abs: return 0;
685 case alg_kind::eltwise_sqrt: return 2;
686 case alg_kind::eltwise_linear: return 1;
687 case alg_kind::eltwise_bounded_relu: return 0;
688 case alg_kind::eltwise_soft_relu: return 4;
689 case alg_kind::eltwise_logistic: return 4;
690 case alg_kind::eltwise_clamp: return 0;
691 case alg_kind::eltwise_exp: return 4;
692 default: assert(!"unsupported eltwise algorithm");
698 template <cpu_isa_t isa>
699 void jit_uni_eltwise_injector_f32<isa>::compute_body(size_t start_idx,
701 using namespace alg_kind;
702 for (size_t idx = start_idx; idx < end_idx; idx++) {
705 if (alpha_ == 0.f) relu_zero_ns_compute_vector(Vmm(idx));
706 else relu_compute_vector(Vmm(idx));
708 case eltwise_elu: elu_compute_vector(Vmm(idx)); break;
709 case eltwise_tanh: tanh_compute_vector(Vmm(idx)); break;
710 case eltwise_square: square_compute_vector(Vmm(idx)); break;
711 case eltwise_abs: abs_compute_vector(Vmm(idx)); break;
712 case eltwise_sqrt: sqrt_compute_vector(Vmm(idx)); break;
713 case eltwise_linear: linear_compute_vector(Vmm(idx)); break;
714 case eltwise_bounded_relu: bounded_relu_compute_vector(Vmm(idx)); break;
715 case eltwise_soft_relu: soft_relu_compute_vector(Vmm(idx)); break;
716 case eltwise_logistic: logistic_compute_vector(Vmm(idx)); break;
717 case eltwise_clamp: clamp_compute_vector(Vmm(idx)); break;
718 case eltwise_exp: exp_compute_vector(Vmm(idx)); break;
719 default: assert(!"unsupported eltwise algorithm");
724 template <cpu_isa_t isa>
725 void jit_uni_eltwise_injector_f32<isa>::compute_vector_range(size_t start_idx,
727 assert(start_idx < end_idx && end_idx <= vecs_count);
729 injector_preamble(start_idx, end_idx);
730 compute_body(start_idx_tail, end_idx);
731 injector_preamble_tail(start_idx);
732 compute_body(start_idx, start_idx_tail);
733 injector_postamble();
736 template <cpu_isa_t isa>
737 void jit_uni_eltwise_injector_f32<isa>::prepare_table(bool gen_table) {
738 using namespace alg_kind;
745 case eltwise_relu: relu_prepare_table(); break;
748 case eltwise_logistic:
750 elu_prepare_table(); break;
751 case eltwise_soft_relu: soft_relu_prepare_table(); break;
752 case eltwise_abs: abs_prepare_table(); break;
753 case eltwise_sqrt: sqrt_prepare_table(); break;
754 case eltwise_linear: linear_prepare_table(); break;
755 case eltwise_bounded_relu: bounded_relu_prepare_table(); break;
756 case eltwise_square: break;
757 case eltwise_clamp: clamp_prepare_table(); break;
758 default: assert(!"unsupported eltwise algorithm");
763 template struct jit_uni_eltwise_injector_f32<avx512_common>;
764 template struct jit_uni_eltwise_injector_f32<avx2>;
765 template struct jit_uni_eltwise_injector_f32<sse42>;
770 const float *for_comparison;
775 struct jit_uni_eltwise_kernel_f32 : public c_compatible {
776 const eltwise_desc_t &desc_;
778 void (*ker_)(const jit_args *);
779 void operator()(const jit_args *args) { assert(ker_); ker_(args); }
781 jit_uni_eltwise_kernel_f32(const eltwise_desc_t &desc)
782 : desc_(desc), ker_(nullptr) {}
783 virtual ~jit_uni_eltwise_kernel_f32() {}
786 bool is_bwd() const { return desc_.prop_kind == prop_kind::backward_data; }
792 template <cpu_isa_t isa>
793 struct jit_uni_relu_kernel_f32 : public jit_uni_eltwise_kernel_f32,
796 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_relu_kernel_f32)
798 void compute_step(bool vectorize, const int uf, const int shift) {
799 for (int i = 0; i < uf; i++) {
801 uni_vmovups(Vmm(i + 1), ptr[reg_from + i * shift]);
803 uni_vmovups(Vmm(uf + i + 1),
804 ptr[reg_for_comparison + i * shift]);
806 movss(Xmm(i + 1), ptr[reg_from + i * shift]);
808 movss(Xmm(uf + i + 1),
809 ptr[reg_for_comparison + i * shift]);
814 for (int i = 0; i < uf; i++) {
815 movups(Vmm(2 * uf + i + 1), Vmm(i + 1));
816 mulps(Vmm(2 * uf + i + 1), vmm_ns);
820 movups(mask, Vmm(uf + i + 1));
821 cmpps(mask, vmm_zero, _cmp_nle_us);
823 movups(mask, Vmm(i + 1));
824 cmpps(mask, vmm_zero, _cmp_nle_us);
826 blendvps(Vmm(2 * uf + i + 1), Vmm(i + 1));
829 for (int i = 0; i < uf; i++) {
830 vmulps(Vmm(2 * uf + i + 1), Vmm(i + 1), vmm_ns);
833 vcmpgtps(vmm_mask, Vmm(uf + i + 1), vmm_zero);
835 vcmpgtps(vmm_mask, Vmm(i + 1), vmm_zero);
837 vblendvps(Vmm(2 * uf + i + 1), Vmm(2 * uf + i + 1),
838 Vmm(i + 1), vmm_mask);
842 vcmpps(k_mask, Vmm(uf + i + 1), vmm_zero, _cmp_nle_us);
844 vcmpps(k_mask, Vmm(i + 1), vmm_zero, _cmp_nle_us);
845 vblendmps(Vmm(2 * uf + i + 1) | k_mask, Vmm(2 * uf + i + 1),
851 for (int i = 0; i < uf; i++) {
853 uni_vmovups(ptr[reg_to + i * shift], Vmm(2 * uf + i + 1));
855 movss(ptr[reg_to + i * shift], Xmm(2 * uf + i + 1));
860 jit_uni_relu_kernel_f32(const eltwise_desc_t &desc)
861 : jit_uni_eltwise_kernel_f32(desc), jit_generator() {
862 assert(desc.alg_kind == alg_kind::eltwise_relu);
863 assert(isa == sse42 || isa == avx2 || isa == avx512_common);
865 Reg64 param = abi_param1;
867 const int simd_w = cpu_isa_traits<isa>::vlen / sizeof(float);
868 const int loop_dec[] = {simd_w, 1};
869 const int uf[] = {1, 1};
870 const int shift[] = {cpu_isa_traits<isa>::vlen, sizeof(float)};
871 const bool loop_vectorize[] = {true, false};
875 mov(reg_from, ptr[param + GET_OFF(from)]);
877 mov(reg_for_comparison, ptr[param + GET_OFF(for_comparison)]);
878 mov(reg_to, ptr[param + GET_OFF(to)]);
879 mov(reg_work_amount, ptr[param + GET_OFF(work_amount)]);
881 mov(imm_addr64, float2int(desc.alpha));
882 movq(xmm_ns, imm_addr64);
883 uni_vbroadcastss(vmm_ns, xmm_ns);
885 uni_vpxor(vmm_zero, vmm_zero, vmm_zero);
889 for (int id = 0; id < 2; id++) {
891 cmp(reg_work_amount, uf[id] * loop_dec[id] - 1);
892 jle(loop_label[id + 1], T_NEAR);
894 compute_step(loop_vectorize[id], uf[id], shift[id]);
896 add(reg_from, uf[id] * shift[id]);
897 add(reg_to, uf[id] * shift[id]);
899 add(reg_for_comparison, uf[id] * shift[id]);
901 sub(reg_work_amount, uf[id] * loop_dec[id]);
908 ker_ = (decltype(ker_))this->getCode();
912 using Vmm = typename utils::conditional3<isa == sse42, Xmm,
913 isa == avx2, Ymm, Zmm>::type;
915 Reg64 reg_from = rax;
916 Reg64 reg_for_comparison = is_bwd() ? rdx : reg_from;
918 Reg64 reg_work_amount = rsi;
919 Reg64 imm_addr64 = rbx;
921 Xmm xmm_ns = Xmm(14);
923 Vmm vmm_ns = Vmm(isa == avx512_common ? 30 : 14);
924 Vmm vmm_zero = Vmm(isa == avx512_common ? 31 : 15);
926 Vmm vmm_mask = Vmm(isa == avx512_common ? 28 : 12);
927 Opmask k_mask = Opmask(1);
930 template <cpu_isa_t isa>
931 struct jit_uni_kernel_fwd_f32: public jit_uni_eltwise_kernel_f32,
932 public jit_generator {
933 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_kernel_fwd_f32)
935 jit_uni_kernel_fwd_f32(const eltwise_desc_t &desc)
936 : jit_uni_eltwise_kernel_f32(desc), jit_generator() {
938 eltwise_injector_ = new jit_uni_eltwise_injector_f32<isa>(this,
939 desc.alg_kind, desc.alpha, desc.beta, false, r9, Opmask(1));
941 using namespace alg_kind;
943 assert(is_bwd() == false);
944 assert(utils::one_of(desc.alg_kind, eltwise_tanh, eltwise_elu,
945 eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
946 eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic,
947 eltwise_clamp, eltwise_exp));
951 Reg64 param = abi_param1;
952 mov(reg_from, ptr[param + GET_OFF(from)]);
953 mov(reg_to, ptr[param + GET_OFF(to)]);
954 mov(reg_work_amount, ptr[param + GET_OFF(work_amount)]);
955 eltwise_injector_->load_table_addr();
957 Label reminder_loop_start, reminder_loop_end;
958 Label vectorized_loop_start, vectorized_loop_end;
960 cmp(reg_work_amount, simd_w);
961 jl(reminder_loop_start, T_NEAR);
963 L(vectorized_loop_start);
965 uni_vmovups(vmm_src, ptr[reg_from]);
966 eltwise_injector_->compute_vector(vmm_src.getIdx());
967 uni_vmovups(ptr[reg_to], vmm_src);
972 sub(reg_work_amount, simd_w);
973 cmp(reg_work_amount, simd_w);
974 jge(vectorized_loop_start, T_NEAR);
976 L(vectorized_loop_end);
978 L(reminder_loop_start);
980 cmp(reg_work_amount, 0);
981 jle(reminder_loop_end, T_NEAR);
983 movss(xmm_src, ptr[reg_from]);
984 eltwise_injector_->compute_vector(xmm_src.getIdx());
985 movss(ptr[reg_to], xmm_src);
987 add(reg_from, sizeof(float));
988 add(reg_to, sizeof(float));
990 dec(reg_work_amount);
991 jmp(reminder_loop_start, T_NEAR);
993 L(reminder_loop_end);
997 eltwise_injector_->prepare_table();
999 ker_ = (decltype(ker_))this->getCode();
1002 ~jit_uni_kernel_fwd_f32() { delete eltwise_injector_; }
1005 using Vmm = typename utils::conditional3<isa == sse42, Xmm,
1006 isa == avx2, Ymm, Zmm>::type;
1008 const int simd_w = cpu_isa_traits<isa>::vlen / sizeof(float);
1009 const int vlen = cpu_isa_traits<isa>::vlen;
1011 Reg64 reg_from = rax;
1013 Reg64 reg_work_amount = rsi;
1014 Reg64 imm_addr64 = rbx;
1016 Xmm xmm_src = Xmm(1);
1017 Vmm vmm_src = Vmm(1);
1019 jit_uni_eltwise_injector_f32<isa> *eltwise_injector_;
1024 template <cpu_isa_t isa>
1025 status_t jit_uni_eltwise_fwd_t<isa>::pd_t::init() {
1026 using namespace alg_kind;
1028 assert(engine()->kind() == engine_kind::cpu);
1029 bool ok = true && mayiuse(isa)
1030 && utils::one_of(desc()->prop_kind, prop_kind::forward_training,
1031 prop_kind::forward_inference)
1032 && utils::everyone_is(data_type::f32, desc()->data_desc.data_type)
1033 && !has_zero_dim_memory()
1034 && utils::one_of(desc()->alg_kind, eltwise_relu, eltwise_tanh,
1035 eltwise_elu, eltwise_square, eltwise_abs, eltwise_sqrt,
1036 eltwise_linear, eltwise_bounded_relu, eltwise_soft_relu,
1037 eltwise_logistic, eltwise_clamp, eltwise_exp)
1038 && memory_desc_wrapper(src_pd()).is_dense(true)
1039 && IMPLICATION(!memory_desc_wrapper(src_pd()).is_dense(false),
1040 math::eltwise_fwd_preserves_zero(desc()->alg_kind, true))
1041 && attr()->has_default_values();
1043 return ok ? status::success : status::unimplemented;
1046 template <cpu_isa_t isa>
1047 jit_uni_eltwise_fwd_t<isa>::jit_uni_eltwise_fwd_t(const pd_t *apd,
1048 const input_vector &inputs, const output_vector &outputs)
1049 : cpu_primitive_t(apd, inputs, outputs), kernel_(nullptr) {
1050 const auto &desc = *pd()->desc();
1051 switch (desc.alg_kind) {
1052 case alg_kind::eltwise_relu:
1053 kernel_ = new jit_uni_relu_kernel_f32<isa>(desc); break;
1055 kernel_ = new jit_uni_kernel_fwd_f32<isa>(desc);
1059 template <cpu_isa_t isa>
1060 jit_uni_eltwise_fwd_t<isa>::~jit_uni_eltwise_fwd_t()
1063 template <cpu_isa_t isa>
1064 void jit_uni_eltwise_fwd_t<isa>::execute_forward() const {
1065 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
1066 auto dst = reinterpret_cast<data_t *>(this->memory(0));
1068 const memory_desc_wrapper data_d(pd()->src_pd());
1070 const size_t nelems = data_d.nelems(true);
1072 src += data_d.blocking_desc().offset_padding;
1073 dst += data_d.blocking_desc().offset_padding;
1075 parallel(0, [&](const int ithr, const int nthr) {
1076 size_t start{0}, end{0};
1078 const int cache_line = 16;
1080 balance211(utils::div_up(nelems, cache_line), nthr, ithr, start, end);
1081 start = nstl::min(nelems, start * cache_line);
1082 end = nstl::min(nelems, end * cache_line);
1084 auto arg = jit_args();
1085 arg.from = &src[start];
1086 arg.for_comparison = &src[start];
1087 arg.to = &dst[start];
1088 arg.work_amount = end - start;
1089 if (arg.work_amount)
1094 template <cpu_isa_t isa>
1095 status_t jit_uni_eltwise_bwd_t<isa>::pd_t::init() {
1096 assert(engine()->kind() == engine_kind::cpu);
1099 && desc()->prop_kind == prop_kind::backward_data
1100 && utils::one_of(desc()->alg_kind, alg_kind::eltwise_relu)
1101 && src_pd()->desc()->data_type == data_type::f32
1102 && !has_zero_dim_memory()
1104 && memory_desc_wrapper(src_pd()).is_dense()
1105 && memory_desc_wrapper(diff_dst_pd()) == memory_desc_wrapper(src_pd())
1106 && attr()->has_default_values();
1108 return ok ? status::success : status::unimplemented;
1111 template <cpu_isa_t isa>
1112 jit_uni_eltwise_bwd_t<isa>::jit_uni_eltwise_bwd_t(const pd_t *apd,
1113 const input_vector &inputs, const output_vector &outputs)
1114 : cpu_primitive_t(apd, inputs, outputs), kernel_(nullptr) {
1115 const auto &desc = *pd()->desc();
1116 switch (desc.alg_kind) {
1117 case alg_kind::eltwise_relu:
1118 kernel_ = new jit_uni_relu_kernel_f32<isa>(desc); break;
1119 default: assert(!"unknown eltwise alg_kind");
1123 template <cpu_isa_t isa>
1124 jit_uni_eltwise_bwd_t<isa>::~jit_uni_eltwise_bwd_t()
1127 template <cpu_isa_t isa>
1128 void jit_uni_eltwise_bwd_t<isa>::execute_backward() const {
1129 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
1130 auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
1131 auto diff_src = reinterpret_cast<data_t *>(this->memory(0));
1133 const memory_desc_wrapper data_d(pd()->src_pd());
1134 const memory_desc_wrapper diff_data_d(pd()->diff_src_pd());
1136 const size_t nelems = data_d.nelems();
1138 src += data_d.blocking_desc().offset_padding;
1139 diff_dst += diff_data_d.blocking_desc().offset_padding;
1140 diff_src += diff_data_d.blocking_desc().offset_padding;
1142 parallel(0, [&](const int ithr, const int nthr) {
1143 size_t start{0}, end{0};
1145 const int cache_line = 16;
1147 balance211(utils::div_up(nelems, cache_line), nthr, ithr, start, end);
1148 start = nstl::min(nelems, start * cache_line);
1149 end = nstl::min(nelems, end * cache_line);
1151 auto arg = jit_args();
1152 arg.from = &diff_dst[start];
1153 arg.to = &diff_src[start];
1154 arg.for_comparison = &src[start];
1155 arg.work_amount = end - start;
1156 if (arg.work_amount)
1161 template struct jit_uni_eltwise_fwd_t<sse42>;
1162 template struct jit_uni_eltwise_bwd_t<sse42>;
1163 template struct jit_uni_eltwise_fwd_t<avx2>;
1164 template struct jit_uni_eltwise_bwd_t<avx2>;
1165 template struct jit_uni_eltwise_fwd_t<avx512_common>;
1166 template struct jit_uni_eltwise_bwd_t<avx512_common>;