updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_eltwise.cpp
1 /*******************************************************************************
2 * Copyright 2017-2019 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <mkldnn_types.h>
18 #include "mkldnn_types.h"
19 #include "mkldnn_thread.hpp"
20 #include "nstl.hpp"
21 #include "utils.hpp"
22 #include "jit_generator.hpp"
23
24 #include "jit_uni_eltwise.hpp"
25 #include "jit_avx512_core_bf16cvt.hpp"
26
27 #define GET_OFF(field) offsetof(jit_args, field)
28
29 namespace mkldnn {
30 namespace impl {
31 namespace cpu {
32
33 using namespace Xbyak;
34
35 template <cpu_isa_t isa>
36 void jit_uni_eltwise_injector_f32<isa>::injector_preamble(size_t start_idx,
37         size_t end_idx) {
38     preserved_vecs_count = 0;
39     vecs_to_preserve = (size_t)aux_vecs_count(alg_);
40     start_idx_tail = start_idx;
41
42     // For sse42 mask register has to be Xmm(0)
43     if (isa == sse42 && vecs_to_preserve > 0) {
44         size_t idx = 0;
45         assert(idx < start_idx);
46         preserved_vec_idxs[preserved_vecs_count++] = idx;
47     }
48
49     for (size_t idx = preserved_vecs_count; idx < vecs_count; idx++) {
50         if (preserved_vecs_count >= vecs_to_preserve) break;
51         if (start_idx <= idx && idx < end_idx) continue;
52
53         preserved_vec_idxs[preserved_vecs_count++] = idx;
54     }
55
56     size_t preserved_vecs_count_tail = vecs_to_preserve - preserved_vecs_count;
57     for (size_t i = 0; i < preserved_vecs_count_tail; i++) {
58         preserved_vec_idxs[preserved_vecs_count++] = start_idx_tail++;
59     }
60
61     assert(preserved_vecs_count == vecs_to_preserve);
62
63     if (save_state_) {
64         h->push(p_table);
65
66         if (preserved_vecs_count)
67             h->sub(h->rsp, preserved_vecs_count * vlen);
68
69         for (size_t i = 0; i < preserved_vecs_count; ++i)
70             h->uni_vmovups(h->ptr[h->rsp + i * vlen],
71                     Vmm(preserved_vec_idxs[i]));
72
73         load_table_addr();
74     }
75
76     assign_regs();
77 }
78
79 template <cpu_isa_t isa>
80 void jit_uni_eltwise_injector_f32<isa>::injector_preamble_tail(size_t start_idx)
81 {
82     size_t tail_vecs_to_preserve = start_idx_tail - start_idx;
83     if (tail_vecs_to_preserve == 0) return;
84
85     const int idx_off = vecs_to_preserve - tail_vecs_to_preserve;
86
87     if (save_state_) {
88         if (idx_off)
89             h->add(h->rsp, idx_off * vlen);
90
91         for (size_t i = 0; i < tail_vecs_to_preserve; ++i)
92             h->uni_vmovups(Vmm(preserved_vec_idxs[idx_off + i]),
93                     h->ptr[h->rsp + i * vlen]);
94     }
95
96     for (size_t i = 0; i < tail_vecs_to_preserve; ++i)
97         preserved_vec_idxs[idx_off + i] += tail_vecs_to_preserve;
98
99     if (save_state_) {
100         for (size_t i = 0; i < tail_vecs_to_preserve; ++i)
101             h->uni_vmovups(h->ptr[h->rsp + i * vlen],
102                     Vmm(preserved_vec_idxs[idx_off + i]));
103
104         if (idx_off)
105             h->sub(h->rsp, idx_off * vlen);
106     }
107
108     assign_regs();
109 }
110
111 template <cpu_isa_t isa>
112 void jit_uni_eltwise_injector_f32<isa>::injector_postamble() {
113     if (!save_state_) return;
114
115     for (size_t i = 0; i < preserved_vecs_count; ++i)
116         h->uni_vmovups(Vmm(preserved_vec_idxs[i]),
117                 h->ptr[h->rsp + i * vlen]);
118
119     if (preserved_vecs_count)
120         h->add(h->rsp, preserved_vecs_count * vlen);
121
122     h->pop(p_table);
123 }
124
125 template <cpu_isa_t isa>
126 void jit_uni_eltwise_injector_f32<isa>::assign_regs() {
127     vmm_mask = Vmm(preserved_vec_idxs[0]);
128     vmm_aux0 = Vmm(preserved_vec_idxs[0]);
129     vmm_aux1 = Vmm(preserved_vec_idxs[1]);
130     vmm_aux2 = Vmm(preserved_vec_idxs[2]);
131     vmm_aux3 = Vmm(preserved_vec_idxs[3]);
132     vmm_aux4 = Vmm(preserved_vec_idxs[4]);
133 }
134
135 template <cpu_isa_t isa>
136 void jit_uni_eltwise_injector_f32<isa>::exp_compute_vector(const Vmm &vmm_src) {
137     h->uni_vminps(vmm_src, vmm_src, table_val(10));
138     h->uni_vmaxps(vmm_src, vmm_src, table_val(11));
139     h->uni_vmovups(vmm_aux0, vmm_src);
140     //calculate exp(x)
141     // fx = x * log2ef + 0.5
142     h->uni_vmulps(vmm_src, vmm_src, table_val(2));
143     h->uni_vaddps(vmm_src, vmm_src, table_val(1));
144
145     // tmp = floorf(fx)
146     h->uni_vroundps(vmm_aux1, vmm_src, _op_floor);
147
148     //keep fx for further computations
149     h->uni_vmovups(vmm_src, vmm_aux1); //vmm_src = fx
150
151     //x = x - fx * ln2
152     h->uni_vfnmadd231ps(vmm_aux0, vmm_aux1, table_val(3));
153
154     // compute 2^n
155     h->uni_vcvtps2dq(vmm_aux1, vmm_src);
156     h->uni_vpaddd(vmm_aux1, vmm_aux1, table_val(4));
157     h->uni_vpslld(vmm_aux1, vmm_aux1, 23); //Vmm(6) = 2^-fx
158
159     // y = p5
160     h->uni_vmovups(vmm_src, table_val(9));
161     // y = y * x + p4
162     h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(8));
163     // y = y * x + p3
164     h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(7));
165     // y = y * x + p2
166     h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(6));
167     // y = y * x + p1
168     h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(0));
169     // y = y * x + p0
170     h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(5));  //exp(q)
171     // y = y * 2^n
172     h->uni_vmulps(vmm_src, vmm_src, vmm_aux1);
173 }
174
175 template <cpu_isa_t isa>
176 void jit_uni_eltwise_injector_f32<isa>::relu_compute_vector(const Vmm &vmm_src)
177 {
178     const int alpha_off = 0, zero_off = 1;
179
180     h->uni_vmovups(vmm_aux1, vmm_src);
181     if (isa == sse42) {
182         h->movups(vmm_mask, vmm_src);
183         h->mulps(vmm_src, table_val(alpha_off));
184         h->cmpps(vmm_mask, table_val(zero_off), _cmp_nle_us);
185         h->blendvps(vmm_src, vmm_aux1);
186     } else if (isa == avx2) {
187         h->vmulps(vmm_src, vmm_src, table_val(alpha_off));
188         h->vcmpgtps(vmm_mask, vmm_aux1, table_val(zero_off));
189         h->vblendvps(vmm_src, vmm_src, vmm_aux1, vmm_mask);
190     } else if (isa == avx512_common) {
191         h->vmulps(vmm_src, vmm_src, table_val(alpha_off));
192         h->vcmpps(k_mask, vmm_aux1, table_val(zero_off), _cmp_nle_us);
193         h->vblendmps(vmm_src | k_mask, vmm_src, vmm_aux1);
194     }
195 }
196
197 template <cpu_isa_t isa>
198 void jit_uni_eltwise_injector_f32<isa>::relu_zero_ns_compute_vector(
199         const Vmm &vmm_src) {
200     const int zero_off = 1;
201     h->uni_vmaxps(vmm_src, vmm_src, table_val(zero_off));
202 }
203
204 template <cpu_isa_t isa>
205 void jit_uni_eltwise_injector_f32<isa>::elu_compute_vector(const Vmm &vmm_src) {
206     const int alpha_off = 23, zero_off = 24;
207
208     // compute exponent
209     h->uni_vmovups(vmm_aux2, vmm_src);
210     exp_compute_vector(vmm_src);
211
212     // alpha * (exp(x) - 1)
213     h->uni_vsubps(vmm_src, vmm_src, table_val(0));
214     h->uni_vmulps(vmm_src, vmm_src, table_val(alpha_off));
215
216     // combine with mask
217     if (isa == sse42) {
218         h->pxor(vmm_mask, vmm_mask);
219         h->cmpps(vmm_mask,  vmm_aux2, _cmp_le_os);
220         h->blendvps(vmm_src, vmm_aux2);
221     } else if (isa == avx2) {
222         h->uni_vcmpgtps(vmm_mask, vmm_aux2, table_val(zero_off));
223         h->uni_vblendvps(vmm_src, vmm_src, vmm_aux2, vmm_mask);
224     } else if (isa == avx512_common) {
225         h->vcmpps(k_mask, vmm_aux2, table_val(zero_off), _cmp_nle_us);
226         h->vblendmps(vmm_src | k_mask, vmm_src, vmm_aux2);
227     }
228 }
229
230 template <cpu_isa_t isa>
231 void jit_uni_eltwise_injector_f32<isa>::tanh_compute_vector(const Vmm &vmm_src)
232 {
233     // # comes from Taylor expansion error bound
234     //  > linear_sat_point = single(sqrt(3) * 1b-12);
235     // # comes from the exp formula cancellation
236     //  > exp_bound_point = (single(log(3)/2));
237     // # comes from rounding accuracy in float
238     //  > one_sat_point = round(atanh(1 - 1b-25), single, RU);
239     //  > P = fpminimax(f, [|1, 3, 5, 7, 9|], [|24... |],
240     //            [linear_sat_point, exp_bound_point], relative, floating);
241     //  > err_bound = D(sup(supnorm(P, tanh(x),
242     //          [linear_sat_point, exp_bound_point], relative, theta)));
243     //    0x1.fffd6f00b9539p-25
244     //  > P;
245     //    x * (0x1.fffffep-1 + x^0x1p1 * (-0x1.55539ep-2 + x^0x1p1 *
246     //        (0x1.10be3ep-3 + x^0x1p1 * (-0x1.ae57b4p-5
247     //        + x^0x1p1 * 0x1.09fa1p-6))))
248
249     // register mapping
250     // vmm_src contains input
251     // vmm_aux0 contains mask of currently valid results.
252     //     1 is need computation, 0 is already computed
253     // vmm_aux1 contains current output
254     // vmm_aux2, vmm_aux3 contains auxiliary values
255     // vmm_aux4 contains the original sign of inputs
256
257     Label end_tanh_label;
258
259     auto test_exit =[&](Xbyak::Address threshold){
260         // is not necessary for >AVX, but should not matter on perf
261         h->uni_vmovups(vmm_aux0, vmm_src);
262         if (isa == avx512_common){
263             h->vcmpps(k_mask, vmm_aux0, threshold, 0x5);
264             h->kortestw(k_mask, k_mask);
265         } else {
266             h->uni_vcmpgeps(vmm_aux0, vmm_aux0, threshold);
267             h->uni_vtestps(vmm_aux0, vmm_aux0);
268         }
269         h->jz(end_tanh_label, Xbyak::CodeGenerator::T_NEAR);
270     };
271
272     auto blend_results=[&](Vmm vmm_partial_res){
273         if (isa == avx512_common)
274             h->vblendmps(vmm_aux1 | k_mask, vmm_aux1, vmm_partial_res);
275         else
276             h->uni_vblendvps(vmm_aux1, vmm_aux1, vmm_partial_res, vmm_aux0);
277     };
278
279     // because tanh(x) = -tanh(-x), we extract sign to make x postive
280     // and reapply sign at the end
281     // mov is not necessary for >AVX, but should not matter for performance
282     h->uni_vmovups(vmm_aux4, vmm_src);
283     h->uni_vandps(vmm_aux4, vmm_aux4, table_val(12));
284     h->uni_vandps(vmm_src, vmm_src, table_val(17));
285
286     // if x < linear_sat_point for all inputs, we just return the input
287     h->uni_vmovups(vmm_aux1, vmm_src);
288     test_exit(table_val(13));
289
290     // if one of the mask is one, we have to compute an better approx
291     h->uni_vmovups(vmm_aux2, vmm_src);
292     h->uni_vmulps(vmm_aux2, vmm_aux2, vmm_aux2);
293     h->uni_vmovups(vmm_aux3, table_val(22));
294     h->uni_vfmadd213ps(vmm_aux3, vmm_aux2, table_val(21));
295     h->uni_vfmadd213ps(vmm_aux3, vmm_aux2, table_val(20));
296     h->uni_vfmadd213ps(vmm_aux3, vmm_aux2, table_val(19));
297     h->uni_vfmadd213ps(vmm_aux3, vmm_aux2, table_val(18));
298     h->uni_vmulps(vmm_aux3, vmm_aux3, vmm_src);
299
300     // we blend only the result that need update
301     blend_results(vmm_aux3);
302
303     // if x < exp_bound_point, we go to return point
304     test_exit(table_val(14));
305
306     // if not we use a better approx 1 - 2 / (1 + exp(2x))
307     // compute 2x
308     h->uni_vmovups(vmm_aux3, vmm_src);
309     h->uni_vaddps(vmm_aux3, vmm_aux3, vmm_aux3);
310
311     // Compute exp(2x)
312     // We need to save kmask, vmm_aux0, vmm_aux1 and vmm_src as exp can use them
313     // vmm_src is not more read afterwards, so we do not have to save it
314     auto stack_size = 3 * vlen + (isa == avx512_common) * 4;
315     h->sub(h->rsp, stack_size);
316     h->uni_vmovups(h->ptr[h->rsp + 0 * vlen], vmm_aux0);
317     h->uni_vmovups(h->ptr[h->rsp + 1 * vlen], vmm_aux1);
318     h->uni_vmovups(h->ptr[h->rsp + 2 * vlen], vmm_src);
319     if (isa == avx512_common)
320         h->kmovw(h->ptr[h->rsp + 3 * vlen], k_mask);
321
322     exp_compute_vector(vmm_aux3);
323
324     h->uni_vmovups(vmm_aux0, h->ptr[h->rsp + 0 * vlen]);
325     h->uni_vmovups(vmm_aux1, h->ptr[h->rsp + 1 * vlen]);
326     h->uni_vmovups(vmm_src, h->ptr[h->rsp + 2 * vlen]);
327     if (isa == avx512_common)
328         h->kmovw(k_mask, h->ptr[h->rsp + 3 * vlen]);
329     h->add(h->rsp, stack_size);
330
331     // 1 + exp(2x)
332     h->uni_vaddps(vmm_aux3, vmm_aux3, table_val(0));
333
334     // 1 - 2 / (1 + exp(2x))
335     h->uni_vmovups(vmm_aux2, table_val(16));
336     h->uni_vdivps(vmm_aux2, vmm_aux2, vmm_aux3);
337     h->uni_vaddps(vmm_aux2, vmm_aux2, table_val(0));
338
339     // we blend only the result that need update
340     blend_results(vmm_aux2);
341
342     // finally, we saturate to 1 if needed
343     // TODO: maybe move that up if most inputs saturate in practice
344     if (isa == avx512_common)
345         h->vcmpps(k_mask, vmm_aux0, table_val(15), 0x5);
346     else {
347         h->uni_vmovups(vmm_aux0, vmm_src);
348         h->uni_vcmpgeps(vmm_aux0, vmm_aux0, table_val(15));
349     }
350     h->uni_vmovups(vmm_aux2, table_val(0));
351     blend_results(vmm_aux2);
352
353     h->L(end_tanh_label);
354     {
355         // we apply the sign of x to the result and we are done
356         h->uni_vmovups(vmm_src, vmm_aux1);
357         h->uni_vpxor(vmm_src, vmm_src, vmm_aux4);
358     }
359 }
360
361 template <cpu_isa_t isa>
362 void jit_uni_eltwise_injector_f32<isa>::square_compute_vector(
363         const Vmm &vmm_src) {
364     h->uni_vmulps(vmm_src, vmm_src, vmm_src);
365 }
366
367 template <cpu_isa_t isa>
368 void jit_uni_eltwise_injector_f32<isa>::abs_compute_vector(const Vmm &vmm_src) {
369     // compute abs(x) = _mm_and_ps(x, 01111..111));
370     h->uni_vandps(vmm_src, vmm_src, table_val(0));
371 }
372
373 template <cpu_isa_t isa>
374 void jit_uni_eltwise_injector_f32<isa>::sqrt_compute_vector(const Vmm &vmm_src)
375 {
376     if (isa == avx512_common) {
377         h->vcmpps(k_mask, vmm_src, table_val(0), _cmp_nle_us);
378         h->uni_vsqrtps(vmm_aux1, vmm_src);
379         h->uni_vmovups(vmm_src, table_val(0));
380         h->vblendmps(vmm_src | k_mask, vmm_src, vmm_aux1);
381     } else {
382         h->uni_vmovups(vmm_mask, vmm_src);
383         h->uni_vcmpgtps(vmm_mask, vmm_mask, table_val(0));
384         h->uni_vsqrtps(vmm_aux1, vmm_src);
385         h->uni_vmovups(vmm_src, table_val(0));
386         h->uni_vblendvps(vmm_src, vmm_src, vmm_aux1, vmm_mask);
387     }
388 }
389
390 template <cpu_isa_t isa>
391 void jit_uni_eltwise_injector_f32<isa>::linear_compute_vector(
392         const Vmm &vmm_src) {
393     // compute x = alpha * x + beta;
394     h->uni_vmovups(vmm_aux0, table_val(0));
395     h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(1));
396 }
397
398 template <cpu_isa_t isa>
399 void jit_uni_eltwise_injector_f32<isa>::bounded_relu_compute_vector(
400         const Vmm &vmm_src) {
401     // compute bounded relu */
402     h->uni_vmaxps(vmm_src, vmm_src, table_val(1));
403     h->uni_vminps(vmm_src, vmm_src, table_val(0));
404 }
405
406 template <cpu_isa_t isa>
407 void jit_uni_eltwise_injector_f32<isa>::soft_relu_compute_vector(
408         const Vmm &vmm_src) {
409     // duplicate src
410     h->uni_vmovups(vmm_aux2, vmm_src);
411
412     h->uni_vminps(vmm_src, vmm_src, table_val(24));
413     h->uni_vmaxps(vmm_src, vmm_src, table_val(25));
414     h->uni_vmovups(vmm_aux1, vmm_src);
415     // calculate exp(x)
416     // fx = x * log2ef + 0.5
417     h->uni_vmulps(vmm_src, vmm_src, table_val(2));
418     h->uni_vaddps(vmm_src, vmm_src, table_val(1));
419
420     // tmp = floorf(fx)
421     h->uni_vroundps(vmm_aux0, vmm_src, _op_floor);
422
423     // keep fx for further computations
424     h->uni_vmovups(vmm_src, vmm_aux0); //vmm_src = fx
425     // calculation fx * ln2
426     h->uni_vmulps(vmm_aux0, vmm_aux0, table_val(3));
427     // x = x - fx * ln2
428     h->uni_vsubps(vmm_aux1, vmm_aux1, vmm_aux0);
429     // y = p5
430     h->uni_vmovups(vmm_aux3, table_val(22));
431     // y = y * x + p4
432     h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(21));
433     // y = y * x + p3
434     h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(20));
435     // y = y * x + p2
436     h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(19));
437     // y = y * x + p1
438     h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(0));
439     // y = y * x + p0
440     h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(17));
441
442     // compute 2^(-n)
443     if (isa == avx512_common) {
444         h->vmulps(vmm_aux1, vmm_src, table_val(23));
445         h->vcvtps2dq(vmm_aux1, vmm_aux1);
446     } else {
447         h->uni_vcvtps2dq(vmm_aux1, vmm_src);
448         h->uni_vpsignd(vmm_aux1, vmm_aux1, table_val(23));
449     }
450
451     h->uni_vpaddd(vmm_aux1, vmm_aux1, table_val(4));
452     h->uni_vpslld(vmm_aux1, vmm_aux1, 23); //vmm_aux1 = 2^-fx
453     // calculate ln(1 + y)
454     h->uni_vaddps(vmm_aux3, vmm_aux3, vmm_aux1);
455     // x = y; y is free; keep x for further computations
456     h->uni_vmovups(vmm_src, vmm_aux3);
457     // frexp()
458     h->uni_vpsrld(vmm_src, vmm_src, 23);
459     h->uni_vcvtdq2ps(vmm_src, vmm_src);
460     // got n. where n is x = 2^n * y. y = 0.5 .. 1
461     h->uni_vsubps(vmm_src, vmm_src, table_val(5));
462
463     h->uni_vandps(vmm_aux3, vmm_aux3, table_val(6));
464     // got y. (mantisa)  0.5 < y < 1
465     h->uni_vorps(vmm_aux3, vmm_aux3, table_val(7));
466     // y  = y - 1
467     h->uni_vsubps(vmm_aux3, vmm_aux3, table_val(0));
468     // y = p8
469     h->uni_vmovups(vmm_aux1, table_val(16));
470     // y = y * x + p7
471     h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(15));
472     // y = y * x + p6
473     h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(14));
474     // y = y * x + p5
475     h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(13));
476     // y = y * x + p4
477     h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(12));
478     // y = y * x + p3
479     h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(11));
480     // y = y * x + p2
481     h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(10));
482     // y = y * x + p1
483     h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(9));
484     // y = y * x + p0 ; p0 = 0
485     h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(8));
486     //calculate ln(2) * n
487     h->uni_vmulps(vmm_src, vmm_src, table_val(3));
488     h->uni_vaddps(vmm_aux1, vmm_aux1, vmm_src);
489     h->uni_vaddps(vmm_aux1, vmm_aux1, vmm_aux0);
490
491     // get vmm_mask = src > max logf
492     h->uni_vmovups(vmm_mask, vmm_aux2);
493     if (isa == avx512_common) {
494         // y = (x < max log f) ? soft_relu(x) : x
495         h->vcmpps(k_mask, vmm_mask, table_val(24), _cmp_nle_us);
496         h->vblendmps(vmm_aux1 | k_mask, vmm_aux1, vmm_aux2);
497     } else {
498         // y = (x < max log f) ? soft_relu(x) : x
499         h->uni_vcmpgtps(vmm_mask, vmm_mask, table_val(24));
500         h->uni_vblendvps(vmm_aux1, vmm_aux1, vmm_aux2, vmm_mask);
501     }
502
503     h->uni_vmovups(vmm_src, vmm_aux1);
504 }
505
506 template <cpu_isa_t isa>
507 void jit_uni_eltwise_injector_f32<isa>::logistic_compute_vector(
508         const Vmm &vmm_src) {
509     // we store the original sign and make x negative
510     // IMPORTANT: we assume vmm_aux0 to be xmm0, as for sse4.2 path it is required
511     // IMPORTANT: we use vmm_aux2 for the mask as exp_compute does not use it.
512     h->uni_vmovups(vmm_aux2, vmm_src);
513     h->uni_vandps(vmm_aux2, vmm_aux2, table_val(12));
514     h->uni_vorps(vmm_src, vmm_src, table_val(12));
515
516     exp_compute_vector(vmm_src);
517     // dup exp(x)
518     h->uni_vmovups(vmm_aux1, vmm_src);
519     // (exp(x) + 1)
520     h->uni_vaddps(vmm_aux1, vmm_aux1, table_val(0));
521     // y = exp(x) / (exp(x) + 1)
522     h->uni_vdivps(vmm_src, vmm_src, vmm_aux1);
523
524     // Now we have to apply the "symmetry" based on original sign
525     h->uni_vmovups(vmm_aux3, table_val(0));
526     h->uni_vsubps(vmm_aux3, vmm_aux3, vmm_src);
527     if (isa == avx512_common) {
528         h->vptestmd(k_mask, vmm_aux2, vmm_aux2);
529         h->vblendmps(vmm_aux3 | k_mask, vmm_aux3, vmm_src);
530     } else {
531         h->uni_vmovups(vmm_aux0, vmm_aux2);// The mask should be xmm0 for sse4.2
532         h->uni_vblendvps(vmm_aux3, vmm_aux3, vmm_src, vmm_aux0);
533     }
534     h->uni_vmovups(vmm_src, vmm_aux3);
535 }
536
537 template <cpu_isa_t isa>
538 void jit_uni_eltwise_injector_f32<isa>::clamp_compute_vector(
539         const Vmm &vmm_src) {
540     // compute clamp */
541     h->uni_vmaxps(vmm_src, vmm_src, table_val(1));
542     h->uni_vminps(vmm_src, vmm_src, table_val(0));
543 }
544
545 template <cpu_isa_t isa>
546 void jit_uni_eltwise_injector_f32<isa>::relu_prepare_table() {
547     for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_));
548     for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0);
549 }
550
551 template <cpu_isa_t isa>
552 void jit_uni_eltwise_injector_f32<isa>::elu_prepare_table() {
553     const unsigned int cvals[] = {
554             0x3f800000, // [0] 1.0f
555             0x3f000000, // [1] 0.5f
556             0x3fb8aa3b, // [2] log2ef = 1.44269502f
557             0x3f317218, // [3] ln2f =   0.69314718f
558             0x0000007f, // [4] 0x7f
559             // exp(x) polynom
560             0x3f800001, // [5] p0 = 1.0000001f
561             0x3efffe85, // [6] p2 = 0.4999887f
562             0x3e2aaa3e, // [7] p3 = 0.16666505f
563             0x3d2bb1b1, // [8] p4 = 0.041917507f
564             0x3c091ec1, // [9] p5 = 0.008369149f
565             0x42b0c0a5, //[10] max logf = 88.3762589f
566             0xc1766666, //[11] min logf = -14.5f
567             // tanh(x) constants,
568             0x80000000, //[12] mask to extract sign
569             0x39ddb3d7, //[13] arg below which tanh(x) = x
570             0x3f0c9f54, //[14] arg below which pol approx is valid
571             0x41102cb4, //[15] arg after which tanh(x) = 1
572             0xc0000000, //[16] -2.0f
573             0x7fffffff, //[17] mask to make positive
574             // tanh pol approx
575             0x3f7fffff, //[18] p0
576             0xbeaaa9cf, //[19] p1
577             0x3e085f1f, //[20] p2
578             0xbd572bda, //[21] p3
579             0x3c84fd08, //[22] p4
580     };
581
582     for (size_t i = 0; i < sizeof(cvals) / sizeof(cvals[0]); ++i) {
583         for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(cvals[i]);
584     }
585
586     for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_));
587     for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0);
588 }
589
590 template <cpu_isa_t isa>
591 void jit_uni_eltwise_injector_f32<isa>::soft_relu_prepare_table() {
592     const unsigned int cvals[] = {
593             0x3f800000, // [0] 1.0f
594             0x3f000000, // [1] 0.5f
595             0x3fb8aa3b, // [2] log2ef = 1.44269502f
596             0x3f317218, // [3] ln2f =   0.69314718f
597             0x0000007f, // [4] 0x7f
598             0x42fc0000, // [5] 126
599             0x807fffff, // [6] and with (to get 0.5 * mantissa)
600             0x3f000000, // [7] or with (to get 0.5 * mantissa)
601             // ln(1 + x) polynomial
602             0xb2b4637d, // [8]  p0 = 0.0000000244f
603             0x3f7fff8e, // [9]  p1 = 0.9999976971f
604             0xbf001759, //[10]  p2 = -0.5002478215f
605             0x3ea70608, //[11]  p3 = 0.3272714505f
606             0xbea3d7bf, //[12]  p4 = -0.3153830071f
607             0xbe361d04, //[13]  p5 = -0.1701777461f
608             0xbfa8f1e6, //[14]  p6 = -1.3254635147f
609             0xbfe1e812, //[15]  p7 = -1.7971917960f
610             0xbfc4d30e, //[16]  p8 = -1.5652673123f
611             // exp(x) polynomial
612             0x3f800001, //[17]  p0 = 1.0000001f
613             0x3f800000, //[18]  p1 = 1.0f
614             0x3efffe85, //[19]  p2 = 0.4999887f
615             0x3e2aaa3e, //[20]  p3 = 0.16666505f
616             0x3d2bb1b1, //[21]  p4 = 0.041917507f
617             0x3c091ec1, //[22]  p5 = 0.008369149f
618             0xbf800000, //[23] is required for sign changing
619             0x42b0c0a5, //[24] max logf = 88.3762589f
620             0xc1766666  //[25] min logf = -14.5f
621     };
622
623     for (size_t i = 0; i < sizeof(cvals) / sizeof(cvals[0]); ++i) {
624         for (size_t d = 0; d < vlen / sizeof(float); ++d) {
625             h->dd(cvals[i]);
626         }
627     }
628 }
629
630 template <cpu_isa_t isa>
631 void jit_uni_eltwise_injector_f32<isa>::abs_prepare_table() {
632     for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0x7fffffff);
633 }
634
635 template <cpu_isa_t isa>
636 void jit_uni_eltwise_injector_f32<isa>::sqrt_prepare_table() {
637     for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0);
638 }
639
640 template <cpu_isa_t isa>
641 void jit_uni_eltwise_injector_f32<isa>::linear_prepare_table() {
642     for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_));
643     for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(beta_));
644 }
645
646 template <cpu_isa_t isa>
647 void jit_uni_eltwise_injector_f32<isa>::bounded_relu_prepare_table() {
648     for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_));
649     for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0);
650 }
651
652 template <cpu_isa_t isa>
653 void jit_uni_eltwise_injector_f32<isa>::clamp_prepare_table() {
654     for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_));
655     for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(beta_));
656 }
657
658 template <cpu_isa_t isa>
659 int jit_uni_eltwise_injector_f32<isa>::aux_vecs_count(alg_kind_t alg_) {
660     switch (alg_) {
661     case alg_kind::eltwise_relu: return (alpha_ == 0.f) ? 0 : 2;
662     case alg_kind::eltwise_elu: return 4;
663     case alg_kind::eltwise_tanh: return 5;
664     case alg_kind::eltwise_square: return 0;
665     case alg_kind::eltwise_abs: return 0;
666     case alg_kind::eltwise_sqrt: return 2;
667     case alg_kind::eltwise_linear: return 1;
668     case alg_kind::eltwise_bounded_relu: return 0;
669     case alg_kind::eltwise_soft_relu: return 4;
670     case alg_kind::eltwise_logistic: return 4;
671     case alg_kind::eltwise_clamp: return 0;
672     case alg_kind::eltwise_exp: return 4;
673     default: assert(!"unsupported eltwise algorithm");
674     }
675
676     return 0;
677 }
678
679 template <cpu_isa_t isa>
680 void jit_uni_eltwise_injector_f32<isa>::compute_body(size_t start_idx,
681         size_t end_idx) {
682     using namespace alg_kind;
683     for (size_t idx = start_idx; idx < end_idx; idx++) {
684         switch (alg_) {
685         case eltwise_relu:
686             if (alpha_ == 0.f) relu_zero_ns_compute_vector(Vmm(idx));
687             else relu_compute_vector(Vmm(idx));
688             break;
689         case eltwise_elu: elu_compute_vector(Vmm(idx)); break;
690         case eltwise_tanh: tanh_compute_vector(Vmm(idx)); break;
691         case eltwise_square: square_compute_vector(Vmm(idx)); break;
692         case eltwise_abs: abs_compute_vector(Vmm(idx)); break;
693         case eltwise_sqrt: sqrt_compute_vector(Vmm(idx)); break;
694         case eltwise_linear: linear_compute_vector(Vmm(idx)); break;
695         case eltwise_bounded_relu: bounded_relu_compute_vector(Vmm(idx)); break;
696         case eltwise_soft_relu: soft_relu_compute_vector(Vmm(idx)); break;
697         case eltwise_logistic: logistic_compute_vector(Vmm(idx)); break;
698         case eltwise_clamp: clamp_compute_vector(Vmm(idx)); break;
699         case eltwise_exp: exp_compute_vector(Vmm(idx)); break;
700         default: assert(!"unsupported eltwise algorithm");
701         }
702     }
703 }
704
705 template <cpu_isa_t isa>
706 void jit_uni_eltwise_injector_f32<isa>::compute_vector_range(size_t start_idx,
707         size_t end_idx) {
708     assert(start_idx < end_idx && end_idx <= vecs_count);
709
710     injector_preamble(start_idx, end_idx);
711     compute_body(start_idx_tail, end_idx);
712     injector_preamble_tail(start_idx);
713     compute_body(start_idx, start_idx_tail);
714     injector_postamble();
715 }
716
717 template <cpu_isa_t isa>
718 void jit_uni_eltwise_injector_f32<isa>::prepare_table(bool gen_table) {
719     using namespace alg_kind;
720
721     h->align(64);
722     h->L(l_table);
723
724     if (gen_table) {
725         switch (alg_) {
726         case eltwise_relu: relu_prepare_table(); break;
727         case eltwise_elu:
728         case eltwise_tanh:
729         case eltwise_logistic:
730         case eltwise_exp:
731             elu_prepare_table(); break;
732         case eltwise_soft_relu: soft_relu_prepare_table(); break;
733         case eltwise_abs: abs_prepare_table(); break;
734         case eltwise_sqrt: sqrt_prepare_table(); break;
735         case eltwise_linear: linear_prepare_table(); break;
736         case eltwise_bounded_relu: bounded_relu_prepare_table(); break;
737         case eltwise_square: break;
738         case eltwise_clamp: clamp_prepare_table(); break;
739         default: assert(!"unsupported eltwise algorithm");
740     }
741     }
742 }
743
744 template struct jit_uni_eltwise_injector_f32<avx512_common>;
745 template struct jit_uni_eltwise_injector_f32<avx2>;
746 template struct jit_uni_eltwise_injector_f32<sse42>;
747
748
749 struct jit_args {
750     const void *from;
751     const void *for_comparison;
752     const void *to;
753     size_t work_amount;
754 };
755
756 struct jit_uni_eltwise_kernel_f32 : public c_compatible {
757     const eltwise_desc_t &desc_;
758
759     void (*ker_)(const jit_args *);
760     void operator()(const jit_args *args) { assert(ker_); ker_(args); }
761
762     jit_uni_eltwise_kernel_f32(const eltwise_desc_t &desc)
763         : desc_(desc), ker_(nullptr) {}
764     virtual ~jit_uni_eltwise_kernel_f32() {}
765
766 protected:
767     bool is_bwd() const { return desc_.prop_kind == prop_kind::backward_data; }
768 };
769
770 /* jit kernels */
771 namespace {
772
773 template <cpu_isa_t isa>
774 struct jit_uni_relu_kernel_f32 : public jit_uni_eltwise_kernel_f32,
775     public jit_generator
776 {
777     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_relu_kernel_f32)
778
779     void compute_step(bool vectorize, const int uf, const int shift) {
780         for (int i = 0; i < uf; i++) {
781             auto addr_fwd = ptr[reg_from + i * shift];
782             auto addr_bwd = ptr[reg_for_comparison + i * shift];
783             if (vectorize) {
784                 if (is_bf16_) {
785                     vmovups(Ymm_src(i + 1), addr_fwd);
786                     vpermw(Vmm(i + 1) | k_mask_cvt | T_z, zmm_idx, Zmm_src(i + 1));
787                 } else {
788                     uni_vmovups(Vmm(i + 1), addr_fwd);
789                 }
790                 if (is_bwd()) {
791                     if (is_bf16_) {
792                         vmovups(Ymm_src(uf + i + 1), addr_bwd);
793                         vpermw(Vmm(uf + i + 1) | k_mask_cvt | T_z,
794                                 zmm_idx, Zmm_src(uf + i + 1));
795                     } else {
796                         uni_vmovups(Vmm(uf + i + 1), addr_bwd);
797                     }
798                 }
799             } else {
800                 if (is_bf16_) {
801                     vmovdqu16(Ymm_src(i + 1) | k_tail_mask, addr_fwd);
802                     vpermw(Vmm(i + 1) | k_mask_cvt | T_z, zmm_idx,
803                             Zmm_src(i + 1));
804                 } else {
805                     movss(Xmm(i + 1), addr_fwd);
806                 }
807                 if (is_bwd()) {
808                     if (is_bf16_) {
809                         vmovdqu16(Ymm_src(uf + i + 1) | k_tail_mask, addr_bwd);
810                         vpermw(Vmm(uf + i + 1) | k_mask_cvt | T_z, zmm_idx,
811                                 Zmm_src(uf + i + 1));
812                     } else {
813                         movss(Xmm(uf + i + 1), addr_bwd);
814                     }
815                 }
816             }
817         }
818
819         if (isa == sse42) {
820             for (int i = 0; i < uf; i++) {
821                 movups(Vmm(2 * uf + i + 1), Vmm(i + 1));
822                 mulps(Vmm(2 * uf + i + 1), vmm_ns);
823
824                 Vmm mask = Vmm(0);
825                 if (is_bwd()) {
826                     movups(mask, Vmm(uf + i + 1));
827                     cmpps(mask, vmm_zero, _cmp_nle_us);
828                 } else {
829                     movups(mask, Vmm(i + 1));
830                     cmpps(mask, vmm_zero, _cmp_nle_us);
831                 }
832                 blendvps(Vmm(2 * uf + i + 1), Vmm(i + 1));
833             }
834         } else {
835             for (int i = 0; i < uf; i++) {
836                 vmulps(Vmm(2 * uf + i + 1), Vmm(i + 1), vmm_ns);
837                 if (isa == avx2) {
838                     if (is_bwd())
839                         vcmpgtps(vmm_mask, Vmm(uf + i + 1), vmm_zero);
840                     else
841                         vcmpgtps(vmm_mask, Vmm(i + 1), vmm_zero);
842
843                     vblendvps(Vmm(2 * uf + i + 1), Vmm(2 * uf + i + 1),
844                               Vmm(i + 1), vmm_mask);
845
846                 } else {
847                     if (is_bwd())
848                         vcmpps(k_mask, Vmm(uf + i + 1), vmm_zero, _cmp_nle_us);
849                     else
850                         vcmpps(k_mask, Vmm(i + 1), vmm_zero, _cmp_nle_us);
851                     vblendmps(Vmm(2 * uf + i + 1) | k_mask, Vmm(2 * uf + i + 1),
852                               Vmm(i + 1));
853                 }
854             }
855         }
856         auto store_data =[&] (opmask_t _kmask, int i) {
857             if (!is_cpx_)
858                 bf16_emu_->r_vcvtneps2bf16(Ymm_src(2 * uf + i + 1),
859                     Zmm(2 * uf + i + 1));
860             else
861                 vcvtneps2bf16(Ymm_src(2 * uf + i + 1), Vmm(2 * uf + i + 1));
862             vmovdqu16(ptr[reg_to + i * shift] | _kmask, Ymm_src(2 * uf + i + 1));
863         };
864
865         for (int i = 0; i < uf; i++) {
866             if (vectorize)
867                 if(is_bf16_)
868                     store_data(k_full_mask, i);
869                 else
870                     uni_vmovups(ptr[reg_to + i * shift], Vmm(2 * uf + i + 1));
871             else
872                 if (is_bf16_)
873                     store_data(k_tail_mask, i);
874                 else
875                     movss(ptr[reg_to + i * shift], Xmm(2 * uf + i + 1));
876         }
877     }
878
879     ~jit_uni_relu_kernel_f32() { delete bf16_emu_; }
880
881     jit_uni_relu_kernel_f32(const eltwise_desc_t &desc)
882         : jit_uni_eltwise_kernel_f32(desc)
883         , jit_generator()
884         , bf16_emu_(nullptr) {
885         assert(desc.alg_kind == alg_kind::eltwise_relu);
886         assert(isa == sse42 || isa == avx2 || isa == avx512_common);
887
888         Reg64 param = abi_param1;
889
890         is_cpx_ = mayiuse(avx512_core_bf16);
891         is_bf16_ = (desc.data_desc.data_type == data_type::bf16);
892         if (!is_cpx_ && is_bf16_)
893             bf16_emu_ = new bf16_emulation_t(this,
894                     bf16_emu_reserv_1, bf16_emu_reserv_2,
895                     bf16_emu_reserv_3, bf16_emu_reserv_4,
896                     bf16_emu_reserv_5, bf16_emu_reserv_6);
897
898         const int simd_w = cpu_isa_traits<isa>::vlen / sizeof(float);
899         const int loop_dec[] = {simd_w, 1};
900         const int uf[] = {1, 1};
901
902         int _shift = (is_bf16_) ? sizeof(mkldnn_bfloat16_t) : sizeof(float);
903         int _vlen = (is_bf16_)
904             ? cpu_isa_traits<isa>::vlen / 2
905             : cpu_isa_traits<isa>::vlen;
906
907         const int shift[] = {_vlen, _shift};
908         const bool loop_vectorize[] = {true, false};
909
910         preamble();
911
912         if (is_bf16_) {
913             mov(mask_reg, 0xAAAAAAAA);
914             kmovd(k_mask_cvt, mask_reg);
915
916             mov(mask_reg, 0x1);
917             kmovd(k_tail_mask, mask_reg);
918
919             mov(mask_reg, 0xffff);
920             kmovd(k_full_mask, mask_reg);
921         }
922         if (!is_cpx_ && is_bf16_)
923             bf16_emu_->init_vcvtneps2bf16();
924
925         mov(reg_from, ptr[param + GET_OFF(from)]);
926         if (is_bwd())
927             mov(reg_for_comparison, ptr[param + GET_OFF(for_comparison)]);
928         mov(reg_to, ptr[param + GET_OFF(to)]);
929         mov(reg_work_amount, ptr[param + GET_OFF(work_amount)]);
930
931         if (is_bf16_) {
932             mov(p_idx_table, idx_table);
933             vmovups(zmm_idx, ptr[p_idx_table]);
934         }
935
936         mov(imm_addr64, float2int(desc.alpha));
937         movq(xmm_ns, imm_addr64);
938         uni_vbroadcastss(vmm_ns, xmm_ns);
939
940         uni_vpxor(vmm_zero, vmm_zero, vmm_zero);
941
942         Label loop_label[3];
943
944         for (int id = 0; id < 2; id++) {
945             L(loop_label[id]);
946             cmp(reg_work_amount, uf[id] * loop_dec[id] - 1);
947             jle(loop_label[id + 1], T_NEAR);
948
949             compute_step(loop_vectorize[id], uf[id], shift[id]);
950
951             add(reg_from, uf[id] * shift[id]);
952             add(reg_to, uf[id] * shift[id]);
953             if (is_bwd())
954                 add(reg_for_comparison, uf[id] * shift[id]);
955
956             sub(reg_work_amount, uf[id] * loop_dec[id]);
957             jmp(loop_label[id]);
958         }
959
960         L(loop_label[2]);
961         postamble();
962
963         if (is_bf16_) {
964             align(64);
965             L(idx_table);
966             const uint16_t _idx[] = { 0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7,8,8,
967                                       9,9,10,10,11,11,12,12,13,13,14,14,15,15 };
968             for (size_t i = 0; i < sizeof(_idx) / sizeof(_idx[0]); ++i)
969                 dw(_idx[i]);
970         }
971
972         ker_ = (decltype(ker_))this->getCode();
973     }
974
975 private:
976     using Vmm = typename utils::conditional3<isa == sse42, Xmm,
977                                              isa == avx2, Ymm, Zmm>::type;
978     using opmask_t = const Xbyak::Opmask;
979
980     Reg64 reg_from = rax;
981     Reg64 reg_for_comparison = is_bwd() ? rdx : reg_from;
982     Reg64 reg_to = r8;
983     Reg64 reg_work_amount = rsi;
984     Reg64 imm_addr64 = rbx;
985
986     Reg32 mask_reg = r14d;
987     Reg32 reg32_tmp = mask_reg;
988     Reg64 reg_idx = r15;
989     Reg64 p_idx_table = r13;
990
991     Xmm xmm_ns = Xmm(14);
992
993     Vmm vmm_ns = Vmm(isa == avx512_common ? 30 : 14);
994     Vmm vmm_zero = Vmm(isa == avx512_common ? 31 : 15);
995
996     Vmm vmm_mask = Vmm(isa == avx512_common ? 28 : 12);
997     Opmask k_mask = Opmask(1);
998
999     inline Ymm Ymm_src(int i) {
1000         return Ymm(15 + i);
1001     }
1002     inline Zmm Zmm_src(int i) {
1003         return Zmm(15 + i);
1004     }
1005     Zmm zmm_idx = Zmm(29);
1006
1007     Zmm bf16_emu_reserv_1 = Zmm(24);
1008     Zmm bf16_emu_reserv_2 = Zmm(25);
1009     Zmm bf16_emu_reserv_3 = Zmm(26);
1010     Reg64 bf16_emu_reserv_4 = r14;
1011     Zmm bf16_emu_reserv_5 = Zmm(27);
1012     Zmm bf16_emu_reserv_6 = Zmm(27);
1013
1014     opmask_t k_mask_cvt = k7;
1015     opmask_t k_tail_mask = k6;
1016     opmask_t k_full_mask = k5;
1017
1018     Label idx_table;
1019
1020     bool is_bf16_;
1021     bool is_cpx_;
1022
1023     bf16_emulation_t *bf16_emu_;
1024 };
1025
1026 template <cpu_isa_t isa>
1027 struct jit_uni_kernel_fwd_f32: public jit_uni_eltwise_kernel_f32,
1028     public jit_generator {
1029     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_kernel_fwd_f32)
1030
1031     jit_uni_kernel_fwd_f32(const eltwise_desc_t &desc)
1032         : jit_uni_eltwise_kernel_f32(desc)
1033         , jit_generator()
1034         , bf16_emu_(nullptr) {
1035
1036         is_cpx_ = mayiuse(avx512_core_bf16);
1037         bool is_bf16_ = (desc.data_desc.data_type == data_type::bf16);
1038
1039         if (!is_cpx_ && is_bf16_)
1040             bf16_emu_ = new bf16_emulation_t(this,
1041                     bf16_emu_reserv_1, bf16_emu_reserv_2,
1042                     bf16_emu_reserv_3, bf16_emu_reserv_4,
1043                     bf16_emu_reserv_5, bf16_emu_reserv_6);
1044
1045         eltwise_injector_ = new jit_uni_eltwise_injector_f32<isa>(this,
1046                 desc.alg_kind, desc.alpha, desc.beta, false, r9, Opmask(1));
1047
1048         using namespace alg_kind;
1049
1050         assert(is_bwd() == false);
1051         assert(utils::one_of(desc.alg_kind, eltwise_tanh, eltwise_elu,
1052                     eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
1053                     eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic,
1054                     eltwise_clamp, eltwise_exp));
1055
1056         preamble();
1057
1058         if (is_bf16_) {
1059             mov(mask_reg, 0xAAAAAAAA);
1060             kmovd(k_mask, mask_reg);
1061
1062             mov(mask_reg, 0x1);
1063             kmovd(k_tail_mask, mask_reg);
1064
1065             mov(mask_reg, 0xffff);
1066             kmovd(k_full_mask, mask_reg);
1067         }
1068         if (!is_cpx_ && is_bf16_)
1069             bf16_emu_->init_vcvtneps2bf16();
1070
1071         Reg64 param = abi_param1;
1072         mov(reg_from, ptr[param + GET_OFF(from)]);
1073         mov(reg_to, ptr[param + GET_OFF(to)]);
1074         if (is_bf16_) {
1075             mov(p_idx_table, idx_table);
1076             vmovups(zmm_idx, ptr[p_idx_table]);
1077         }
1078         mov(reg_work_amount, ptr[param + GET_OFF(work_amount)]);
1079
1080         eltwise_injector_->load_table_addr();
1081
1082         Label reminder_loop_start, reminder_loop_end;
1083         Label vectorized_loop_start, vectorized_loop_end;
1084
1085         cmp(reg_work_amount, simd_w);
1086         jl(reminder_loop_start, T_NEAR);
1087
1088         L(vectorized_loop_start);
1089
1090         auto store_data =[&] (opmask_t _kmask) {
1091             if (!is_cpx_)
1092                 bf16_emu_->r_vcvtneps2bf16(ymm_src, zmm_src_1);
1093             else
1094                 vcvtneps2bf16(ymm_src, vmm_src);
1095             vmovdqu16(ptr[reg_to] | _kmask, ymm_src);
1096         };
1097
1098         if (is_bf16_) {
1099             vmovups(ymm_src, ptr[reg_from]);
1100             vpermw(vmm_src | k_mask  | T_z, zmm_idx, zmm_src);
1101             eltwise_injector_->compute_vector(vmm_src.getIdx());
1102             store_data(k_full_mask);
1103         } else {
1104             uni_vmovups(vmm_src, ptr[reg_from]);
1105             eltwise_injector_->compute_vector(vmm_src.getIdx());
1106             uni_vmovups(ptr[reg_to], vmm_src);
1107         }
1108         auto shift = (is_bf16_) ? vlen / 2 : vlen;
1109         add(reg_from, shift);
1110         add(reg_to, shift);
1111
1112         sub(reg_work_amount, simd_w);
1113         cmp(reg_work_amount, simd_w);
1114         jge(vectorized_loop_start, T_NEAR);
1115
1116         L(vectorized_loop_end);
1117
1118         L(reminder_loop_start);
1119
1120         cmp(reg_work_amount, 0);
1121         jle(reminder_loop_end, T_NEAR);
1122         if (is_bf16_) {
1123             vmovups(ymm_src | k_tail_mask, ptr[reg_from]);
1124             vpermw(vmm_src | k_mask | T_z, zmm_idx, zmm_src);
1125             eltwise_injector_->compute_vector(vmm_src.getIdx());
1126             store_data(k_tail_mask);
1127         } else {
1128             movss(xmm_src, ptr[reg_from]);
1129             eltwise_injector_->compute_vector(xmm_src.getIdx());
1130             movss(ptr[reg_to], xmm_src);
1131         }
1132         auto size_step = (is_bf16_) ? sizeof(mkldnn_bfloat16_t) : sizeof(float);
1133         add(reg_from, size_step);
1134         add(reg_to, size_step);
1135
1136         dec(reg_work_amount);
1137         jmp(reminder_loop_start, T_NEAR);
1138
1139         L(reminder_loop_end);
1140
1141         postamble();
1142
1143         eltwise_injector_->prepare_table();
1144
1145         if (is_bf16_) {
1146             align(64);
1147             L(idx_table);
1148             const uint16_t _idx[] = { 0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7,8,8,
1149                                       9,9,10,10,11,11,12,12,13,13,14,14,15,15 };
1150             for (size_t i = 0; i < sizeof(_idx) / sizeof(_idx[0]); ++i)
1151                 dw(_idx[i]);
1152         }
1153
1154         ker_ = (decltype(ker_))this->getCode();
1155     }
1156
1157     ~jit_uni_kernel_fwd_f32() {
1158         delete eltwise_injector_;
1159         delete bf16_emu_;
1160     }
1161
1162 private:
1163     using Vmm = typename utils::conditional3<isa == sse42, Xmm,
1164                 isa == avx2, Ymm, Zmm>::type;
1165     using opmask_t = const Xbyak::Opmask;
1166
1167     const int simd_w = cpu_isa_traits<isa>::vlen / sizeof(float);
1168     const int vlen   = cpu_isa_traits<isa>::vlen;
1169
1170     Reg64 reg_from = rax;
1171     Reg64 reg_to = r8;
1172     Reg64 reg_work_amount = rsi;
1173     Reg64 imm_addr64 = rbx;
1174     Reg32 mask_reg = edx;
1175     Reg64 reg_idx = r15;
1176     Reg32 reg32_tmp = r14d;
1177     Reg64 p_idx_table = r13;
1178
1179     Xmm xmm_src = Xmm(1);
1180     Vmm vmm_src = Vmm(1);
1181     Zmm zmm_src_1 = Zmm(1);
1182
1183     Ymm ymm_src = Ymm(30);
1184     Zmm zmm_src = Zmm(30);
1185     Zmm zmm_idx = Zmm(31);
1186
1187     Zmm bf16_emu_reserv_1 = Zmm(26);
1188     Zmm bf16_emu_reserv_2 = Zmm(27);
1189     Zmm bf16_emu_reserv_3 = Zmm(28);
1190     Reg64 bf16_emu_reserv_4 = r14;
1191     Zmm bf16_emu_reserv_5 = Zmm(29);
1192     Zmm bf16_emu_reserv_6 = Zmm(29);
1193
1194     bool is_cpx_;
1195
1196     opmask_t k_mask = k7;
1197     opmask_t k_tail_mask = k6;
1198     opmask_t k_full_mask = k5;
1199
1200     Label idx_table;
1201
1202     jit_uni_eltwise_injector_f32<isa> *eltwise_injector_;
1203     bf16_emulation_t *bf16_emu_;
1204 };
1205
1206 } /* namespace */
1207
1208 template <cpu_isa_t isa, data_type_t d_type>
1209 status_t jit_uni_eltwise_fwd_t<isa, d_type>::pd_t::init() {
1210     using namespace alg_kind;
1211
1212     assert(engine()->kind() == engine_kind::cpu);
1213     bool ok = true && mayiuse(isa)
1214         && utils::one_of(desc()->prop_kind, prop_kind::forward_training,
1215                 prop_kind::forward_inference)
1216         && desc()->data_desc.data_type == d_type
1217         && !has_zero_dim_memory()
1218         && utils::one_of(desc()->alg_kind, eltwise_relu, eltwise_tanh,
1219                 eltwise_elu, eltwise_square, eltwise_abs, eltwise_sqrt,
1220                 eltwise_linear, eltwise_bounded_relu, eltwise_soft_relu,
1221                 eltwise_logistic, eltwise_clamp, eltwise_exp)
1222         && memory_desc_wrapper(src_pd()).is_dense(true)
1223         && IMPLICATION(!memory_desc_wrapper(src_pd()).is_dense(false),
1224                 math::eltwise_fwd_preserves_zero(desc()->alg_kind, true))
1225         && attr()->has_default_values();
1226
1227     return ok ? status::success : status::unimplemented;
1228 }
1229
1230 template <cpu_isa_t isa, data_type_t d_type>
1231 jit_uni_eltwise_fwd_t<isa, d_type>::jit_uni_eltwise_fwd_t(const pd_t *apd,
1232         const input_vector &inputs, const output_vector &outputs)
1233     : cpu_primitive_t(apd, inputs, outputs), kernel_(nullptr) {
1234     const auto &desc = *pd()->desc();
1235     switch (desc.alg_kind) {
1236     case alg_kind::eltwise_relu:
1237         kernel_ = new jit_uni_relu_kernel_f32<isa>(desc); break;
1238     default:
1239         kernel_ = new jit_uni_kernel_fwd_f32<isa>(desc);
1240     }
1241 }
1242
1243 template <cpu_isa_t isa, data_type_t d_type>
1244 jit_uni_eltwise_fwd_t<isa, d_type>::~jit_uni_eltwise_fwd_t()
1245 { delete kernel_; }
1246
1247 template <cpu_isa_t isa, data_type_t d_type>
1248 void jit_uni_eltwise_fwd_t<isa, d_type>::execute_forward() const {
1249     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
1250     auto dst = reinterpret_cast<data_t *>(this->memory(0));
1251
1252     const memory_desc_wrapper data_d(pd()->src_pd());
1253
1254     const size_t nelems = data_d.nelems(true);
1255
1256     src += data_d.blocking_desc().offset_padding;
1257     dst += data_d.blocking_desc().offset_padding;
1258
1259     const int cache_line = 16;
1260     parallel(0, utils::div_up(nelems, cache_line), [&](const int ithr, const int nthr) {
1261         size_t start{0}, end{0};
1262         balance211(utils::div_up(nelems, cache_line), nthr, ithr, start, end);
1263         start = nstl::min(nelems, start * cache_line);
1264         end = nstl::min(nelems, end * cache_line);
1265
1266         auto arg = jit_args();
1267         arg.from = (const void*)&src[start];
1268         arg.for_comparison = (const void*)&src[start];
1269         arg.to = (const void*)&dst[start];
1270         arg.work_amount = end - start;
1271         if (arg.work_amount)
1272             (*kernel_)(&arg);
1273     });
1274 }
1275
1276 template <cpu_isa_t isa, data_type_t d_type>
1277 status_t jit_uni_eltwise_bwd_t<isa, d_type>::pd_t::init() {
1278     assert(engine()->kind() == engine_kind::cpu);
1279
1280     bool ok = true
1281         && desc()->prop_kind == prop_kind::backward_data
1282         && utils::one_of(desc()->alg_kind, alg_kind::eltwise_relu)
1283         && src_pd()->desc()->data_type == d_type
1284         && !has_zero_dim_memory()
1285         && mayiuse(isa)
1286         && memory_desc_wrapper(src_pd()).is_dense()
1287         && memory_desc_wrapper(diff_dst_pd()) == memory_desc_wrapper(src_pd())
1288         && attr()->has_default_values();
1289
1290     return ok ? status::success : status::unimplemented;
1291 }
1292
1293 template <cpu_isa_t isa, data_type_t d_type>
1294 jit_uni_eltwise_bwd_t<isa, d_type>::jit_uni_eltwise_bwd_t(const pd_t *apd,
1295         const input_vector &inputs, const output_vector &outputs)
1296     : cpu_primitive_t(apd, inputs, outputs), kernel_(nullptr) {
1297     const auto &desc = *pd()->desc();
1298     switch (desc.alg_kind) {
1299     case alg_kind::eltwise_relu:
1300         kernel_ = new jit_uni_relu_kernel_f32<isa>(desc); break;
1301     default: assert(!"unknown eltwise alg_kind");
1302     }
1303 }
1304
1305 template <cpu_isa_t isa, data_type_t d_type>
1306 jit_uni_eltwise_bwd_t<isa, d_type>::~jit_uni_eltwise_bwd_t()
1307 { delete kernel_; }
1308
1309 template <cpu_isa_t isa, data_type_t d_type>
1310 void jit_uni_eltwise_bwd_t<isa, d_type>::execute_backward() const {
1311     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
1312     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
1313     auto diff_src = reinterpret_cast<data_t *>(this->memory(0));
1314
1315     const memory_desc_wrapper data_d(pd()->src_pd());
1316     const memory_desc_wrapper diff_data_d(pd()->diff_src_pd());
1317
1318     const size_t nelems = data_d.nelems();
1319
1320     src += data_d.blocking_desc().offset_padding;
1321     diff_dst += diff_data_d.blocking_desc().offset_padding;
1322     diff_src += diff_data_d.blocking_desc().offset_padding;
1323
1324     const int cache_line = 16;
1325
1326     parallel(0, utils::div_up(nelems, cache_line), [&](const int ithr, const int nthr) {
1327         size_t start{0}, end{0};
1328
1329         balance211(utils::div_up(nelems, cache_line), nthr, ithr, start, end);
1330         start = nstl::min(nelems, start * cache_line);
1331         end = nstl::min(nelems, end * cache_line);
1332
1333         auto arg = jit_args();
1334         arg.from = (const void*)&diff_dst[start];
1335         arg.to = (const void*)&diff_src[start];
1336         arg.for_comparison = (const void*)&src[start];
1337         arg.work_amount = end - start;
1338         if (arg.work_amount) {
1339             (*kernel_)(&arg);
1340         }
1341     });
1342 }
1343
1344 template struct jit_uni_eltwise_fwd_t<sse42, data_type::f32>;
1345 template struct jit_uni_eltwise_bwd_t<sse42, data_type::f32>;
1346 template struct jit_uni_eltwise_fwd_t<avx2, data_type::f32>;
1347 template struct jit_uni_eltwise_bwd_t<avx2, data_type::f32>;
1348 template struct jit_uni_eltwise_fwd_t<avx512_common, data_type::f32>;
1349 template struct jit_uni_eltwise_fwd_t<avx512_common, data_type::bf16>;
1350 template struct jit_uni_eltwise_bwd_t<avx512_common, data_type::f32>;
1351 template struct jit_uni_eltwise_bwd_t<avx512_common, data_type::bf16>;
1352
1353 }
1354 }
1355 }