Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_eltwise.cpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <mkldnn_types.h>
18 #include "mkldnn_types.h"
19 #include "mkldnn_thread.hpp"
20 #include "nstl.hpp"
21 #include "utils.hpp"
22 #include "jit_generator.hpp"
23
24 #include "jit_uni_eltwise.hpp"
25
26 #define GET_OFF(field) offsetof(jit_args, field)
27
28 namespace mkldnn {
29 namespace impl {
30 namespace cpu {
31
32 using namespace Xbyak;
33
34 template <cpu_isa_t isa>
35 void jit_uni_eltwise_injector_f32<isa>::injector_preamble(size_t start_idx,
36         size_t end_idx) {
37     preserved_vecs_count = 0;
38     vecs_to_preserve = (size_t)aux_vecs_count(alg_);
39     start_idx_tail = start_idx;
40
41     // For sse42 mask register has to be Xmm(0)
42     if (isa == sse42 && vecs_to_preserve > 0) {
43         size_t idx = 0;
44         assert(idx < start_idx);
45         preserved_vec_idxs[preserved_vecs_count++] = idx;
46     }
47
48     for (size_t idx = preserved_vecs_count; idx < vecs_count; idx++) {
49         if (preserved_vecs_count >= vecs_to_preserve) break;
50         if (start_idx <= idx && idx < end_idx) continue;
51
52         preserved_vec_idxs[preserved_vecs_count++] = idx;
53     }
54
55     size_t preserved_vecs_count_tail = vecs_to_preserve - preserved_vecs_count;
56     for (size_t i = 0; i < preserved_vecs_count_tail; i++) {
57         preserved_vec_idxs[preserved_vecs_count++] = start_idx_tail++;
58     }
59
60     assert(preserved_vecs_count == vecs_to_preserve);
61
62     if (save_state_) {
63         h->push(p_table);
64
65         if (preserved_vecs_count)
66             h->sub(h->rsp, preserved_vecs_count * vlen);
67
68         for (size_t i = 0; i < preserved_vecs_count; ++i)
69             h->uni_vmovups(h->ptr[h->rsp + i * vlen],
70                     Vmm(preserved_vec_idxs[i]));
71
72         load_table_addr();
73     }
74
75     assign_regs();
76 }
77
78 template <cpu_isa_t isa>
79 void jit_uni_eltwise_injector_f32<isa>::injector_preamble_tail(size_t start_idx)
80 {
81     size_t tail_vecs_to_preserve = start_idx_tail - start_idx;
82     if (tail_vecs_to_preserve == 0) return;
83
84     const int idx_off = vecs_to_preserve - tail_vecs_to_preserve;
85
86     if (save_state_) {
87         if (idx_off)
88             h->add(h->rsp, idx_off * vlen);
89
90         for (size_t i = 0; i < tail_vecs_to_preserve; ++i)
91             h->uni_vmovups(Vmm(preserved_vec_idxs[idx_off + i]),
92                     h->ptr[h->rsp + i * vlen]);
93     }
94
95     for (size_t i = 0; i < tail_vecs_to_preserve; ++i)
96         preserved_vec_idxs[idx_off + i] += tail_vecs_to_preserve;
97
98     if (save_state_) {
99         for (size_t i = 0; i < tail_vecs_to_preserve; ++i)
100             h->uni_vmovups(h->ptr[h->rsp + i * vlen],
101                     Vmm(preserved_vec_idxs[idx_off + i]));
102
103         if (idx_off)
104             h->sub(h->rsp, idx_off * vlen);
105     }
106
107     assign_regs();
108 }
109
110 template <cpu_isa_t isa>
111 void jit_uni_eltwise_injector_f32<isa>::injector_postamble() {
112     if (!save_state_) return;
113
114     for (size_t i = 0; i < preserved_vecs_count; ++i)
115         h->uni_vmovups(Vmm(preserved_vec_idxs[i]),
116                 h->ptr[h->rsp + i * vlen]);
117
118     if (preserved_vecs_count)
119         h->add(h->rsp, preserved_vecs_count * vlen);
120
121     h->pop(p_table);
122 }
123
124 template <cpu_isa_t isa>
125 void jit_uni_eltwise_injector_f32<isa>::assign_regs() {
126     vmm_mask = Vmm(preserved_vec_idxs[0]);
127     vmm_aux0 = Vmm(preserved_vec_idxs[0]);
128     vmm_aux1 = Vmm(preserved_vec_idxs[1]);
129     vmm_aux2 = Vmm(preserved_vec_idxs[2]);
130     vmm_aux3 = Vmm(preserved_vec_idxs[3]);
131     vmm_aux4 = Vmm(preserved_vec_idxs[4]);
132 }
133
134 template <cpu_isa_t isa>
135 void jit_uni_eltwise_injector_f32<isa>::exp_compute_vector(const Vmm &vmm_src) {
136     h->uni_vminps(vmm_src, vmm_src, table_val(10));
137     h->uni_vmaxps(vmm_src, vmm_src, table_val(11));
138     h->uni_vmovups(vmm_aux0, vmm_src);
139     //calculate exp(x)
140     // fx = x * log2ef + 0.5
141     h->uni_vmulps(vmm_src, vmm_src, table_val(2));
142     h->uni_vaddps(vmm_src, vmm_src, table_val(1));
143
144     // tmp = floorf(fx)
145     if (isa == avx512_common) {
146         h->vcvtps2dq(vmm_aux1 | h->T_rd_sae, vmm_src);
147         h->vcvtdq2ps(vmm_aux1, vmm_aux1);
148
149         h->vcmpps(k_mask, vmm_aux1, vmm_src, _cmp_nle_us);
150         h->vmovups(vmm_aux3 | k_mask | h->T_z, table_val(0));
151
152         h->uni_vsubps(vmm_aux1, vmm_aux1, vmm_aux3);
153     } else {
154         h->uni_vroundps(vmm_aux1, vmm_src, _op_floor);
155     }
156
157     //keep fx for further computations
158     h->uni_vmovups(vmm_src, vmm_aux1); //vmm_src = fx
159
160     //x = x - fx * ln2
161     h->uni_vfnmadd231ps(vmm_aux0, vmm_aux1, table_val(3));
162
163     // compute 2^n
164     h->uni_vcvtps2dq(vmm_aux1, vmm_src);
165     h->uni_vpaddd(vmm_aux1, vmm_aux1, table_val(4));
166     h->uni_vpslld(vmm_aux1, vmm_aux1, 23); //Vmm(6) = 2^-fx
167
168     // y = p5
169     h->uni_vmovups(vmm_src, table_val(9));
170     // y = y * x + p4
171     h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(8));
172     // y = y * x + p3
173     h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(7));
174     // y = y * x + p2
175     h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(6));
176     // y = y * x + p1
177     h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(0));
178     // y = y * x + p0
179     h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(5));  //exp(q)
180     // y = y * 2^n
181     h->uni_vmulps(vmm_src, vmm_src, vmm_aux1);
182 }
183
184 template <cpu_isa_t isa>
185 void jit_uni_eltwise_injector_f32<isa>::relu_compute_vector(const Vmm &vmm_src)
186 {
187     const int alpha_off = 0, zero_off = 1;
188
189     h->uni_vmovups(vmm_aux1, vmm_src);
190     if (isa == sse42) {
191         h->movups(vmm_mask, vmm_src);
192         h->mulps(vmm_src, table_val(alpha_off));
193         h->cmpps(vmm_mask, table_val(zero_off), _cmp_nle_us);
194         h->blendvps(vmm_src, vmm_aux1);
195     } else if (isa == avx2) {
196         h->vmulps(vmm_src, vmm_src, table_val(alpha_off));
197         h->vcmpgtps(vmm_mask, vmm_aux1, table_val(zero_off));
198         h->vblendvps(vmm_src, vmm_src, vmm_aux1, vmm_mask);
199     } else if (isa == avx512_common) {
200         h->vmulps(vmm_src, vmm_src, table_val(alpha_off));
201         h->vcmpps(k_mask, vmm_aux1, table_val(zero_off), _cmp_nle_us);
202         h->vblendmps(vmm_src | k_mask, vmm_src, vmm_aux1);
203     }
204 }
205
206 template <cpu_isa_t isa>
207 void jit_uni_eltwise_injector_f32<isa>::relu_zero_ns_compute_vector(
208         const Vmm &vmm_src) {
209     const int zero_off = 1;
210     h->uni_vmaxps(vmm_src, vmm_src, table_val(zero_off));
211 }
212
213 template <cpu_isa_t isa>
214 void jit_uni_eltwise_injector_f32<isa>::elu_compute_vector(const Vmm &vmm_src) {
215     const int alpha_off = 23, zero_off = 24;
216
217     // compute exponent
218     h->uni_vmovups(vmm_aux2, vmm_src);
219     exp_compute_vector(vmm_src);
220
221     // alpha * (exp(x) - 1)
222     h->uni_vsubps(vmm_src, vmm_src, table_val(0));
223     h->uni_vmulps(vmm_src, vmm_src, table_val(alpha_off));
224
225     // combine with mask
226     if (isa == sse42) {
227         h->pxor(vmm_mask, vmm_mask);
228         h->cmpps(vmm_mask,  vmm_aux2, _cmp_le_os);
229         h->blendvps(vmm_src, vmm_aux2);
230     } else if (isa == avx2) {
231         h->uni_vcmpgtps(vmm_mask, vmm_aux2, table_val(zero_off));
232         h->uni_vblendvps(vmm_src, vmm_src, vmm_aux2, vmm_mask);
233     } else if (isa == avx512_common) {
234         h->vcmpps(k_mask, vmm_aux2, table_val(zero_off), _cmp_nle_us);
235         h->vblendmps(vmm_src | k_mask, vmm_src, vmm_aux2);
236     }
237 }
238
239 template <cpu_isa_t isa>
240 void jit_uni_eltwise_injector_f32<isa>::tanh_compute_vector(const Vmm &vmm_src)
241 {
242     // # comes from Taylor expansion error bound
243     //  > linear_sat_point = single(sqrt(3) * 1b-12);
244     // # comes from the exp formula cancellation
245     //  > exp_bound_point = (single(log(3)/2));
246     // # comes from rounding accuracy in float
247     //  > one_sat_point = round(atanh(1 - 1b-25), single, RU);
248     //  > P = fpminimax(f, [|1, 3, 5, 7, 9|], [|24... |],
249     //            [linear_sat_point, exp_bound_point], relative, floating);
250     //  > err_bound = D(sup(supnorm(P, tanh(x),
251     //          [linear_sat_point, exp_bound_point], relative, theta)));
252     //    0x1.fffd6f00b9539p-25
253     //  > P;
254     //    x * (0x1.fffffep-1 + x^0x1p1 * (-0x1.55539ep-2 + x^0x1p1 *
255     //        (0x1.10be3ep-3 + x^0x1p1 * (-0x1.ae57b4p-5
256     //        + x^0x1p1 * 0x1.09fa1p-6))))
257
258     // register mapping
259     // vmm_src contains input
260     // vmm_aux0 contains mask of currently valid results.
261     //     1 is need computation, 0 is already computed
262     // vmm_aux1 contains current output
263     // vmm_aux2, vmm_aux3 contains auxiliary values
264     // vmm_aux4 contains the original sign of inputs
265
266     Label end_tanh_label;
267
268     auto test_exit =[&](Xbyak::Address threshold){
269         // is not necessary for >AVX, but should not matter on perf
270         h->uni_vmovups(vmm_aux0, vmm_src);
271         if (isa == avx512_common){
272             h->vcmpps(k_mask, vmm_aux0, threshold, 0x5);
273             h->kortestw(k_mask, k_mask);
274         } else {
275             h->uni_vcmpgeps(vmm_aux0, vmm_aux0, threshold);
276             h->uni_vtestps(vmm_aux0, vmm_aux0);
277         }
278         h->jz(end_tanh_label, Xbyak::CodeGenerator::T_NEAR);
279     };
280
281     auto blend_results=[&](Vmm vmm_partial_res){
282         if (isa == avx512_common)
283             h->vblendmps(vmm_aux1 | k_mask, vmm_aux1, vmm_partial_res);
284         else
285             h->uni_vblendvps(vmm_aux1, vmm_aux1, vmm_partial_res, vmm_aux0);
286     };
287
288     // because tanh(x) = -tanh(-x), we extract sign to make x postive
289     // and reapply sign at the end
290     // mov is not necessary for >AVX, but should not matter for performance
291     h->uni_vmovups(vmm_aux4, vmm_src);
292     h->uni_vandps(vmm_aux4, vmm_aux4, table_val(12));
293     h->uni_vandps(vmm_src, vmm_src, table_val(17));
294
295     // if x < linear_sat_point for all inputs, we just return the input
296     h->uni_vmovups(vmm_aux1, vmm_src);
297     test_exit(table_val(13));
298
299     // if one of the mask is one, we have to compute an better approx
300     h->uni_vmovups(vmm_aux2, vmm_src);
301     h->uni_vmulps(vmm_aux2, vmm_aux2, vmm_aux2);
302     h->uni_vmovups(vmm_aux3, table_val(22));
303     h->uni_vfmadd213ps(vmm_aux3, vmm_aux2, table_val(21));
304     h->uni_vfmadd213ps(vmm_aux3, vmm_aux2, table_val(20));
305     h->uni_vfmadd213ps(vmm_aux3, vmm_aux2, table_val(19));
306     h->uni_vfmadd213ps(vmm_aux3, vmm_aux2, table_val(18));
307     h->uni_vmulps(vmm_aux3, vmm_aux3, vmm_src);
308
309     // we blend only the result that need update
310     blend_results(vmm_aux3);
311
312     // if x < exp_bound_point, we go to return point
313     test_exit(table_val(14));
314
315     // if not we use a better approx 1 - 2 / (1 + exp(2x))
316     // compute 2x
317     h->uni_vmovups(vmm_aux3, vmm_src);
318     h->uni_vaddps(vmm_aux3, vmm_aux3, vmm_aux3);
319
320     // Compute exp(2x)
321     // We need to save kmask, vmm_aux0, vmm_aux1 and vmm_src as exp can use them
322     // vmm_src is not more read afterwards, so we do not have to save it
323     auto stack_size = 3 * vlen + (isa == avx512_common) * 4;
324     h->sub(h->rsp, stack_size);
325     h->uni_vmovups(h->ptr[h->rsp + 0 * vlen], vmm_aux0);
326     h->uni_vmovups(h->ptr[h->rsp + 1 * vlen], vmm_aux1);
327     h->uni_vmovups(h->ptr[h->rsp + 2 * vlen], vmm_src);
328     if (isa == avx512_common)
329         h->kmovw(h->ptr[h->rsp + 3 * vlen], k_mask);
330
331     exp_compute_vector(vmm_aux3);
332
333     h->uni_vmovups(vmm_aux0, h->ptr[h->rsp + 0 * vlen]);
334     h->uni_vmovups(vmm_aux1, h->ptr[h->rsp + 1 * vlen]);
335     h->uni_vmovups(vmm_src, h->ptr[h->rsp + 2 * vlen]);
336     if (isa == avx512_common)
337         h->kmovw(k_mask, h->ptr[h->rsp + 3 * vlen]);
338     h->add(h->rsp, stack_size);
339
340     // 1 + exp(2x)
341     h->uni_vaddps(vmm_aux3, vmm_aux3, table_val(0));
342
343     // 1 - 2 / (1 + exp(2x))
344     h->uni_vmovups(vmm_aux2, table_val(16));
345     h->uni_vdivps(vmm_aux2, vmm_aux2, vmm_aux3);
346     h->uni_vaddps(vmm_aux2, vmm_aux2, table_val(0));
347
348     // we blend only the result that need update
349     blend_results(vmm_aux2);
350
351     // finally, we saturate to 1 if needed
352     // TODO: maybe move that up if most inputs saturate in practice
353     if (isa == avx512_common)
354         h->vcmpps(k_mask, vmm_aux0, table_val(15), 0x5);
355     else {
356         h->uni_vmovups(vmm_aux0, vmm_src);
357         h->uni_vcmpgeps(vmm_aux0, vmm_aux0, table_val(15));
358     }
359     h->uni_vmovups(vmm_aux2, table_val(0));
360     blend_results(vmm_aux2);
361
362     h->L(end_tanh_label);
363     {
364         // we apply the sign of x to the result and we are done
365         h->uni_vmovups(vmm_src, vmm_aux1);
366         h->uni_vpxor(vmm_src, vmm_src, vmm_aux4);
367     }
368 }
369
370 template <cpu_isa_t isa>
371 void jit_uni_eltwise_injector_f32<isa>::square_compute_vector(
372         const Vmm &vmm_src) {
373     h->uni_vmulps(vmm_src, vmm_src, vmm_src);
374 }
375
376 template <cpu_isa_t isa>
377 void jit_uni_eltwise_injector_f32<isa>::abs_compute_vector(const Vmm &vmm_src) {
378     // compute abs(x) = _mm_and_ps(x, 01111..111));
379     h->uni_vandps(vmm_src, vmm_src, table_val(0));
380 }
381
382 template <cpu_isa_t isa>
383 void jit_uni_eltwise_injector_f32<isa>::sqrt_compute_vector(const Vmm &vmm_src)
384 {
385     if (isa == avx512_common) {
386         h->vcmpps(k_mask, vmm_src, table_val(0), _cmp_nle_us);
387         h->uni_vsqrtps(vmm_aux1, vmm_src);
388         h->uni_vmovups(vmm_src, table_val(0));
389         h->vblendmps(vmm_src | k_mask, vmm_src, vmm_aux1);
390     } else {
391         h->uni_vmovups(vmm_mask, vmm_src);
392         h->uni_vcmpgtps(vmm_mask, vmm_mask, table_val(0));
393         h->uni_vsqrtps(vmm_aux1, vmm_src);
394         h->uni_vmovups(vmm_src, table_val(0));
395         h->uni_vblendvps(vmm_src, vmm_src, vmm_aux1, vmm_mask);
396     }
397 }
398
399 template <cpu_isa_t isa>
400 void jit_uni_eltwise_injector_f32<isa>::linear_compute_vector(
401         const Vmm &vmm_src) {
402     // compute x = alpha * x + beta;
403     h->uni_vmovups(vmm_aux0, table_val(0));
404     h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(1));
405 }
406
407 template <cpu_isa_t isa>
408 void jit_uni_eltwise_injector_f32<isa>::bounded_relu_compute_vector(
409         const Vmm &vmm_src) {
410     // compute bounded relu */
411     h->uni_vmaxps(vmm_src, vmm_src, table_val(1));
412     h->uni_vminps(vmm_src, vmm_src, table_val(0));
413 }
414
415 template <cpu_isa_t isa>
416 void jit_uni_eltwise_injector_f32<isa>::soft_relu_compute_vector(
417         const Vmm &vmm_src) {
418     // duplicate src
419     h->uni_vmovups(vmm_aux2, vmm_src);
420
421     h->uni_vminps(vmm_src, vmm_src, table_val(24));
422     h->uni_vmaxps(vmm_src, vmm_src, table_val(25));
423     h->uni_vmovups(vmm_aux1, vmm_src);
424     // calculate exp(x)
425     // fx = x * log2ef + 0.5
426     h->uni_vmulps(vmm_src, vmm_src, table_val(2));
427     h->uni_vaddps(vmm_src, vmm_src, table_val(1));
428
429     // tmp = floorf(fx)
430     if (isa == avx512_common) {
431         h->vcvtps2dq(vmm_aux0 | h->T_rd_sae, vmm_src);
432         h->vcvtdq2ps(vmm_aux0, vmm_aux0);
433
434         h->vcmpps(k_mask, vmm_aux0, vmm_src, _cmp_nle_us);
435         h->vmovups(vmm_aux3 | k_mask | h->T_z, table_val(0));
436
437         h->vsubps(vmm_aux0, vmm_aux0, vmm_aux3);
438     } else {
439         h->uni_vroundps(vmm_aux0, vmm_src, _op_floor);
440     }
441
442     // keep fx for further computations
443     h->uni_vmovups(vmm_src, vmm_aux0); //vmm_src = fx
444     // calculation fx * ln2
445     h->uni_vmulps(vmm_aux0, vmm_aux0, table_val(3));
446     // x = x - fx * ln2
447     h->uni_vsubps(vmm_aux1, vmm_aux1, vmm_aux0);
448     // y = p5
449     h->uni_vmovups(vmm_aux3, table_val(22));
450     // y = y * x + p4
451     h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(21));
452     // y = y * x + p3
453     h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(20));
454     // y = y * x + p2
455     h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(19));
456     // y = y * x + p1
457     h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(0));
458     // y = y * x + p0
459     h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(17));
460
461     // compute 2^(-n)
462     if (isa == avx512_common) {
463         h->vmulps(vmm_aux1, vmm_src, table_val(23));
464         h->vcvtps2dq(vmm_aux1, vmm_aux1);
465     } else {
466         h->uni_vcvtps2dq(vmm_aux1, vmm_src);
467         h->uni_vpsignd(vmm_aux1, vmm_aux1, table_val(23));
468     }
469
470     h->uni_vpaddd(vmm_aux1, vmm_aux1, table_val(4));
471     h->uni_vpslld(vmm_aux1, vmm_aux1, 23); //vmm_aux1 = 2^-fx
472     // calculate ln(1 + y)
473     h->uni_vaddps(vmm_aux3, vmm_aux3, vmm_aux1);
474     // x = y; y is free; keep x for further computations
475     h->uni_vmovups(vmm_src, vmm_aux3);
476     // frexp()
477     h->uni_vpsrld(vmm_src, vmm_src, 23);
478     h->uni_vcvtdq2ps(vmm_src, vmm_src);
479     // got n. where n is x = 2^n * y. y = 0.5 .. 1
480     h->uni_vsubps(vmm_src, vmm_src, table_val(5));
481
482     h->uni_vandps(vmm_aux3, vmm_aux3, table_val(6));
483     // got y. (mantisa)  0.5 < y < 1
484     h->uni_vorps(vmm_aux3, vmm_aux3, table_val(7));
485     // y  = y - 1
486     h->uni_vsubps(vmm_aux3, vmm_aux3, table_val(0));
487     // y = p8
488     h->uni_vmovups(vmm_aux1, table_val(16));
489     // y = y * x + p7
490     h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(15));
491     // y = y * x + p6
492     h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(14));
493     // y = y * x + p5
494     h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(13));
495     // y = y * x + p4
496     h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(12));
497     // y = y * x + p3
498     h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(11));
499     // y = y * x + p2
500     h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(10));
501     // y = y * x + p1
502     h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(9));
503     // y = y * x + p0 ; p0 = 0
504     h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(8));
505     //calculate ln(2) * n
506     h->uni_vmulps(vmm_src, vmm_src, table_val(3));
507     h->uni_vaddps(vmm_aux1, vmm_aux1, vmm_src);
508     h->uni_vaddps(vmm_aux1, vmm_aux1, vmm_aux0);
509
510     // get vmm_mask = src > max logf
511     h->uni_vmovups(vmm_mask, vmm_aux2);
512     if (isa == avx512_common) {
513         // y = (x < max log f) ? soft_relu(x) : x
514         h->vcmpps(k_mask, vmm_mask, table_val(24), _cmp_nle_us);
515         h->vblendmps(vmm_aux1 | k_mask, vmm_aux1, vmm_aux2);
516     } else {
517         // y = (x < max log f) ? soft_relu(x) : x
518         h->uni_vcmpgtps(vmm_mask, vmm_mask, table_val(24));
519         h->uni_vblendvps(vmm_aux1, vmm_aux1, vmm_aux2, vmm_mask);
520     }
521
522     h->uni_vmovups(vmm_src, vmm_aux1);
523 }
524
525 template <cpu_isa_t isa>
526 void jit_uni_eltwise_injector_f32<isa>::logistic_compute_vector(
527         const Vmm &vmm_src) {
528     // we store the original sign and make x negative
529     // IMPORTANT: we assume vmm_aux0 to be xmm0, as for sse4.2 path it is required
530     // IMPORTANT: we use vmm_aux2 for the mask as exp_compute does not use it.
531     h->uni_vmovups(vmm_aux2, vmm_src);
532     h->uni_vandps(vmm_aux2, vmm_aux2, table_val(12));
533     h->uni_vorps(vmm_src, vmm_src, table_val(12));
534
535     exp_compute_vector(vmm_src);
536     // dup exp(x)
537     h->uni_vmovups(vmm_aux1, vmm_src);
538     // (exp(x) + 1)
539     h->uni_vaddps(vmm_aux1, vmm_aux1, table_val(0));
540     // y = exp(x) / (exp(x) + 1)
541     h->uni_vdivps(vmm_src, vmm_src, vmm_aux1);
542
543     // Now we have to apply the "symmetry" based on original sign
544     h->uni_vmovups(vmm_aux3, table_val(0));
545     h->uni_vsubps(vmm_aux3, vmm_aux3, vmm_src);
546     if (isa == avx512_common) {
547         h->vptestmd(k_mask, vmm_aux2, vmm_aux2);
548         h->vblendmps(vmm_aux3 | k_mask, vmm_aux3, vmm_src);
549     } else {
550         h->uni_vmovups(vmm_aux0, vmm_aux2);// The mask should be xmm0 for sse4.2
551         h->uni_vblendvps(vmm_aux3, vmm_aux3, vmm_src, vmm_aux0);
552     }
553     h->uni_vmovups(vmm_src, vmm_aux3);
554 }
555
556 template <cpu_isa_t isa>
557 void jit_uni_eltwise_injector_f32<isa>::clamp_compute_vector(
558         const Vmm &vmm_src) {
559     // compute clamp */
560     h->uni_vmaxps(vmm_src, vmm_src, table_val(1));
561     h->uni_vminps(vmm_src, vmm_src, table_val(0));
562 }
563
564 template <cpu_isa_t isa>
565 void jit_uni_eltwise_injector_f32<isa>::relu_prepare_table() {
566     for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_));
567     for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0);
568 }
569
570 template <cpu_isa_t isa>
571 void jit_uni_eltwise_injector_f32<isa>::elu_prepare_table() {
572     const unsigned int cvals[] = {
573             0x3f800000, // [0] 1.0f
574             0x3f000000, // [1] 0.5f
575             0x3fb8aa3b, // [2] log2ef = 1.44269502f
576             0x3f317218, // [3] ln2f =   0.69314718f
577             0x0000007f, // [4] 0x7f
578             // exp(x) polynom
579             0x3f800001, // [5] p0 = 1.0000001f
580             0x3efffe85, // [6] p2 = 0.4999887f
581             0x3e2aaa3e, // [7] p3 = 0.16666505f
582             0x3d2bb1b1, // [8] p4 = 0.041917507f
583             0x3c091ec1, // [9] p5 = 0.008369149f
584             0x42b0c0a5, //[10] max logf = 88.3762589f
585             0xc1766666, //[11] min logf = -14.5f
586             // tanh(x) constants,
587             0x80000000, //[12] mask to extract sign
588             0x39ddb3d7, //[13] arg below which tanh(x) = x
589             0x3f0c9f54, //[14] arg below which pol approx is valid
590             0x41102cb4, //[15] arg after which tanh(x) = 1
591             0xc0000000, //[16] -2.0f
592             0x7fffffff, //[17] mask to make positive
593             // tanh pol approx
594             0x3f7fffff, //[18] p0
595             0xbeaaa9cf, //[19] p1
596             0x3e085f1f, //[20] p2
597             0xbd572bda, //[21] p3
598             0x3c84fd08, //[22] p4
599     };
600
601     for (size_t i = 0; i < sizeof(cvals) / sizeof(cvals[0]); ++i) {
602         for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(cvals[i]);
603     }
604
605     for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_));
606     for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0);
607 }
608
609 template <cpu_isa_t isa>
610 void jit_uni_eltwise_injector_f32<isa>::soft_relu_prepare_table() {
611     const unsigned int cvals[] = {
612             0x3f800000, // [0] 1.0f
613             0x3f000000, // [1] 0.5f
614             0x3fb8aa3b, // [2] log2ef = 1.44269502f
615             0x3f317218, // [3] ln2f =   0.69314718f
616             0x0000007f, // [4] 0x7f
617             0x42fc0000, // [5] 126
618             0x807fffff, // [6] and with (to get 0.5 * mantissa)
619             0x3f000000, // [7] or with (to get 0.5 * mantissa)
620             // ln(1 + x) polynomial
621             0xb2b4637d, // [8]  p0 = 0.0000000244f
622             0x3f7fff8e, // [9]  p1 = 0.9999976971f
623             0xbf001759, //[10]  p2 = -0.5002478215f
624             0x3ea70608, //[11]  p3 = 0.3272714505f
625             0xbea3d7bf, //[12]  p4 = -0.3153830071f
626             0xbe361d04, //[13]  p5 = -0.1701777461f
627             0xbfa8f1e6, //[14]  p6 = -1.3254635147f
628             0xbfe1e812, //[15]  p7 = -1.7971917960f
629             0xbfc4d30e, //[16]  p8 = -1.5652673123f
630             // exp(x) polynomial
631             0x3f800001, //[17]  p0 = 1.0000001f
632             0x3f800000, //[18]  p1 = 1.0f
633             0x3efffe85, //[19]  p2 = 0.4999887f
634             0x3e2aaa3e, //[20]  p3 = 0.16666505f
635             0x3d2bb1b1, //[21]  p4 = 0.041917507f
636             0x3c091ec1, //[22]  p5 = 0.008369149f
637             0xbf800000, //[23] is required for sign changing
638             0x42b0c0a5, //[24] max logf = 88.3762589f
639             0xc1766666  //[25] min logf = -14.5f
640     };
641
642     for (size_t i = 0; i < sizeof(cvals) / sizeof(cvals[0]); ++i) {
643         for (size_t d = 0; d < vlen / sizeof(float); ++d) {
644             h->dd(cvals[i]);
645         }
646     }
647 }
648
649 template <cpu_isa_t isa>
650 void jit_uni_eltwise_injector_f32<isa>::abs_prepare_table() {
651     for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0x7fffffff);
652 }
653
654 template <cpu_isa_t isa>
655 void jit_uni_eltwise_injector_f32<isa>::sqrt_prepare_table() {
656     for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0);
657 }
658
659 template <cpu_isa_t isa>
660 void jit_uni_eltwise_injector_f32<isa>::linear_prepare_table() {
661     for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_));
662     for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(beta_));
663 }
664
665 template <cpu_isa_t isa>
666 void jit_uni_eltwise_injector_f32<isa>::bounded_relu_prepare_table() {
667     for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_));
668     for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0);
669 }
670
671 template <cpu_isa_t isa>
672 void jit_uni_eltwise_injector_f32<isa>::clamp_prepare_table() {
673     for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_));
674     for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(beta_));
675 }
676
677 template <cpu_isa_t isa>
678 int jit_uni_eltwise_injector_f32<isa>::aux_vecs_count(alg_kind_t alg_) {
679     switch (alg_) {
680     case alg_kind::eltwise_relu: return (alpha_ == 0.f) ? 0 : 2;
681     case alg_kind::eltwise_elu: return 4;
682     case alg_kind::eltwise_tanh: return 5;
683     case alg_kind::eltwise_square: return 0;
684     case alg_kind::eltwise_abs: return 0;
685     case alg_kind::eltwise_sqrt: return 2;
686     case alg_kind::eltwise_linear: return 1;
687     case alg_kind::eltwise_bounded_relu: return 0;
688     case alg_kind::eltwise_soft_relu: return 4;
689     case alg_kind::eltwise_logistic: return 4;
690     case alg_kind::eltwise_clamp: return 0;
691     case alg_kind::eltwise_exp: return 4;
692     default: assert(!"unsupported eltwise algorithm");
693     }
694
695     return 0;
696 }
697
698 template <cpu_isa_t isa>
699 void jit_uni_eltwise_injector_f32<isa>::compute_body(size_t start_idx,
700         size_t end_idx) {
701     using namespace alg_kind;
702     for (size_t idx = start_idx; idx < end_idx; idx++) {
703         switch (alg_) {
704         case eltwise_relu:
705             if (alpha_ == 0.f) relu_zero_ns_compute_vector(Vmm(idx));
706             else relu_compute_vector(Vmm(idx));
707             break;
708         case eltwise_elu: elu_compute_vector(Vmm(idx)); break;
709         case eltwise_tanh: tanh_compute_vector(Vmm(idx)); break;
710         case eltwise_square: square_compute_vector(Vmm(idx)); break;
711         case eltwise_abs: abs_compute_vector(Vmm(idx)); break;
712         case eltwise_sqrt: sqrt_compute_vector(Vmm(idx)); break;
713         case eltwise_linear: linear_compute_vector(Vmm(idx)); break;
714         case eltwise_bounded_relu: bounded_relu_compute_vector(Vmm(idx)); break;
715         case eltwise_soft_relu: soft_relu_compute_vector(Vmm(idx)); break;
716         case eltwise_logistic: logistic_compute_vector(Vmm(idx)); break;
717         case eltwise_clamp: clamp_compute_vector(Vmm(idx)); break;
718         case eltwise_exp: exp_compute_vector(Vmm(idx)); break;
719         default: assert(!"unsupported eltwise algorithm");
720         }
721     }
722 }
723
724 template <cpu_isa_t isa>
725 void jit_uni_eltwise_injector_f32<isa>::compute_vector_range(size_t start_idx,
726         size_t end_idx) {
727     assert(start_idx < end_idx && end_idx <= vecs_count);
728
729     injector_preamble(start_idx, end_idx);
730     compute_body(start_idx_tail, end_idx);
731     injector_preamble_tail(start_idx);
732     compute_body(start_idx, start_idx_tail);
733     injector_postamble();
734 }
735
736 template <cpu_isa_t isa>
737 void jit_uni_eltwise_injector_f32<isa>::prepare_table(bool gen_table) {
738     using namespace alg_kind;
739
740     h->align(64);
741     h->L(l_table);
742
743     if (gen_table) {
744         switch (alg_) {
745         case eltwise_relu: relu_prepare_table(); break;
746         case eltwise_elu:
747         case eltwise_tanh:
748         case eltwise_logistic:
749         case eltwise_exp:
750             elu_prepare_table(); break;
751         case eltwise_soft_relu: soft_relu_prepare_table(); break;
752         case eltwise_abs: abs_prepare_table(); break;
753         case eltwise_sqrt: sqrt_prepare_table(); break;
754         case eltwise_linear: linear_prepare_table(); break;
755         case eltwise_bounded_relu: bounded_relu_prepare_table(); break;
756         case eltwise_square: break;
757         case eltwise_clamp: clamp_prepare_table(); break;
758         default: assert(!"unsupported eltwise algorithm");
759     }
760     }
761 }
762
763 template struct jit_uni_eltwise_injector_f32<avx512_common>;
764 template struct jit_uni_eltwise_injector_f32<avx2>;
765 template struct jit_uni_eltwise_injector_f32<sse42>;
766
767
768 struct jit_args {
769     const float *from;
770     const float *for_comparison;
771     const float *to;
772     size_t work_amount;
773 };
774
775 struct jit_uni_eltwise_kernel_f32 : public c_compatible {
776     const eltwise_desc_t &desc_;
777
778     void (*ker_)(const jit_args *);
779     void operator()(const jit_args *args) { assert(ker_); ker_(args); }
780
781     jit_uni_eltwise_kernel_f32(const eltwise_desc_t &desc)
782         : desc_(desc), ker_(nullptr) {}
783     virtual ~jit_uni_eltwise_kernel_f32() {}
784
785 protected:
786     bool is_bwd() const { return desc_.prop_kind == prop_kind::backward_data; }
787 };
788
789 /* jit kernels */
790 namespace {
791
792 template <cpu_isa_t isa>
793 struct jit_uni_relu_kernel_f32 : public jit_uni_eltwise_kernel_f32,
794     public jit_generator
795 {
796     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_relu_kernel_f32)
797
798     void compute_step(bool vectorize, const int uf, const int shift) {
799         for (int i = 0; i < uf; i++) {
800             if (vectorize) {
801                 uni_vmovups(Vmm(i + 1), ptr[reg_from + i * shift]);
802                 if (is_bwd())
803                     uni_vmovups(Vmm(uf + i + 1),
804                                 ptr[reg_for_comparison + i * shift]);
805             } else {
806                 movss(Xmm(i + 1), ptr[reg_from + i * shift]);
807                 if (is_bwd())
808                     movss(Xmm(uf + i + 1),
809                           ptr[reg_for_comparison + i * shift]);
810             }
811         }
812
813         if (isa == sse42) {
814             for (int i = 0; i < uf; i++) {
815                 movups(Vmm(2 * uf + i + 1), Vmm(i + 1));
816                 mulps(Vmm(2 * uf + i + 1), vmm_ns);
817
818                 Vmm mask = Vmm(0);
819                 if (is_bwd()) {
820                     movups(mask, Vmm(uf + i + 1));
821                     cmpps(mask, vmm_zero, _cmp_nle_us);
822                 } else {
823                     movups(mask, Vmm(i + 1));
824                     cmpps(mask, vmm_zero, _cmp_nle_us);
825                 }
826                 blendvps(Vmm(2 * uf + i + 1), Vmm(i + 1));
827             }
828         } else {
829             for (int i = 0; i < uf; i++) {
830                 vmulps(Vmm(2 * uf + i + 1), Vmm(i + 1), vmm_ns);
831                 if (isa == avx2) {
832                     if (is_bwd())
833                         vcmpgtps(vmm_mask, Vmm(uf + i + 1), vmm_zero);
834                     else
835                         vcmpgtps(vmm_mask, Vmm(i + 1), vmm_zero);
836
837                     vblendvps(Vmm(2 * uf + i + 1), Vmm(2 * uf + i + 1),
838                               Vmm(i + 1), vmm_mask);
839
840                 } else {
841                     if (is_bwd())
842                         vcmpps(k_mask, Vmm(uf + i + 1), vmm_zero, _cmp_nle_us);
843                     else
844                         vcmpps(k_mask, Vmm(i + 1), vmm_zero, _cmp_nle_us);
845                     vblendmps(Vmm(2 * uf + i + 1) | k_mask, Vmm(2 * uf + i + 1),
846                               Vmm(i + 1));
847                 }
848             }
849         }
850
851         for (int i = 0; i < uf; i++) {
852             if (vectorize) {
853                 uni_vmovups(ptr[reg_to + i * shift], Vmm(2 * uf + i + 1));
854             } else {
855                 movss(ptr[reg_to + i * shift], Xmm(2 * uf + i + 1));
856             }
857         }
858     }
859
860     jit_uni_relu_kernel_f32(const eltwise_desc_t &desc)
861         : jit_uni_eltwise_kernel_f32(desc), jit_generator() {
862         assert(desc.alg_kind == alg_kind::eltwise_relu);
863         assert(isa == sse42 || isa == avx2 || isa == avx512_common);
864
865         Reg64 param = abi_param1;
866
867         const int simd_w = cpu_isa_traits<isa>::vlen / sizeof(float);
868         const int loop_dec[] = {simd_w, 1};
869         const int uf[] = {1, 1};
870         const int shift[] = {cpu_isa_traits<isa>::vlen, sizeof(float)};
871         const bool loop_vectorize[] = {true, false};
872
873         this->preamble();
874
875         mov(reg_from, ptr[param + GET_OFF(from)]);
876         if (is_bwd())
877             mov(reg_for_comparison, ptr[param + GET_OFF(for_comparison)]);
878         mov(reg_to, ptr[param + GET_OFF(to)]);
879         mov(reg_work_amount, ptr[param + GET_OFF(work_amount)]);
880
881         mov(imm_addr64, float2int(desc.alpha));
882         movq(xmm_ns, imm_addr64);
883         uni_vbroadcastss(vmm_ns, xmm_ns);
884
885         uni_vpxor(vmm_zero, vmm_zero, vmm_zero);
886
887         Label loop_label[3];
888
889         for (int id = 0; id < 2; id++) {
890             L(loop_label[id]);
891             cmp(reg_work_amount, uf[id] * loop_dec[id] - 1);
892             jle(loop_label[id + 1], T_NEAR);
893
894             compute_step(loop_vectorize[id], uf[id], shift[id]);
895
896             add(reg_from, uf[id] * shift[id]);
897             add(reg_to, uf[id] * shift[id]);
898             if (is_bwd())
899                 add(reg_for_comparison, uf[id] * shift[id]);
900
901             sub(reg_work_amount, uf[id] * loop_dec[id]);
902             jmp(loop_label[id]);
903         }
904
905         L(loop_label[2]);
906         this->postamble();
907
908         ker_ = (decltype(ker_))this->getCode();
909     }
910
911 private:
912     using Vmm = typename utils::conditional3<isa == sse42, Xmm,
913                                              isa == avx2, Ymm, Zmm>::type;
914
915     Reg64 reg_from = rax;
916     Reg64 reg_for_comparison = is_bwd() ? rdx : reg_from;
917     Reg64 reg_to = r8;
918     Reg64 reg_work_amount = rsi;
919     Reg64 imm_addr64 = rbx;
920
921     Xmm xmm_ns = Xmm(14);
922
923     Vmm vmm_ns = Vmm(isa == avx512_common ? 30 : 14);
924     Vmm vmm_zero = Vmm(isa == avx512_common ? 31 : 15);
925
926     Vmm vmm_mask = Vmm(isa == avx512_common ? 28 : 12);
927     Opmask k_mask = Opmask(1);
928 };
929
930 template <cpu_isa_t isa>
931 struct jit_uni_kernel_fwd_f32: public jit_uni_eltwise_kernel_f32,
932     public jit_generator {
933     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_kernel_fwd_f32)
934
935     jit_uni_kernel_fwd_f32(const eltwise_desc_t &desc)
936         : jit_uni_eltwise_kernel_f32(desc), jit_generator() {
937
938         eltwise_injector_ = new jit_uni_eltwise_injector_f32<isa>(this,
939                 desc.alg_kind, desc.alpha, desc.beta, false, r9, Opmask(1));
940
941         using namespace alg_kind;
942
943         assert(is_bwd() == false);
944         assert(utils::one_of(desc.alg_kind, eltwise_tanh, eltwise_elu,
945                     eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
946                     eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic,
947                     eltwise_clamp, eltwise_exp));
948
949         preamble();
950
951         Reg64 param = abi_param1;
952         mov(reg_from, ptr[param + GET_OFF(from)]);
953         mov(reg_to, ptr[param + GET_OFF(to)]);
954         mov(reg_work_amount, ptr[param + GET_OFF(work_amount)]);
955         eltwise_injector_->load_table_addr();
956
957         Label reminder_loop_start, reminder_loop_end;
958         Label vectorized_loop_start, vectorized_loop_end;
959
960         cmp(reg_work_amount, simd_w);
961         jl(reminder_loop_start, T_NEAR);
962
963         L(vectorized_loop_start);
964
965         uni_vmovups(vmm_src, ptr[reg_from]);
966         eltwise_injector_->compute_vector(vmm_src.getIdx());
967         uni_vmovups(ptr[reg_to], vmm_src);
968
969         add(reg_from, vlen);
970         add(reg_to, vlen);
971
972         sub(reg_work_amount, simd_w);
973         cmp(reg_work_amount, simd_w);
974         jge(vectorized_loop_start, T_NEAR);
975
976         L(vectorized_loop_end);
977
978         L(reminder_loop_start);
979
980         cmp(reg_work_amount, 0);
981         jle(reminder_loop_end, T_NEAR);
982
983         movss(xmm_src, ptr[reg_from]);
984         eltwise_injector_->compute_vector(xmm_src.getIdx());
985         movss(ptr[reg_to], xmm_src);
986
987         add(reg_from, sizeof(float));
988         add(reg_to, sizeof(float));
989
990         dec(reg_work_amount);
991         jmp(reminder_loop_start, T_NEAR);
992
993         L(reminder_loop_end);
994
995         postamble();
996
997         eltwise_injector_->prepare_table();
998
999         ker_ = (decltype(ker_))this->getCode();
1000     }
1001
1002     ~jit_uni_kernel_fwd_f32() { delete eltwise_injector_; }
1003
1004 private:
1005     using Vmm = typename utils::conditional3<isa == sse42, Xmm,
1006                 isa == avx2, Ymm, Zmm>::type;
1007
1008     const int simd_w = cpu_isa_traits<isa>::vlen / sizeof(float);
1009     const int vlen   = cpu_isa_traits<isa>::vlen;
1010
1011     Reg64 reg_from = rax;
1012     Reg64 reg_to = r8;
1013     Reg64 reg_work_amount = rsi;
1014     Reg64 imm_addr64 = rbx;
1015
1016     Xmm xmm_src = Xmm(1);
1017     Vmm vmm_src = Vmm(1);
1018
1019     jit_uni_eltwise_injector_f32<isa> *eltwise_injector_;
1020 };
1021
1022 } /* namespace */
1023
1024 template <cpu_isa_t isa>
1025 status_t jit_uni_eltwise_fwd_t<isa>::pd_t::init() {
1026     using namespace alg_kind;
1027
1028     assert(engine()->kind() == engine_kind::cpu);
1029     bool ok = true && mayiuse(isa)
1030         && utils::one_of(desc()->prop_kind, prop_kind::forward_training,
1031                 prop_kind::forward_inference)
1032         && utils::everyone_is(data_type::f32, desc()->data_desc.data_type)
1033         && !has_zero_dim_memory()
1034         && utils::one_of(desc()->alg_kind, eltwise_relu, eltwise_tanh,
1035                 eltwise_elu, eltwise_square, eltwise_abs, eltwise_sqrt,
1036                 eltwise_linear, eltwise_bounded_relu, eltwise_soft_relu,
1037                 eltwise_logistic, eltwise_clamp, eltwise_exp)
1038         && memory_desc_wrapper(src_pd()).is_dense(true)
1039         && IMPLICATION(!memory_desc_wrapper(src_pd()).is_dense(false),
1040                 math::eltwise_fwd_preserves_zero(desc()->alg_kind, true))
1041         && attr()->has_default_values();
1042
1043     return ok ? status::success : status::unimplemented;
1044 }
1045
1046 template <cpu_isa_t isa>
1047 jit_uni_eltwise_fwd_t<isa>::jit_uni_eltwise_fwd_t(const pd_t *apd,
1048         const input_vector &inputs, const output_vector &outputs)
1049     : cpu_primitive_t(apd, inputs, outputs), kernel_(nullptr) {
1050     const auto &desc = *pd()->desc();
1051     switch (desc.alg_kind) {
1052     case alg_kind::eltwise_relu:
1053         kernel_ = new jit_uni_relu_kernel_f32<isa>(desc); break;
1054     default:
1055         kernel_ = new jit_uni_kernel_fwd_f32<isa>(desc);
1056     }
1057 }
1058
1059 template <cpu_isa_t isa>
1060 jit_uni_eltwise_fwd_t<isa>::~jit_uni_eltwise_fwd_t()
1061 { delete kernel_; }
1062
1063 template <cpu_isa_t isa>
1064 void jit_uni_eltwise_fwd_t<isa>::execute_forward() const {
1065     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
1066     auto dst = reinterpret_cast<data_t *>(this->memory(0));
1067
1068     const memory_desc_wrapper data_d(pd()->src_pd());
1069
1070     const size_t nelems = data_d.nelems(true);
1071
1072     src += data_d.blocking_desc().offset_padding;
1073     dst += data_d.blocking_desc().offset_padding;
1074
1075     parallel(0, [&](const int ithr, const int nthr) {
1076         size_t start{0}, end{0};
1077
1078         const int cache_line = 16;
1079
1080         balance211(utils::div_up(nelems, cache_line), nthr, ithr, start, end);
1081         start = nstl::min(nelems, start * cache_line);
1082         end = nstl::min(nelems, end * cache_line);
1083
1084         auto arg = jit_args();
1085         arg.from = &src[start];
1086         arg.for_comparison = &src[start];
1087         arg.to = &dst[start];
1088         arg.work_amount = end - start;
1089         if (arg.work_amount)
1090             (*kernel_)(&arg);
1091     });
1092 }
1093
1094 template <cpu_isa_t isa>
1095 status_t jit_uni_eltwise_bwd_t<isa>::pd_t::init() {
1096     assert(engine()->kind() == engine_kind::cpu);
1097
1098     bool ok = true
1099         && desc()->prop_kind == prop_kind::backward_data
1100         && utils::one_of(desc()->alg_kind, alg_kind::eltwise_relu)
1101         && src_pd()->desc()->data_type == data_type::f32
1102         && !has_zero_dim_memory()
1103         && mayiuse(isa)
1104         && memory_desc_wrapper(src_pd()).is_dense()
1105         && memory_desc_wrapper(diff_dst_pd()) == memory_desc_wrapper(src_pd())
1106         && attr()->has_default_values();
1107
1108     return ok ? status::success : status::unimplemented;
1109 }
1110
1111 template <cpu_isa_t isa>
1112 jit_uni_eltwise_bwd_t<isa>::jit_uni_eltwise_bwd_t(const pd_t *apd,
1113         const input_vector &inputs, const output_vector &outputs)
1114     : cpu_primitive_t(apd, inputs, outputs), kernel_(nullptr) {
1115     const auto &desc = *pd()->desc();
1116     switch (desc.alg_kind) {
1117     case alg_kind::eltwise_relu:
1118         kernel_ = new jit_uni_relu_kernel_f32<isa>(desc); break;
1119     default: assert(!"unknown eltwise alg_kind");
1120     }
1121 }
1122
1123 template <cpu_isa_t isa>
1124 jit_uni_eltwise_bwd_t<isa>::~jit_uni_eltwise_bwd_t()
1125 { delete kernel_; }
1126
1127 template <cpu_isa_t isa>
1128 void jit_uni_eltwise_bwd_t<isa>::execute_backward() const {
1129     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
1130     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
1131     auto diff_src = reinterpret_cast<data_t *>(this->memory(0));
1132
1133     const memory_desc_wrapper data_d(pd()->src_pd());
1134     const memory_desc_wrapper diff_data_d(pd()->diff_src_pd());
1135
1136     const size_t nelems = data_d.nelems();
1137
1138     src += data_d.blocking_desc().offset_padding;
1139     diff_dst += diff_data_d.blocking_desc().offset_padding;
1140     diff_src += diff_data_d.blocking_desc().offset_padding;
1141
1142     parallel(0, [&](const int ithr, const int nthr) {
1143         size_t start{0}, end{0};
1144
1145         const int cache_line = 16;
1146
1147         balance211(utils::div_up(nelems, cache_line), nthr, ithr, start, end);
1148         start = nstl::min(nelems, start * cache_line);
1149         end = nstl::min(nelems, end * cache_line);
1150
1151         auto arg = jit_args();
1152         arg.from = &diff_dst[start];
1153         arg.to = &diff_src[start];
1154         arg.for_comparison = &src[start];
1155         arg.work_amount = end - start;
1156         if (arg.work_amount)
1157             (*kernel_)(&arg);
1158     });
1159 }
1160
1161 template struct jit_uni_eltwise_fwd_t<sse42>;
1162 template struct jit_uni_eltwise_bwd_t<sse42>;
1163 template struct jit_uni_eltwise_fwd_t<avx2>;
1164 template struct jit_uni_eltwise_bwd_t<avx2>;
1165 template struct jit_uni_eltwise_fwd_t<avx512_common>;
1166 template struct jit_uni_eltwise_bwd_t<avx512_common>;
1167
1168 }
1169 }
1170 }