1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include <mkldnn_types.h>
18 #include "mkldnn_types.h"
19 #include "mkldnn_thread.hpp"
22 #include "jit_generator.hpp"
24 #include "jit_uni_eltwise.hpp"
26 #define GET_OFF(field) offsetof(jit_args, field)
32 using namespace Xbyak;
34 template <cpu_isa_t isa>
35 bool jit_uni_eltwise_injector_f32<isa>::is_free_vec(size_t idx) {
36 for (size_t i = 0; i < preserved_vecs_count; i++) {
37 if (preserved_vec_idxs[i] == idx) {
44 template <cpu_isa_t isa>
45 void jit_uni_eltwise_injector_f32<isa>::injector_preamble(size_t start_idx,
47 preserved_vecs_count = 0;
48 vecs_to_preserve = (size_t)jit_uni_eltwise_injector_f32<isa>::
49 aux_vecs_count(elt_alg);
50 start_idx_tail = start_idx;
52 // For sse42 mask register has to be Xmm(0)
53 if (isa == sse42 && vecs_to_preserve > 0) {
55 assert(idx < start_idx);
56 preserved_vec_idxs[preserved_vecs_count++] = idx;
59 for (size_t i = 0; i < vecs_count; i++) {
60 if (preserved_vecs_count >= vecs_to_preserve)
64 if (is_free_vec(idx) && (idx < start_idx || idx >= end_idx)) {
65 preserved_vec_idxs[preserved_vecs_count++] = idx;
69 size_t preserved_vecs_count_tail = vecs_to_preserve - preserved_vecs_count;
70 for (size_t i = 0; i < preserved_vecs_count_tail; i++) {
71 size_t idx = start_idx_tail;
72 if (is_free_vec(idx)) {
73 preserved_vec_idxs[preserved_vecs_count++] = idx;
78 assert(preserved_vecs_count == vecs_to_preserve);
80 if (save_vecs_state) {
83 h->sub(h->rsp, preserved_vecs_count * vlen);
84 for (size_t i = 0; i < preserved_vecs_count; ++i)
85 h->uni_vmovups(h->ptr[h->rsp + i * vlen],
86 Vmm(preserved_vec_idxs[i]));
92 template <cpu_isa_t isa>
93 void jit_uni_eltwise_injector_f32<isa>::injector_preamble_tail(
95 size_t tail_vecs_to_preserve = start_idx_tail - start_idx;
96 int idx_off = (vecs_to_preserve - tail_vecs_to_preserve);
98 if (tail_vecs_to_preserve > 0) {
99 if (save_vecs_state) {
100 h->add(h->rsp, idx_off * vlen);
101 for (size_t i = 0; i < tail_vecs_to_preserve; ++i)
102 h->uni_vmovups(Vmm(preserved_vec_idxs[idx_off + i]),
103 h->ptr[h->rsp + i * vlen]);
106 for (size_t i = 0; i < tail_vecs_to_preserve; ++i) {
107 preserved_vec_idxs[idx_off + i] += tail_vecs_to_preserve;
110 if (save_vecs_state) {
111 for (size_t i = 0; i < tail_vecs_to_preserve; ++i)
112 h->uni_vmovups(h->ptr[h->rsp + i * vlen],
113 Vmm(preserved_vec_idxs[idx_off + i]));
114 h->sub(h->rsp, idx_off * vlen);
121 template <cpu_isa_t isa>
122 void jit_uni_eltwise_injector_f32<isa>::injector_postamble() {
123 if (save_vecs_state) {
124 for (size_t i = 0; i < preserved_vecs_count; ++i)
125 h->uni_vmovups(Vmm(preserved_vec_idxs[i]),
126 h->ptr[h->rsp + i * vlen]);
127 h->add(h->rsp, preserved_vecs_count * vlen);
133 template <cpu_isa_t isa>
134 void jit_uni_eltwise_injector_f32<isa>::assign_regs() {
135 vmm_mask = Vmm(preserved_vec_idxs[0]);
136 vmm_aux0 = Vmm(preserved_vec_idxs[0]);
137 vmm_aux1 = Vmm(preserved_vec_idxs[1]);
138 vmm_aux2 = Vmm(preserved_vec_idxs[2]);
139 vmm_aux3 = Vmm(preserved_vec_idxs[3]);
141 p_table = Xbyak::Reg64(table_reg_idx);
142 k_mask = Xbyak::Opmask(opmask_idx);
145 template <cpu_isa_t isa>
146 void jit_uni_eltwise_injector_f32<isa>::exp_compute_vector(const Vmm &vmm_src) {
147 const unsigned char _op_floor = 1;
149 h->uni_vminps(vmm_src, vmm_src, h->ptr[p_table + 10 * vlen]);
150 h->uni_vmaxps(vmm_src, vmm_src, h->ptr[p_table + 11 * vlen]);
151 h->uni_vmovups(vmm_aux0, vmm_src);
153 // fx = x * log2ef + 0.5
154 h->uni_vmulps(vmm_src, vmm_src, h->ptr[p_table + 2 * vlen]);
155 h->uni_vaddps(vmm_src, vmm_src, h->ptr[p_table + 1 * vlen]);
158 if (isa == avx512_common) {
159 h->vcvtps2dq(vmm_aux1 | h->T_rd_sae, vmm_src);
160 h->vcvtdq2ps(vmm_aux1, vmm_aux1);
162 unsigned char _cmp_gt_os = 14;
163 Xbyak::Opmask k_mask_tmp = Xbyak::Opmask(2);
164 h->vcmpps(k_mask_tmp, vmm_aux1, vmm_src, _cmp_gt_os);
165 h->vmovups(vmm_aux3 | k_mask_tmp | h->T_z,
166 h->zword[p_table + 0 * vlen]);
168 h->uni_vsubps(vmm_aux1, vmm_aux1, vmm_aux3);
170 h->uni_vroundps(vmm_aux1, vmm_src, _op_floor);
173 //keep fx for further computations
174 h->uni_vmovups(vmm_src, vmm_aux1); //vmm_src = fx
177 h->uni_vfnmadd231ps(vmm_aux0, vmm_aux1, h->ptr[p_table + 3 * vlen]);
180 h->uni_vcvtps2dq(vmm_aux1, vmm_src);
181 h->uni_vpaddd(vmm_aux1, vmm_aux1, h->ptr[p_table + 4 * vlen]);
182 h->uni_vpslld(vmm_aux1, vmm_aux1, 23); //Vmm(6) = 2^-fx
185 h->uni_vmovups(vmm_src, h->ptr[p_table + 9 * vlen]);
187 h->uni_vfmadd213ps(vmm_src, vmm_aux0, h->ptr[p_table + 8 * vlen]);
189 h->uni_vfmadd213ps(vmm_src, vmm_aux0, h->ptr[p_table + 7 * vlen]);
191 h->uni_vfmadd213ps(vmm_src, vmm_aux0, h->ptr[p_table + 6 * vlen]);
193 h->uni_vfmadd213ps(vmm_src, vmm_aux0, h->ptr[p_table + 0 * vlen]);
195 h->uni_vfmadd213ps(vmm_src, vmm_aux0, h->ptr[p_table + 5 * vlen]); //exp(q)
197 h->uni_vmulps(vmm_src, vmm_src, vmm_aux1);
200 template <cpu_isa_t isa>
201 void jit_uni_eltwise_injector_f32<isa>::relu_compute_vector(
202 const Vmm &vmm_src) {
203 unsigned char _cmp_gt_os = isa == avx512_common ? 14 : 6;
205 int alpha_off = 0 * vlen;
206 int zero_off = 1 * vlen;
208 h->uni_vmovups(vmm_aux1, vmm_src);
210 h->movups(vmm_mask, vmm_src);
211 h->mulps(vmm_src, h->ptr[p_table + alpha_off]);
212 h->cmpps(vmm_mask, h->ptr[p_table + zero_off], _cmp_gt_os);
213 h->blendvps(vmm_src, vmm_aux1);
214 } else if (isa == avx2) {
215 h->vmulps(vmm_src, vmm_src, h->ptr[p_table + alpha_off]);
216 h->vcmpgtps(vmm_mask, vmm_aux1, h->ptr[p_table + zero_off]);
217 h->vblendvps(vmm_src, vmm_src, vmm_aux1, vmm_mask);
218 } else if (isa == avx512_common) {
219 h->vmulps(vmm_src, vmm_src, h->ptr[p_table + alpha_off]);
220 h->vcmpps(k_mask, vmm_aux1, h->ptr[p_table + zero_off], _cmp_gt_os);
221 h->vblendmps(vmm_src | k_mask, vmm_src,
226 template <cpu_isa_t isa>
227 void jit_uni_eltwise_injector_f32<isa>::relu_zero_ns_compute_vector(
228 const Vmm &vmm_src) {
229 int zero_off = 1 * vlen;
230 h->uni_vmaxps(vmm_src, vmm_src, h->ptr[p_table + zero_off]);
233 template <cpu_isa_t isa>
234 void jit_uni_eltwise_injector_f32<isa>::elu_compute_vector(const Vmm &vmm_src) {
235 const unsigned char _cmp_gt_os = 6;
236 const unsigned char _cmp_let_os = 2;
237 int alpha_off = 12 * vlen;
238 int zero_off = 13 * vlen;
241 h->uni_vmovups(vmm_aux2, vmm_src);
242 exp_compute_vector(vmm_src);
244 // alpha * (exp(x) - 1)
245 h->uni_vsubps(vmm_src, vmm_src, h->ptr[p_table + 0 * 32]);
246 h->uni_vmulps(vmm_src, vmm_src, h->ptr[p_table + alpha_off]);
250 h->pxor(vmm_mask, vmm_mask);
251 h->cmpps(vmm_mask, vmm_aux2, _cmp_let_os);
252 h->blendvps(vmm_src, vmm_aux2);
253 } else if (isa == avx2) {
254 h->uni_vcmpgtps(vmm_mask, vmm_aux2, h->ptr[p_table + zero_off]);
255 h->uni_vblendvps(vmm_src, vmm_src, vmm_aux2, vmm_mask);
256 } else if (isa == avx512_common) {
257 h->vcmpps(k_mask, vmm_aux2, h->ptr[p_table + zero_off], _cmp_gt_os);
258 h->vblendmps(vmm_src | k_mask, vmm_src, vmm_aux2);
262 template <cpu_isa_t isa>
263 void jit_uni_eltwise_injector_f32<isa>::tanh_compute_vector(
264 const Vmm &vmm_src) {
266 h->uni_vaddps(vmm_src, vmm_src, vmm_src);
267 exp_compute_vector(vmm_src);
269 h->uni_vmovups(vmm_aux0, vmm_src);
271 h->uni_vsubps(vmm_src, vmm_src, h->ptr[p_table + 0 * vlen]);
273 h->uni_vaddps(vmm_aux0, vmm_aux0, h->ptr[p_table + 0 * vlen]);
274 // y = (exp(2x) - 1) / (exp(2x) + 1)
275 h->uni_vdivps(vmm_src, vmm_src, vmm_aux0);
278 template <cpu_isa_t isa>
279 void jit_uni_eltwise_injector_f32<isa>::square_compute_vector(
280 const Vmm &vmm_src) {
281 h->uni_vmulps(vmm_src, vmm_src, vmm_src);
284 template <cpu_isa_t isa>
285 void jit_uni_eltwise_injector_f32<isa>::abs_compute_vector(const Vmm &vmm_src) {
286 // compute abs(x) = _mm_and_ps(x, 01111..111));
287 h->uni_vandps(vmm_src, vmm_src, h->ptr[p_table + 0*vlen]);
290 template <cpu_isa_t isa>
291 void jit_uni_eltwise_injector_f32<isa>::sqrt_compute_vector(
292 const Vmm &vmm_src) {
293 if (isa == avx512_common) {
294 unsigned char _cmp_gt_os = 6;
296 h->vcmpps(k_mask, vmm_src, h->ptr[p_table + 0 * vlen], _cmp_gt_os);
297 h->uni_vsqrtps(vmm_aux1, vmm_src);
298 h->uni_vmovups(vmm_src, h->ptr[p_table + 0*vlen]);
299 h->vblendmps(vmm_src | k_mask, vmm_src, vmm_aux1);
301 h->uni_vmovups(vmm_mask, vmm_src);
302 h->uni_vcmpgtps(vmm_mask, vmm_mask, h->ptr[p_table + 0*vlen]);
303 h->uni_vsqrtps(vmm_aux1, vmm_src);
304 h->uni_vmovups(vmm_src, h->ptr[p_table + 0*vlen]);
305 h->uni_vblendvps(vmm_src, vmm_src, vmm_aux1, vmm_mask);
309 template <cpu_isa_t isa>
310 void jit_uni_eltwise_injector_f32<isa>::linear_compute_vector(
311 const Vmm &vmm_src) {
312 // compute x = alpha * x + beta;
313 h->uni_vmovups(vmm_aux0, h->ptr[p_table + 0*vlen]);
314 h->uni_vfmadd213ps(vmm_src, vmm_aux0, h->ptr[p_table + 1*vlen]);
317 template <cpu_isa_t isa>
318 void jit_uni_eltwise_injector_f32<isa>::bounded_relu_compute_vector(
319 const Vmm &vmm_src) {
320 // compute bounded relu */
321 h->uni_vmaxps(vmm_src, vmm_src, h->ptr[p_table + 1*vlen]);
322 h->uni_vminps(vmm_src, vmm_src, h->ptr[p_table + 0*vlen]);
325 template <cpu_isa_t isa>
326 void jit_uni_eltwise_injector_f32<isa>::clamp_compute_vector(
327 const Vmm &vmm_src) {
328 h->uni_vmaxps(vmm_src, vmm_src, h->ptr[p_table + 1*vlen]);
329 h->uni_vminps(vmm_src, vmm_src, h->ptr[p_table + 0*vlen]);
332 template <cpu_isa_t isa>
333 void jit_uni_eltwise_injector_f32<isa>::soft_relu_compute_vector(
334 const Vmm &vmm_src) {
335 const unsigned char _op_floor = 1;
337 h->uni_vmovups(vmm_aux2, vmm_src);
339 h->uni_vminps(vmm_src, vmm_src, h->ptr[p_table + 24 * vlen]);
340 h->uni_vmaxps(vmm_src, vmm_src, h->ptr[p_table + 25 * vlen]);
341 h->uni_vmovups(vmm_aux1, vmm_src);
343 // fx = x * log2ef + 0.5
344 h->uni_vmulps(vmm_src, vmm_src, h->ptr[p_table + 2 * vlen]);
345 h->uni_vaddps(vmm_src, vmm_src, h->ptr[p_table + 1 * vlen]);
348 if (isa == avx512_common) {
349 h->vcvtps2dq(vmm_aux0 | h->T_rd_sae, vmm_src);
350 h->vcvtdq2ps(vmm_aux0, vmm_aux0);
352 unsigned char _cmp_gt_os = 14;
353 h->vcmpps(k_mask, vmm_aux0, vmm_src, _cmp_gt_os);
354 h->vmovups(vmm_aux3 | k_mask | h->T_z, h->ptr[p_table + 0 * vlen]);
356 h->vsubps(vmm_aux0, vmm_aux0, vmm_aux3);
358 h->uni_vroundps(vmm_aux0, vmm_src, _op_floor);
361 // keep fx for further computations
362 h->uni_vmovups(vmm_src, vmm_aux0); //vmm_src = fx
363 // calculation fx * ln2
364 h->uni_vmulps(vmm_aux0, vmm_aux0, h->ptr[p_table + 3 * vlen]);
366 h->uni_vsubps(vmm_aux1, vmm_aux1, vmm_aux0);
368 h->uni_vmovups(vmm_aux3, h->ptr[p_table + 22 * vlen]);
370 h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, h->ptr[p_table + 21 * vlen]);
372 h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, h->ptr[p_table + 20 * vlen]);
374 h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, h->ptr[p_table + 19 * vlen]);
376 h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, h->ptr[p_table + 0 * vlen]);
378 h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, h->ptr[p_table + 17 * vlen]);
381 if (isa == avx512_common) {
382 h->vmulps(vmm_aux1, vmm_src, h->ptr[p_table + 23 * vlen]);
383 h->vcvtps2dq(vmm_aux1, vmm_aux1);
385 h->uni_vcvtps2dq(vmm_aux1, vmm_src);
386 h->uni_vpsignd(vmm_aux1, vmm_aux1, h->ptr[p_table + 23 * vlen]);
389 h->uni_vpaddd(vmm_aux1, vmm_aux1, h->ptr[p_table + 4 * vlen]);
390 h->uni_vpslld(vmm_aux1, vmm_aux1, 23); //vmm_aux1 = 2^-fx
391 // calculate ln(1 + y)
392 h->uni_vaddps(vmm_aux3, vmm_aux3, vmm_aux1);
393 // x = y; y is free; keep x for further computations
394 h->uni_vmovups(vmm_src, vmm_aux3);
396 h->uni_vpsrld(vmm_src, vmm_src, 23);
397 h->uni_vcvtdq2ps(vmm_src, vmm_src);
398 // got n. where n is x = 2^n * y. y = 0.5 .. 1
399 h->uni_vsubps(vmm_src, vmm_src, h->ptr[p_table + 5 * vlen]);
401 h->uni_vandps(vmm_aux3, vmm_aux3, h->ptr[p_table + 6 * vlen]);
402 // got y. (mantisa) 0.5 < y < 1
403 h->uni_vorps(vmm_aux3, vmm_aux3, h->ptr[p_table + 7 * vlen]);
405 h->uni_vsubps(vmm_aux3, vmm_aux3, h->ptr[p_table + 0 * vlen]);
407 h->uni_vmovups(vmm_aux1, h->ptr[p_table + 16 * vlen]);
409 h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, h->ptr[p_table + 15 * vlen]);
411 h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, h->ptr[p_table + 14 * vlen]);
413 h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, h->ptr[p_table + 13 * vlen]);
415 h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, h->ptr[p_table + 12 * vlen]);
417 h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, h->ptr[p_table + 11 * vlen]);
419 h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, h->ptr[p_table + 10 * vlen]);
421 h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, h->ptr[p_table + 9 * vlen]);
422 // y = y * x + p0 ; p0 = 0
423 h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, h->ptr[p_table + 8 * vlen]);
424 //calculate ln(2) * n
425 h->uni_vmulps(vmm_src, vmm_src, h->ptr[p_table + 3 * vlen]);
426 h->uni_vaddps(vmm_aux1, vmm_aux1, vmm_src);
427 h->uni_vaddps(vmm_aux1, vmm_aux1, vmm_aux0);
429 // get vmm_mask = src > max logf
430 h->uni_vmovups(vmm_mask, vmm_aux2);
431 if (isa == avx512_common) {
432 unsigned char _cmp_gt_os = 6;
433 // y = (x < max log f) ? soft_relu(x) : x
434 h->vcmpps(k_mask, vmm_mask, h->ptr[p_table + 24 * vlen], _cmp_gt_os);
435 h->vblendmps(vmm_aux1 | k_mask, vmm_aux1, vmm_aux2);
437 // y = (x < max log f) ? soft_relu(x) : x
438 h->uni_vcmpgtps(vmm_mask, vmm_mask, h->ptr[p_table + 24 * vlen]);
439 h->uni_vblendvps(vmm_aux1, vmm_aux1, vmm_aux2, vmm_mask);
442 h->uni_vmovups(vmm_src, vmm_aux1);
445 template <cpu_isa_t isa>
446 void jit_uni_eltwise_injector_f32<isa>::logistic_compute_vector(
447 const Vmm &vmm_src) {
448 exp_compute_vector(vmm_src);
450 h->uni_vmovups(vmm_aux0, vmm_src);
452 h->uni_vaddps(vmm_aux0, vmm_aux0, h->ptr[p_table + 0 * vlen]);
453 // y = exp(x) / (exp(x) + 1)
454 h->uni_vdivps(vmm_src, vmm_src, vmm_aux0);
457 template <cpu_isa_t isa>
458 void jit_uni_eltwise_injector_f32<isa>::relu_prepare_table() {
459 for (size_t d = 0; d < vlen / sizeof(float); ++d) {
460 h->dd(float2int(alpha));
462 for (size_t d = 0; d < vlen / sizeof(float); ++d) {
467 template <cpu_isa_t isa>
468 void jit_uni_eltwise_injector_f32<isa>::elu_prepare_table() {
469 const unsigned int cvals[] = {
470 0x3f800000, // [0] 1.0f
471 0x3f000000, // [1] 0.5f
472 0x3fb8aa3b, // [2] log2ef = 1.44269502f
473 0x3f317218, // [3] ln2f = 0.69314718f
474 0x0000007f, // [4] 0x7f
476 0x3f800001, // [5] p0 = 1.0000001f
477 0x3efffe85, // [6] p2 = 0.4999887f
478 0x3e2aaa3e, // [7] p3 = 0.16666505f
479 0x3d2bb1b1, // [8] p4 = 0.041917507f
480 0x3c091ec1, // [9] p5 = 0.008369149f
481 0x42b0c0a5, //[10] max logf = 88.3762589f
482 0xc1766666 //[11] min logf = -14.5f
485 for (size_t i = 0; i < sizeof(cvals) / sizeof(cvals[0]); ++i) {
486 for (size_t d = 0; d < vlen / sizeof(float); ++d) {
490 for (size_t d = 0; d < vlen / sizeof(float); ++d) {
491 h->dd(float2int(alpha));
493 for (size_t d = 0; d < vlen / sizeof(float); ++d) {
498 template <cpu_isa_t isa>
499 void jit_uni_eltwise_injector_f32<isa>::soft_relu_prepare_table() {
500 const unsigned int cvals[] = {
501 0x3f800000, // [0] 1.0f
502 0x3f000000, // [1] 0.5f
503 0x3fb8aa3b, // [2] log2ef = 1.44269502f
504 0x3f317218, // [3] ln2f = 0.69314718f
505 0x0000007f, // [4] 0x7f
506 0x42fc0000, // [5] 126
507 0x807fffff, // [6] and with (to get 0.5 * mantissa)
508 0x3f000000, // [7] or with (to get 0.5 * mantissa)
509 // ln(1 + x) polynomial
510 0xb2b4637d, // [8] p0 = 0.0000000244f
511 0x3f7fff8e, // [9] p1 = 0.9999976971f
512 0xbf001759, //[10] p2 = -0.5002478215f
513 0x3ea70608, //[11] p3 = 0.3272714505f
514 0xbea3d7bf, //[12] p4 = -0.3153830071f
515 0xbe361d04, //[13] p5 = -0.1701777461f
516 0xbfa8f1e6, //[14] p6 = -1.3254635147f
517 0xbfe1e812, //[15] p7 = -1.7971917960f
518 0xbfc4d30e, //[16] p8 = -1.5652673123f
520 0x3f800001, //[17] p0 = 1.0000001f
521 0x3f800000, //[18] p1 = 1.0f
522 0x3efffe85, //[19] p2 = 0.4999887f
523 0x3e2aaa3e, //[20] p3 = 0.16666505f
524 0x3d2bb1b1, //[21] p4 = 0.041917507f
525 0x3c091ec1, //[22] p5 = 0.008369149f
526 0xbf800000, //[23] is required for sign changing
527 0x42b0c0a5, //[24] max logf = 88.3762589f
528 0xc1766666 //[25] min logf = -14.5f
531 for (size_t i = 0; i < sizeof(cvals) / sizeof(cvals[0]); ++i) {
532 for (size_t d = 0; d < vlen / sizeof(float); ++d) {
538 template <cpu_isa_t isa>
539 void jit_uni_eltwise_injector_f32<isa>::abs_prepare_table() {
540 for (size_t d = 0; d < vlen / sizeof(float); ++d) {
545 template <cpu_isa_t isa>
546 void jit_uni_eltwise_injector_f32<isa>::sqrt_prepare_table() {
547 for (size_t d = 0; d < vlen / sizeof(float); ++d) {
552 template <cpu_isa_t isa>
553 void jit_uni_eltwise_injector_f32<isa>::linear_prepare_table() {
554 for (size_t d = 0; d < vlen / sizeof(float); ++d) {
555 h->dd(float2int(alpha));
557 for (size_t d = 0; d < vlen / sizeof(float); ++d) {
558 h->dd(float2int(beta));
562 template <cpu_isa_t isa>
563 void jit_uni_eltwise_injector_f32<isa>::bounded_relu_prepare_table() {
564 for (size_t d = 0; d < vlen / sizeof(float); ++d) {
565 h->dd(float2int(alpha));
567 for (size_t d = 0; d < vlen / sizeof(float); ++d) {
572 template <cpu_isa_t isa>
573 void jit_uni_eltwise_injector_f32<isa>::clamp_prepare_table() {
574 for (size_t d = 0; d < vlen / sizeof(float); ++d) {
575 h->dd(float2int(alpha));
577 for (size_t d = 0; d < vlen / sizeof(float); ++d) {
578 h->dd(float2int(beta));
582 template <cpu_isa_t isa>
583 int jit_uni_eltwise_injector_f32<isa>::aux_vecs_count(alg_kind_t elt_alg) {
585 case alg_kind::eltwise_relu: return (alpha == 0.f) ? 0 : 2;
586 case alg_kind::eltwise_elu: return 4;
587 case alg_kind::eltwise_tanh: return 4;
588 case alg_kind::eltwise_square: return 0;
589 case alg_kind::eltwise_abs: return 0;
590 case alg_kind::eltwise_sqrt: return 2;
591 case alg_kind::eltwise_linear: return 1;
592 case alg_kind::eltwise_bounded_relu: return 0;
593 case alg_kind::eltwise_soft_relu: return 4;
594 case alg_kind::eltwise_logistic: return 4;
595 case alg_kind::eltwise_clamp: return 0;
596 default: assert(!"unsupported eltwise algorithm");
602 template <cpu_isa_t isa>
603 void jit_uni_eltwise_injector_f32<isa>::compute_body(size_t start_idx,
605 h->mov(p_table, l_table);
607 for (size_t idx = start_idx; idx < end_idx; idx++) {
609 case alg_kind::eltwise_relu:
611 relu_zero_ns_compute_vector(Vmm(idx));
613 relu_compute_vector(Vmm(idx));
615 case alg_kind::eltwise_elu:
616 elu_compute_vector(Vmm(idx)); break;
617 case alg_kind::eltwise_tanh:
618 tanh_compute_vector(Vmm(idx)); break;
619 case alg_kind::eltwise_square:
620 square_compute_vector(Vmm(idx)); break;
621 case alg_kind::eltwise_abs:
622 abs_compute_vector(Vmm(idx)); break;
623 case alg_kind::eltwise_sqrt:
624 sqrt_compute_vector(Vmm(idx)); break;
625 case alg_kind::eltwise_linear:
626 linear_compute_vector(Vmm(idx)); break;
627 case alg_kind::eltwise_bounded_relu:
628 bounded_relu_compute_vector(Vmm(idx)); break;
629 case alg_kind::eltwise_soft_relu:
630 soft_relu_compute_vector(Vmm(idx)); break;
631 case alg_kind::eltwise_logistic:
632 logistic_compute_vector(Vmm(idx)); break;
633 case alg_kind::eltwise_clamp:
634 clamp_compute_vector(Vmm(idx)); break;
635 default: assert(!"unsupported eltwise algorithm");
640 template <cpu_isa_t isa>
641 void jit_uni_eltwise_injector_f32<isa>::compute_vector_range(size_t start_idx,
643 assert(start_idx < vecs_count);
644 assert(end_idx <= vecs_count);
645 assert(start_idx < end_idx);
647 injector_preamble(start_idx, end_idx);
648 compute_body(start_idx_tail, end_idx);
649 injector_preamble_tail(start_idx);
650 compute_body(start_idx, start_idx_tail);
651 injector_postamble();
654 template <cpu_isa_t isa>
655 void jit_uni_eltwise_injector_f32<isa>::compute_vector(size_t idx) {
656 compute_vector_range(idx, idx + 1);
659 template <cpu_isa_t isa>
660 void jit_uni_eltwise_injector_f32<isa>::prepare_table() {
665 case alg_kind::eltwise_relu:
666 relu_prepare_table(); break;
667 case alg_kind::eltwise_elu:
668 case alg_kind::eltwise_tanh:
669 case alg_kind::eltwise_logistic:
670 elu_prepare_table(); break;
671 case alg_kind::eltwise_soft_relu:
672 soft_relu_prepare_table(); break;
673 case alg_kind::eltwise_abs:
674 abs_prepare_table(); break;
675 case alg_kind::eltwise_sqrt:
676 sqrt_prepare_table(); break;
677 case alg_kind::eltwise_linear:
678 linear_prepare_table(); break;
679 case alg_kind::eltwise_bounded_relu:
680 bounded_relu_prepare_table(); break;
681 case alg_kind::eltwise_square:
683 case alg_kind::eltwise_clamp:
684 clamp_prepare_table(); break;
685 default: assert(!"unsupported eltwise algorithm");
689 template struct jit_uni_eltwise_injector_f32<avx512_common>;
690 template struct jit_uni_eltwise_injector_f32<avx2>;
691 template struct jit_uni_eltwise_injector_f32<sse42>;
696 const float *for_comparison;
701 struct jit_uni_eltwise_kernel_f32 : public c_compatible {
702 const eltwise_desc_t &desc_;
704 void (*ker_)(const jit_args *);
705 void operator()(const jit_args *args) { assert(ker_); ker_(args); }
707 jit_uni_eltwise_kernel_f32(const eltwise_desc_t &desc)
708 : desc_(desc), ker_(nullptr) {}
709 virtual ~jit_uni_eltwise_kernel_f32() {}
712 bool is_bwd() const { return desc_.prop_kind == prop_kind::backward_data; }
718 template <cpu_isa_t isa>
719 struct jit_uni_relu_kernel_f32 : public jit_uni_eltwise_kernel_f32,
722 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_relu_kernel_f32)
724 void compute_step(bool vectorize, const int uf, const int shift) {
725 for (int i = 0; i < uf; i++) {
727 uni_vmovups(Vmm(i + 1), ptr[reg_from + i * shift]);
729 uni_vmovups(Vmm(uf + i + 1),
730 ptr[reg_for_comparison + i * shift]);
732 movss(Xmm(i + 1), ptr[reg_from + i * shift]);
734 movss(Xmm(uf + i + 1),
735 ptr[reg_for_comparison + i * shift]);
740 for (int i = 0; i < uf; i++) {
741 movups(Vmm(2 * uf + i + 1), Vmm(i + 1));
742 mulps(Vmm(2 * uf + i + 1), vmm_ns);
746 movups(mask, Vmm(uf + i + 1));
747 cmpps(mask, vmm_zero, _cmp_nle_us);
749 movups(mask, Vmm(i + 1));
750 cmpps(mask, vmm_zero, _cmp_nle_us);
752 blendvps(Vmm(2 * uf + i + 1), Vmm(i + 1));
755 for (int i = 0; i < uf; i++) {
756 vmulps(Vmm(2 * uf + i + 1), Vmm(i + 1), vmm_ns);
759 vcmpgtps(vmm_mask, Vmm(uf + i + 1), vmm_zero);
761 vcmpgtps(vmm_mask, Vmm(i + 1), vmm_zero);
763 vblendvps(Vmm(2 * uf + i + 1), Vmm(2 * uf + i + 1),
764 Vmm(i + 1), vmm_mask);
768 vcmpps(k_mask, Vmm(uf + i + 1), vmm_zero, _cmp_nle_us);
770 vcmpps(k_mask, Vmm(i + 1), vmm_zero, _cmp_nle_us);
771 vblendmps(Vmm(2 * uf + i + 1) | k_mask, Vmm(2 * uf + i + 1),
777 for (int i = 0; i < uf; i++) {
779 uni_vmovups(ptr[reg_to + i * shift], Vmm(2 * uf + i + 1));
781 movss(ptr[reg_to + i * shift], Xmm(2 * uf + i + 1));
786 jit_uni_relu_kernel_f32(const eltwise_desc_t &desc)
787 : jit_uni_eltwise_kernel_f32(desc), jit_generator() {
788 assert(desc.alg_kind == alg_kind::eltwise_relu);
789 assert(isa == sse42 || isa == avx2 || isa == avx512_common);
791 Reg64 param = abi_param1;
793 const int simd_w = cpu_isa_traits<isa>::vlen / sizeof(float);
794 const int loop_dec[] = {simd_w, 1};
795 const int uf[] = {1, 1};
796 const int shift[] = {cpu_isa_traits<isa>::vlen, sizeof(float)};
797 const bool loop_vectorize[] = {true, false};
801 mov(reg_from, ptr[param + GET_OFF(from)]);
803 mov(reg_for_comparison, ptr[param + GET_OFF(for_comparison)]);
804 mov(reg_to, ptr[param + GET_OFF(to)]);
805 mov(reg_work_amount, ptr[param + GET_OFF(work_amount)]);
807 mov(imm_addr64, float2int(desc.alpha));
808 movq(xmm_ns, imm_addr64);
809 uni_vbroadcastss(vmm_ns, xmm_ns);
811 uni_vpxor(vmm_zero, vmm_zero, vmm_zero);
815 for (int id = 0; id < 2; id++) {
817 cmp(reg_work_amount, uf[id] * loop_dec[id] - 1);
818 jle(loop_label[id + 1], T_NEAR);
820 compute_step(loop_vectorize[id], uf[id], shift[id]);
822 add(reg_from, uf[id] * shift[id]);
823 add(reg_to, uf[id] * shift[id]);
825 add(reg_for_comparison, uf[id] * shift[id]);
827 sub(reg_work_amount, uf[id] * loop_dec[id]);
834 ker_ = (decltype(ker_))this->getCode();
838 using Vmm = typename utils::conditional3<isa == sse42, Xmm,
839 isa == avx2, Ymm, Zmm>::type;
841 Reg64 reg_from = rax;
842 Reg64 reg_for_comparison = is_bwd() ? rdx : reg_from;
844 Reg64 reg_work_amount = rsi;
845 Reg64 imm_addr64 = rbx;
847 Xmm xmm_ns = Xmm(14);
849 Vmm vmm_ns = Vmm(isa == avx512_common ? 30 : 14);
850 Vmm vmm_zero = Vmm(isa == avx512_common ? 31 : 15);
852 Vmm vmm_mask = Vmm(isa == avx512_common ? 28 : 12);
853 Opmask k_mask = Opmask(1);
856 template <cpu_isa_t isa>
857 struct jit_uni_kernel_fwd_f32: public jit_uni_eltwise_kernel_f32,
858 public jit_generator {
859 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_kernel_fwd_f32)
861 jit_uni_kernel_fwd_f32(const eltwise_desc_t &desc)
862 : jit_uni_eltwise_kernel_f32(desc), jit_generator() {
864 eltwise_injector = new jit_uni_eltwise_injector_f32<isa>(this,
865 desc.alg_kind, desc.alpha, desc.beta, false, 9, 1);
867 using namespace alg_kind;
869 assert(is_bwd() == false);
870 assert(utils::one_of(desc.alg_kind, eltwise_tanh, eltwise_elu,
871 eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
872 eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic));
876 Reg64 param = abi_param1;
877 mov(reg_from, ptr[param + GET_OFF(from)]);
878 mov(reg_to, ptr[param + GET_OFF(to)]);
879 mov(reg_work_amount, ptr[param + GET_OFF(work_amount)]);
881 cmp(reg_work_amount, simd_w);
882 jl("reminder_loop_start", T_NEAR);
884 L("vectorized_loop_start");
886 uni_vmovups(vmm_src, ptr[reg_from]);
887 eltwise_injector->compute_vector(vmm_src.getIdx());
888 uni_vmovups(ptr[reg_to], vmm_src);
893 sub(reg_work_amount, simd_w);
894 cmp(reg_work_amount, simd_w);
895 jge("vectorized_loop_start", T_NEAR);
897 L("vectorized_loop_end");
899 L("reminder_loop_start");
901 cmp(reg_work_amount, 0);
902 jle("reminder_loop_end", T_NEAR);
904 movss(xmm_src, ptr[reg_from]);
905 eltwise_injector->compute_vector(xmm_src.getIdx());
906 movss(ptr[reg_to], xmm_src);
908 add(reg_from, sizeof(float));
909 add(reg_to, sizeof(float));
911 dec(reg_work_amount);
912 jmp("reminder_loop_start", T_NEAR);
914 L("reminder_loop_end");
918 eltwise_injector->prepare_table();
920 ker_ = (decltype(ker_))this->getCode();
923 ~jit_uni_kernel_fwd_f32() {
924 delete eltwise_injector;
928 using Vmm = typename utils::conditional3<isa == sse42, Xmm,
929 isa == avx2, Ymm, Zmm>::type;
931 const int simd_w = cpu_isa_traits<isa>::vlen / sizeof(float);
932 const int vlen = cpu_isa_traits<isa>::vlen;
934 Reg64 reg_from = rax;
936 Reg64 reg_work_amount = rsi;
937 Reg64 imm_addr64 = rbx;
939 Xmm xmm_src = Xmm(1);
940 Vmm vmm_src = Vmm(1);
942 jit_uni_eltwise_injector_f32<isa>* eltwise_injector;
947 template <cpu_isa_t isa>
948 status_t jit_uni_eltwise_fwd_t<isa>::pd_t::init() {
949 using namespace alg_kind;
951 assert(engine()->kind() == engine_kind::cpu);
952 bool ok = true && mayiuse(isa)
953 && utils::one_of(desc()->prop_kind, prop_kind::forward_training,
954 prop_kind::forward_inference)
955 && utils::everyone_is(data_type::f32, desc()->data_desc.data_type)
956 && !has_zero_dim_memory()
957 && utils::implication(isa > avx2, utils::one_of(desc()->alg_kind,
958 eltwise_relu, eltwise_elu))
959 && utils::implication(isa == sse42 || isa == avx2, utils::one_of(
960 desc()->alg_kind, eltwise_relu, eltwise_tanh, eltwise_elu,
961 eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
962 eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic))
963 && memory_desc_wrapper(src_pd()).is_dense()
964 && attr()->has_default_values();
966 return ok ? status::success : status::unimplemented;
969 template <cpu_isa_t isa>
970 jit_uni_eltwise_fwd_t<isa>::jit_uni_eltwise_fwd_t(const pd_t *pd,
971 const input_vector &inputs, const output_vector &outputs)
972 : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd), kernel_(nullptr) {
973 const auto &desc = *conf_.desc();
974 switch (desc.alg_kind) {
975 case alg_kind::eltwise_relu:
976 kernel_ = new jit_uni_relu_kernel_f32<isa>(desc); break;
978 kernel_ = new jit_uni_kernel_fwd_f32<isa>(desc);
982 template <cpu_isa_t isa>
983 jit_uni_eltwise_fwd_t<isa>::~jit_uni_eltwise_fwd_t()
986 template <cpu_isa_t isa>
987 void jit_uni_eltwise_fwd_t<isa>::execute_forward() {
988 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
989 auto dst = reinterpret_cast<data_t *>(this->memory(0));
991 const memory_desc_wrapper data_d(conf_.src_pd());
993 const size_t nelems = data_d.nelems();
995 src += data_d.blocking_desc().offset_padding;
996 dst += data_d.blocking_desc().offset_padding;
998 parallel(0, [&](const int ithr, const int nthr) {
999 size_t start{0}, end{0};
1001 const int cache_line = 16;
1003 balance211(utils::div_up(nelems, cache_line), nthr, ithr, start, end);
1004 start = nstl::min(nelems, start * cache_line);
1005 end = nstl::min(nelems, end * cache_line);
1007 auto arg = jit_args();
1008 arg.from = &src[start];
1009 arg.for_comparison = &src[start];
1010 arg.to = &dst[start];
1011 arg.work_amount = end - start;
1012 if (arg.work_amount)
1017 template <cpu_isa_t isa>
1018 status_t jit_uni_eltwise_bwd_t<isa>::pd_t::init() {
1019 assert(engine()->kind() == engine_kind::cpu);
1022 && desc()->prop_kind == prop_kind::backward_data
1023 && utils::one_of(desc()->alg_kind, alg_kind::eltwise_relu)
1024 && src_pd()->desc()->data_type == data_type::f32
1025 && !has_zero_dim_memory()
1027 && memory_desc_wrapper(src_pd()).is_dense()
1028 && memory_desc_wrapper(diff_dst_pd()) == memory_desc_wrapper(src_pd())
1029 && attr()->has_default_values();
1031 return ok ? status::success : status::unimplemented;
1034 template <cpu_isa_t isa>
1035 jit_uni_eltwise_bwd_t<isa>::jit_uni_eltwise_bwd_t(const pd_t *pd,
1036 const input_vector &inputs, const output_vector &outputs)
1037 : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd), kernel_(nullptr) {
1038 const auto &desc = *conf_.desc();
1039 switch (desc.alg_kind) {
1040 case alg_kind::eltwise_relu:
1041 kernel_ = new jit_uni_relu_kernel_f32<isa>(desc); break;
1042 default: assert(!"unknown eltwise alg_kind");
1046 template <cpu_isa_t isa>
1047 jit_uni_eltwise_bwd_t<isa>::~jit_uni_eltwise_bwd_t()
1050 template <cpu_isa_t isa>
1051 void jit_uni_eltwise_bwd_t<isa>::execute_backward() {
1052 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
1053 auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
1054 auto diff_src = reinterpret_cast<data_t *>(this->memory(0));
1056 const memory_desc_wrapper data_d(conf_.src_pd());
1057 const memory_desc_wrapper diff_data_d(conf_.diff_src_pd());
1059 const size_t nelems = data_d.nelems();
1061 src += data_d.blocking_desc().offset_padding;
1062 diff_dst += diff_data_d.blocking_desc().offset_padding;
1063 diff_src += diff_data_d.blocking_desc().offset_padding;
1065 parallel(0, [&](const int ithr, const int nthr) {
1066 size_t start{0}, end{0};
1068 const int cache_line = 16;
1070 balance211(utils::div_up(nelems, cache_line), nthr, ithr, start, end);
1071 start = nstl::min(nelems, start * cache_line);
1072 end = nstl::min(nelems, end * cache_line);
1074 auto arg = jit_args();
1075 arg.from = &diff_dst[start];
1076 arg.to = &diff_src[start];
1077 arg.for_comparison = &src[start];
1078 arg.work_amount = end - start;
1079 if (arg.work_amount)
1084 template struct jit_uni_eltwise_fwd_t<sse42>;
1085 template struct jit_uni_eltwise_bwd_t<sse42>;
1086 template struct jit_uni_eltwise_fwd_t<avx2>;
1087 template struct jit_uni_eltwise_bwd_t<avx2>;
1088 template struct jit_uni_eltwise_fwd_t<avx512_common>;
1089 template struct jit_uni_eltwise_bwd_t<avx512_common>;