21fe6d0f8e23bc6be32537e556980b938c5983db
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_eltwise.cpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <mkldnn_types.h>
18 #include "mkldnn_types.h"
19 #include "mkldnn_thread.hpp"
20 #include "nstl.hpp"
21 #include "utils.hpp"
22 #include "jit_generator.hpp"
23
24 #include "jit_uni_eltwise.hpp"
25
26 #define GET_OFF(field) offsetof(jit_args, field)
27
28 namespace mkldnn {
29 namespace impl {
30 namespace cpu {
31
32 using namespace Xbyak;
33
34 template <cpu_isa_t isa>
35 bool jit_uni_eltwise_injector_f32<isa>::is_free_vec(size_t idx) {
36     for (size_t i = 0; i < preserved_vecs_count; i++) {
37         if (preserved_vec_idxs[i] == idx) {
38             return false;
39         }
40     }
41     return true;
42 }
43
44 template <cpu_isa_t isa>
45 void jit_uni_eltwise_injector_f32<isa>::injector_preamble(size_t start_idx,
46         size_t end_idx) {
47     preserved_vecs_count = 0;
48     vecs_to_preserve = (size_t)jit_uni_eltwise_injector_f32<isa>::
49             aux_vecs_count(elt_alg);
50     start_idx_tail = start_idx;
51
52     // For sse42 mask register has to be Xmm(0)
53     if (isa == sse42 && vecs_to_preserve > 0) {
54         size_t idx = 0;
55         assert(idx < start_idx);
56         preserved_vec_idxs[preserved_vecs_count++] = idx;
57     }
58
59     for (size_t i = 0; i < vecs_count; i++) {
60         if (preserved_vecs_count >= vecs_to_preserve)
61             break;
62
63         size_t idx = i;
64         if (is_free_vec(idx) && (idx < start_idx || idx >= end_idx)) {
65             preserved_vec_idxs[preserved_vecs_count++] = idx;
66         }
67     }
68
69     size_t preserved_vecs_count_tail = vecs_to_preserve - preserved_vecs_count;
70     for (size_t i = 0; i < preserved_vecs_count_tail; i++) {
71         size_t idx = start_idx_tail;
72         if (is_free_vec(idx)) {
73             preserved_vec_idxs[preserved_vecs_count++] = idx;
74             start_idx_tail++;
75         }
76     }
77
78     assert(preserved_vecs_count == vecs_to_preserve);
79
80     if (save_vecs_state) {
81         h->push(p_table);
82
83         h->sub(h->rsp, preserved_vecs_count * vlen);
84         for (size_t i = 0; i < preserved_vecs_count; ++i)
85             h->uni_vmovups(h->ptr[h->rsp + i * vlen],
86                     Vmm(preserved_vec_idxs[i]));
87     }
88
89     assign_regs();
90 }
91
92 template <cpu_isa_t isa>
93 void jit_uni_eltwise_injector_f32<isa>::injector_preamble_tail(
94         size_t start_idx) {
95     size_t tail_vecs_to_preserve = start_idx_tail - start_idx;
96     int idx_off = (vecs_to_preserve - tail_vecs_to_preserve);
97
98     if (tail_vecs_to_preserve > 0) {
99         if (save_vecs_state) {
100             h->add(h->rsp, idx_off * vlen);
101             for (size_t i = 0; i < tail_vecs_to_preserve; ++i)
102                 h->uni_vmovups(Vmm(preserved_vec_idxs[idx_off + i]),
103                         h->ptr[h->rsp + i * vlen]);
104         }
105
106         for (size_t i = 0; i < tail_vecs_to_preserve; ++i) {
107             preserved_vec_idxs[idx_off + i] += tail_vecs_to_preserve;
108         }
109
110         if (save_vecs_state) {
111             for (size_t i = 0; i < tail_vecs_to_preserve; ++i)
112                 h->uni_vmovups(h->ptr[h->rsp + i * vlen],
113                         Vmm(preserved_vec_idxs[idx_off + i]));
114             h->sub(h->rsp, idx_off * vlen);
115         }
116
117         assign_regs();
118     }
119 }
120
121 template <cpu_isa_t isa>
122 void jit_uni_eltwise_injector_f32<isa>::injector_postamble() {
123     if (save_vecs_state) {
124         for (size_t i = 0; i < preserved_vecs_count; ++i)
125             h->uni_vmovups(Vmm(preserved_vec_idxs[i]),
126                     h->ptr[h->rsp + i * vlen]);
127         h->add(h->rsp, preserved_vecs_count * vlen);
128
129         h->pop(p_table);
130     }
131 }
132
133 template <cpu_isa_t isa>
134 void jit_uni_eltwise_injector_f32<isa>::assign_regs() {
135     vmm_mask = Vmm(preserved_vec_idxs[0]);
136     vmm_aux0 = Vmm(preserved_vec_idxs[0]);
137     vmm_aux1 = Vmm(preserved_vec_idxs[1]);
138     vmm_aux2 = Vmm(preserved_vec_idxs[2]);
139     vmm_aux3 = Vmm(preserved_vec_idxs[3]);
140
141     p_table = Xbyak::Reg64(table_reg_idx);
142     k_mask = Xbyak::Opmask(opmask_idx);
143 }
144
145 template <cpu_isa_t isa>
146 void jit_uni_eltwise_injector_f32<isa>::exp_compute_vector(const Vmm &vmm_src) {
147     const unsigned char _op_floor = 1;
148
149     h->uni_vminps(vmm_src, vmm_src, h->ptr[p_table + 10 * vlen]);
150     h->uni_vmaxps(vmm_src, vmm_src, h->ptr[p_table + 11 * vlen]);
151     h->uni_vmovups(vmm_aux0, vmm_src);
152     //calculate exp(x)
153     // fx = x * log2ef + 0.5
154     h->uni_vmulps(vmm_src, vmm_src, h->ptr[p_table + 2 * vlen]);
155     h->uni_vaddps(vmm_src, vmm_src, h->ptr[p_table + 1 * vlen]);
156
157     // tmp = floorf(fx)
158     if (isa == avx512_common) {
159         h->vcvtps2dq(vmm_aux1 | h->T_rd_sae, vmm_src);
160         h->vcvtdq2ps(vmm_aux1, vmm_aux1);
161
162         unsigned char _cmp_gt_os = 14;
163         Xbyak::Opmask k_mask_tmp = Xbyak::Opmask(2);
164         h->vcmpps(k_mask_tmp, vmm_aux1, vmm_src, _cmp_gt_os);
165         h->vmovups(vmm_aux3 | k_mask_tmp | h->T_z,
166                 h->zword[p_table + 0 * vlen]);
167
168         h->uni_vsubps(vmm_aux1, vmm_aux1, vmm_aux3);
169     } else {
170         h->uni_vroundps(vmm_aux1, vmm_src, _op_floor);
171     }
172
173     //keep fx for further computations
174     h->uni_vmovups(vmm_src, vmm_aux1); //vmm_src = fx
175
176     //x = x - fx * ln2
177     h->uni_vfnmadd231ps(vmm_aux0, vmm_aux1, h->ptr[p_table + 3 * vlen]);
178
179     // compute 2^n
180     h->uni_vcvtps2dq(vmm_aux1, vmm_src);
181     h->uni_vpaddd(vmm_aux1, vmm_aux1, h->ptr[p_table + 4 * vlen]);
182     h->uni_vpslld(vmm_aux1, vmm_aux1, 23); //Vmm(6) = 2^-fx
183
184     // y = p5
185     h->uni_vmovups(vmm_src, h->ptr[p_table + 9 * vlen]);
186     // y = y * x + p4
187     h->uni_vfmadd213ps(vmm_src, vmm_aux0, h->ptr[p_table + 8 * vlen]);
188     // y = y * x + p3
189     h->uni_vfmadd213ps(vmm_src, vmm_aux0, h->ptr[p_table + 7 * vlen]);
190     // y = y * x + p2
191     h->uni_vfmadd213ps(vmm_src, vmm_aux0, h->ptr[p_table + 6 * vlen]);
192     // y = y * x + p1
193     h->uni_vfmadd213ps(vmm_src, vmm_aux0, h->ptr[p_table + 0 * vlen]);
194     // y = y * x + p0
195     h->uni_vfmadd213ps(vmm_src, vmm_aux0, h->ptr[p_table + 5 * vlen]);  //exp(q)
196     // y = y * 2^n
197     h->uni_vmulps(vmm_src, vmm_src, vmm_aux1);
198 }
199
200 template <cpu_isa_t isa>
201 void jit_uni_eltwise_injector_f32<isa>::relu_compute_vector(
202         const Vmm &vmm_src) {
203     unsigned char _cmp_gt_os = isa == avx512_common ? 14 : 6;
204
205     int alpha_off = 0 * vlen;
206     int zero_off = 1 * vlen;
207
208     h->uni_vmovups(vmm_aux1, vmm_src);
209     if (isa == sse42) {
210         h->movups(vmm_mask, vmm_src);
211         h->mulps(vmm_src, h->ptr[p_table + alpha_off]);
212         h->cmpps(vmm_mask, h->ptr[p_table + zero_off], _cmp_gt_os);
213         h->blendvps(vmm_src, vmm_aux1);
214     } else if (isa == avx2) {
215         h->vmulps(vmm_src, vmm_src, h->ptr[p_table + alpha_off]);
216         h->vcmpgtps(vmm_mask, vmm_aux1, h->ptr[p_table + zero_off]);
217         h->vblendvps(vmm_src, vmm_src, vmm_aux1, vmm_mask);
218     } else if (isa == avx512_common) {
219         h->vmulps(vmm_src, vmm_src, h->ptr[p_table + alpha_off]);
220         h->vcmpps(k_mask, vmm_aux1, h->ptr[p_table + zero_off], _cmp_gt_os);
221         h->vblendmps(vmm_src | k_mask, vmm_src,
222                      vmm_aux1);
223     }
224 }
225
226 template <cpu_isa_t isa>
227 void jit_uni_eltwise_injector_f32<isa>::relu_zero_ns_compute_vector(
228         const Vmm &vmm_src) {
229     int zero_off = 1 * vlen;
230     h->uni_vmaxps(vmm_src, vmm_src, h->ptr[p_table + zero_off]);
231 }
232
233 template <cpu_isa_t isa>
234 void jit_uni_eltwise_injector_f32<isa>::elu_compute_vector(const Vmm &vmm_src) {
235     const unsigned char _cmp_gt_os = 6;
236     const unsigned char _cmp_let_os = 2;
237     int alpha_off = 12 * vlen;
238     int zero_off = 13 * vlen;
239
240     // compute exponent
241     h->uni_vmovups(vmm_aux2, vmm_src);
242     exp_compute_vector(vmm_src);
243
244     // alpha * (exp(x) - 1)
245     h->uni_vsubps(vmm_src, vmm_src, h->ptr[p_table + 0 * 32]);
246     h->uni_vmulps(vmm_src, vmm_src, h->ptr[p_table + alpha_off]);
247
248     // combine with mask
249     if (isa == sse42) {
250         h->pxor(vmm_mask, vmm_mask);
251         h->cmpps(vmm_mask,  vmm_aux2, _cmp_let_os);
252         h->blendvps(vmm_src, vmm_aux2);
253     } else if (isa == avx2) {
254         h->uni_vcmpgtps(vmm_mask, vmm_aux2, h->ptr[p_table + zero_off]);
255         h->uni_vblendvps(vmm_src, vmm_src, vmm_aux2, vmm_mask);
256     } else if (isa == avx512_common) {
257         h->vcmpps(k_mask, vmm_aux2, h->ptr[p_table + zero_off], _cmp_gt_os);
258         h->vblendmps(vmm_src | k_mask, vmm_src, vmm_aux2);
259     }
260 }
261
262 template <cpu_isa_t isa>
263 void jit_uni_eltwise_injector_f32<isa>::tanh_compute_vector(
264         const Vmm &vmm_src) {
265     // compute exp(2x)
266     h->uni_vaddps(vmm_src, vmm_src, vmm_src);
267     exp_compute_vector(vmm_src);
268     // dup exp(2x)
269     h->uni_vmovups(vmm_aux0, vmm_src);
270     // (exp(2x) - 1)
271     h->uni_vsubps(vmm_src, vmm_src, h->ptr[p_table + 0 * vlen]);
272     // (exp(2x) + 1)
273     h->uni_vaddps(vmm_aux0, vmm_aux0, h->ptr[p_table + 0 * vlen]);
274     // y = (exp(2x) - 1) / (exp(2x) + 1)
275     h->uni_vdivps(vmm_src, vmm_src, vmm_aux0);
276 }
277
278 template <cpu_isa_t isa>
279 void jit_uni_eltwise_injector_f32<isa>::square_compute_vector(
280         const Vmm &vmm_src) {
281     h->uni_vmulps(vmm_src, vmm_src, vmm_src);
282 }
283
284 template <cpu_isa_t isa>
285 void jit_uni_eltwise_injector_f32<isa>::abs_compute_vector(const Vmm &vmm_src) {
286     // compute abs(x) = _mm_and_ps(x, 01111..111));
287     h->uni_vandps(vmm_src, vmm_src, h->ptr[p_table + 0*vlen]);
288 }
289
290 template <cpu_isa_t isa>
291 void jit_uni_eltwise_injector_f32<isa>::sqrt_compute_vector(
292         const Vmm &vmm_src) {
293     if (isa == avx512_common) {
294         unsigned char _cmp_gt_os = 6;
295
296         h->vcmpps(k_mask, vmm_src, h->ptr[p_table + 0 * vlen], _cmp_gt_os);
297         h->uni_vsqrtps(vmm_aux1, vmm_src);
298         h->uni_vmovups(vmm_src, h->ptr[p_table + 0*vlen]);
299         h->vblendmps(vmm_src | k_mask, vmm_src, vmm_aux1);
300     } else {
301         h->uni_vmovups(vmm_mask, vmm_src);
302         h->uni_vcmpgtps(vmm_mask, vmm_mask, h->ptr[p_table + 0*vlen]);
303         h->uni_vsqrtps(vmm_aux1, vmm_src);
304         h->uni_vmovups(vmm_src, h->ptr[p_table + 0*vlen]);
305         h->uni_vblendvps(vmm_src, vmm_src, vmm_aux1, vmm_mask);
306     }
307 }
308
309 template <cpu_isa_t isa>
310 void jit_uni_eltwise_injector_f32<isa>::linear_compute_vector(
311         const Vmm &vmm_src) {
312     // compute x = alpha * x + beta;
313     h->uni_vmovups(vmm_aux0, h->ptr[p_table + 0*vlen]);
314     h->uni_vfmadd213ps(vmm_src, vmm_aux0, h->ptr[p_table + 1*vlen]);
315 }
316
317 template <cpu_isa_t isa>
318 void jit_uni_eltwise_injector_f32<isa>::bounded_relu_compute_vector(
319         const Vmm &vmm_src) {
320     // compute bounded relu */
321     h->uni_vmaxps(vmm_src, vmm_src, h->ptr[p_table + 1*vlen]);
322     h->uni_vminps(vmm_src, vmm_src, h->ptr[p_table + 0*vlen]);
323 }
324
325 template <cpu_isa_t isa>
326 void jit_uni_eltwise_injector_f32<isa>::clamp_compute_vector(
327         const Vmm &vmm_src) {
328     h->uni_vmaxps(vmm_src, vmm_src, h->ptr[p_table + 1*vlen]);
329     h->uni_vminps(vmm_src, vmm_src, h->ptr[p_table + 0*vlen]);
330 }
331
332 template <cpu_isa_t isa>
333 void jit_uni_eltwise_injector_f32<isa>::soft_relu_compute_vector(
334         const Vmm &vmm_src) {
335     const unsigned char _op_floor = 1;
336     // duplicate src
337     h->uni_vmovups(vmm_aux2, vmm_src);
338
339     h->uni_vminps(vmm_src, vmm_src, h->ptr[p_table + 24 * vlen]);
340     h->uni_vmaxps(vmm_src, vmm_src, h->ptr[p_table + 25 * vlen]);
341     h->uni_vmovups(vmm_aux1, vmm_src);
342     // calculate exp(x)
343     // fx = x * log2ef + 0.5
344     h->uni_vmulps(vmm_src, vmm_src, h->ptr[p_table + 2 * vlen]);
345     h->uni_vaddps(vmm_src, vmm_src, h->ptr[p_table + 1 * vlen]);
346
347     // tmp = floorf(fx)
348     if (isa == avx512_common) {
349         h->vcvtps2dq(vmm_aux0 | h->T_rd_sae, vmm_src);
350         h->vcvtdq2ps(vmm_aux0, vmm_aux0);
351
352         unsigned char _cmp_gt_os = 14;
353         h->vcmpps(k_mask, vmm_aux0, vmm_src, _cmp_gt_os);
354         h->vmovups(vmm_aux3 | k_mask | h->T_z, h->ptr[p_table + 0 * vlen]);
355
356         h->vsubps(vmm_aux0, vmm_aux0, vmm_aux3);
357     } else {
358         h->uni_vroundps(vmm_aux0, vmm_src, _op_floor);
359     }
360
361     // keep fx for further computations
362     h->uni_vmovups(vmm_src, vmm_aux0); //vmm_src = fx
363     // calculation fx * ln2
364     h->uni_vmulps(vmm_aux0, vmm_aux0, h->ptr[p_table + 3 * vlen]);
365     // x = x - fx * ln2
366     h->uni_vsubps(vmm_aux1, vmm_aux1, vmm_aux0);
367     // y = p5
368     h->uni_vmovups(vmm_aux3, h->ptr[p_table + 22 * vlen]);
369     // y = y * x + p4
370     h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, h->ptr[p_table + 21 * vlen]);
371     // y = y * x + p3
372     h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, h->ptr[p_table + 20 * vlen]);
373     // y = y * x + p2
374     h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, h->ptr[p_table + 19 * vlen]);
375     // y = y * x + p1
376     h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, h->ptr[p_table + 0 * vlen]);
377     // y = y * x + p0
378     h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, h->ptr[p_table + 17 * vlen]);
379
380     // compute 2^(-n)
381     if (isa == avx512_common) {
382         h->vmulps(vmm_aux1, vmm_src, h->ptr[p_table + 23 * vlen]);
383         h->vcvtps2dq(vmm_aux1, vmm_aux1);
384     } else {
385         h->uni_vcvtps2dq(vmm_aux1, vmm_src);
386         h->uni_vpsignd(vmm_aux1, vmm_aux1, h->ptr[p_table + 23 * vlen]);
387     }
388
389     h->uni_vpaddd(vmm_aux1, vmm_aux1, h->ptr[p_table + 4 * vlen]);
390     h->uni_vpslld(vmm_aux1, vmm_aux1, 23); //vmm_aux1 = 2^-fx
391     // calculate ln(1 + y)
392     h->uni_vaddps(vmm_aux3, vmm_aux3, vmm_aux1);
393     // x = y; y is free; keep x for further computations
394     h->uni_vmovups(vmm_src, vmm_aux3);
395     // frexp()
396     h->uni_vpsrld(vmm_src, vmm_src, 23);
397     h->uni_vcvtdq2ps(vmm_src, vmm_src);
398     // got n. where n is x = 2^n * y. y = 0.5 .. 1
399     h->uni_vsubps(vmm_src, vmm_src, h->ptr[p_table + 5 * vlen]);
400
401     h->uni_vandps(vmm_aux3, vmm_aux3, h->ptr[p_table + 6 * vlen]);
402     // got y. (mantisa)  0.5 < y < 1
403     h->uni_vorps(vmm_aux3, vmm_aux3, h->ptr[p_table + 7 * vlen]);
404     // y  = y - 1
405     h->uni_vsubps(vmm_aux3, vmm_aux3, h->ptr[p_table + 0 * vlen]);
406     // y = p8
407     h->uni_vmovups(vmm_aux1, h->ptr[p_table + 16 * vlen]);
408     // y = y * x + p7
409     h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, h->ptr[p_table + 15 * vlen]);
410     // y = y * x + p6
411     h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, h->ptr[p_table + 14 * vlen]);
412     // y = y * x + p5
413     h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, h->ptr[p_table + 13 * vlen]);
414     // y = y * x + p4
415     h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, h->ptr[p_table + 12 * vlen]);
416     // y = y * x + p3
417     h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, h->ptr[p_table + 11 * vlen]);
418     // y = y * x + p2
419     h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, h->ptr[p_table + 10 * vlen]);
420     // y = y * x + p1
421     h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, h->ptr[p_table + 9 * vlen]);
422     // y = y * x + p0 ; p0 = 0
423     h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, h->ptr[p_table + 8 * vlen]);
424     //calculate ln(2) * n
425     h->uni_vmulps(vmm_src, vmm_src, h->ptr[p_table + 3 * vlen]);
426     h->uni_vaddps(vmm_aux1, vmm_aux1, vmm_src);
427     h->uni_vaddps(vmm_aux1, vmm_aux1, vmm_aux0);
428
429     // get vmm_mask = src > max logf
430     h->uni_vmovups(vmm_mask, vmm_aux2);
431     if (isa == avx512_common) {
432         unsigned char _cmp_gt_os = 6;
433         // y = (x < max log f) ? soft_relu(x) : x
434         h->vcmpps(k_mask, vmm_mask, h->ptr[p_table + 24 * vlen], _cmp_gt_os);
435         h->vblendmps(vmm_aux1 | k_mask, vmm_aux1, vmm_aux2);
436     } else {
437         // y = (x < max log f) ? soft_relu(x) : x
438         h->uni_vcmpgtps(vmm_mask, vmm_mask, h->ptr[p_table + 24 * vlen]);
439         h->uni_vblendvps(vmm_aux1, vmm_aux1, vmm_aux2, vmm_mask);
440     }
441
442     h->uni_vmovups(vmm_src, vmm_aux1);
443 }
444
445 template <cpu_isa_t isa>
446 void jit_uni_eltwise_injector_f32<isa>::logistic_compute_vector(
447         const Vmm &vmm_src) {
448     exp_compute_vector(vmm_src);
449     // dup exp(x)
450     h->uni_vmovups(vmm_aux0, vmm_src);
451     // (exp(x) + 1)
452     h->uni_vaddps(vmm_aux0, vmm_aux0, h->ptr[p_table + 0 * vlen]);
453     // y = exp(x) / (exp(x) + 1)
454     h->uni_vdivps(vmm_src, vmm_src, vmm_aux0);
455 }
456
457 template <cpu_isa_t isa>
458 void jit_uni_eltwise_injector_f32<isa>::relu_prepare_table() {
459     for (size_t d = 0; d < vlen / sizeof(float); ++d) {
460         h->dd(float2int(alpha));
461     }
462     for (size_t d = 0; d < vlen / sizeof(float); ++d) {
463         h->dd(0);
464     }
465 }
466
467 template <cpu_isa_t isa>
468 void jit_uni_eltwise_injector_f32<isa>::elu_prepare_table() {
469     const unsigned int cvals[] = {
470             0x3f800000, // [0] 1.0f
471             0x3f000000, // [1] 0.5f
472             0x3fb8aa3b, // [2] log2ef = 1.44269502f
473             0x3f317218, // [3] ln2f =   0.69314718f
474             0x0000007f, // [4] 0x7f
475             // exp(x) polynom
476             0x3f800001, // [5] p0 = 1.0000001f
477             0x3efffe85, // [6] p2 = 0.4999887f
478             0x3e2aaa3e, // [7] p3 = 0.16666505f
479             0x3d2bb1b1, // [8] p4 = 0.041917507f
480             0x3c091ec1, // [9] p5 = 0.008369149f
481             0x42b0c0a5, //[10] max logf = 88.3762589f
482             0xc1766666  //[11] min logf = -14.5f
483     };
484
485     for (size_t i = 0; i < sizeof(cvals) / sizeof(cvals[0]); ++i) {
486         for (size_t d = 0; d < vlen / sizeof(float); ++d) {
487             h->dd(cvals[i]);
488         }
489     }
490     for (size_t d = 0; d < vlen / sizeof(float); ++d) {
491         h->dd(float2int(alpha));
492     }
493     for (size_t d = 0; d < vlen / sizeof(float); ++d) {
494         h->dd(0);
495     }
496 }
497
498 template <cpu_isa_t isa>
499 void jit_uni_eltwise_injector_f32<isa>::soft_relu_prepare_table() {
500     const unsigned int cvals[] = {
501             0x3f800000, // [0] 1.0f
502             0x3f000000, // [1] 0.5f
503             0x3fb8aa3b, // [2] log2ef = 1.44269502f
504             0x3f317218, // [3] ln2f =   0.69314718f
505             0x0000007f, // [4] 0x7f
506             0x42fc0000, // [5] 126
507             0x807fffff, // [6] and with (to get 0.5 * mantissa)
508             0x3f000000, // [7] or with (to get 0.5 * mantissa)
509             // ln(1 + x) polynomial
510             0xb2b4637d, // [8]  p0 = 0.0000000244f
511             0x3f7fff8e, // [9]  p1 = 0.9999976971f
512             0xbf001759, //[10]  p2 = -0.5002478215f
513             0x3ea70608, //[11]  p3 = 0.3272714505f
514             0xbea3d7bf, //[12]  p4 = -0.3153830071f
515             0xbe361d04, //[13]  p5 = -0.1701777461f
516             0xbfa8f1e6, //[14]  p6 = -1.3254635147f
517             0xbfe1e812, //[15]  p7 = -1.7971917960f
518             0xbfc4d30e, //[16]  p8 = -1.5652673123f
519             // exp(x) polynomial
520             0x3f800001, //[17]  p0 = 1.0000001f
521             0x3f800000, //[18]  p1 = 1.0f
522             0x3efffe85, //[19]  p2 = 0.4999887f
523             0x3e2aaa3e, //[20]  p3 = 0.16666505f
524             0x3d2bb1b1, //[21]  p4 = 0.041917507f
525             0x3c091ec1, //[22]  p5 = 0.008369149f
526             0xbf800000, //[23] is required for sign changing
527             0x42b0c0a5, //[24] max logf = 88.3762589f
528             0xc1766666  //[25] min logf = -14.5f
529     };
530
531     for (size_t i = 0; i < sizeof(cvals) / sizeof(cvals[0]); ++i) {
532         for (size_t d = 0; d < vlen / sizeof(float); ++d) {
533             h->dd(cvals[i]);
534         }
535     }
536 }
537
538 template <cpu_isa_t isa>
539 void jit_uni_eltwise_injector_f32<isa>::abs_prepare_table() {
540     for (size_t d = 0; d < vlen / sizeof(float); ++d) {
541         h->dd(0x7fffffff);
542     }
543 }
544
545 template <cpu_isa_t isa>
546 void jit_uni_eltwise_injector_f32<isa>::sqrt_prepare_table() {
547     for (size_t d = 0; d < vlen / sizeof(float); ++d) {
548         h->dd(0);
549     }
550 }
551
552 template <cpu_isa_t isa>
553 void jit_uni_eltwise_injector_f32<isa>::linear_prepare_table() {
554     for (size_t d = 0; d < vlen / sizeof(float); ++d) {
555         h->dd(float2int(alpha));
556     }
557     for (size_t d = 0; d < vlen / sizeof(float); ++d) {
558         h->dd(float2int(beta));
559     }
560 }
561
562 template <cpu_isa_t isa>
563 void jit_uni_eltwise_injector_f32<isa>::bounded_relu_prepare_table() {
564     for (size_t d = 0; d < vlen / sizeof(float); ++d) {
565         h->dd(float2int(alpha));
566     }
567     for (size_t d = 0; d < vlen / sizeof(float); ++d) {
568         h->dd(0);
569     }
570 }
571
572 template <cpu_isa_t isa>
573 void jit_uni_eltwise_injector_f32<isa>::clamp_prepare_table() {
574     for (size_t d = 0; d < vlen / sizeof(float); ++d) {
575         h->dd(float2int(alpha));
576     }
577     for (size_t d = 0; d < vlen / sizeof(float); ++d) {
578         h->dd(float2int(beta));
579     }
580 }
581
582 template <cpu_isa_t isa>
583 int jit_uni_eltwise_injector_f32<isa>::aux_vecs_count(alg_kind_t elt_alg) {
584     switch (elt_alg) {
585         case alg_kind::eltwise_relu: return (alpha == 0.f) ? 0 : 2;
586         case alg_kind::eltwise_elu: return 4;
587         case alg_kind::eltwise_tanh: return 4;
588         case alg_kind::eltwise_square: return 0;
589         case alg_kind::eltwise_abs: return 0;
590         case alg_kind::eltwise_sqrt: return 2;
591         case alg_kind::eltwise_linear: return 1;
592         case alg_kind::eltwise_bounded_relu: return 0;
593         case alg_kind::eltwise_soft_relu: return 4;
594         case alg_kind::eltwise_logistic: return 4;
595         case alg_kind::eltwise_clamp: return 0;
596         default: assert(!"unsupported eltwise algorithm");
597     }
598
599     return 0;
600 }
601
602 template <cpu_isa_t isa>
603 void jit_uni_eltwise_injector_f32<isa>::compute_body(size_t start_idx,
604         size_t end_idx) {
605     h->mov(p_table, l_table);
606
607     for (size_t idx = start_idx; idx < end_idx; idx++) {
608         switch (elt_alg) {
609             case alg_kind::eltwise_relu:
610                 if (alpha == 0.f)
611                     relu_zero_ns_compute_vector(Vmm(idx));
612                 else
613                     relu_compute_vector(Vmm(idx));
614                 break;
615             case alg_kind::eltwise_elu:
616                 elu_compute_vector(Vmm(idx)); break;
617             case alg_kind::eltwise_tanh:
618                 tanh_compute_vector(Vmm(idx)); break;
619             case alg_kind::eltwise_square:
620                 square_compute_vector(Vmm(idx)); break;
621             case alg_kind::eltwise_abs:
622                 abs_compute_vector(Vmm(idx)); break;
623             case alg_kind::eltwise_sqrt:
624                 sqrt_compute_vector(Vmm(idx)); break;
625             case alg_kind::eltwise_linear:
626                 linear_compute_vector(Vmm(idx)); break;
627             case alg_kind::eltwise_bounded_relu:
628                 bounded_relu_compute_vector(Vmm(idx)); break;
629             case alg_kind::eltwise_soft_relu:
630                 soft_relu_compute_vector(Vmm(idx)); break;
631             case alg_kind::eltwise_logistic:
632                 logistic_compute_vector(Vmm(idx)); break;
633             case alg_kind::eltwise_clamp:
634                 clamp_compute_vector(Vmm(idx)); break;
635             default: assert(!"unsupported eltwise algorithm");
636         }
637     }
638 }
639
640 template <cpu_isa_t isa>
641 void jit_uni_eltwise_injector_f32<isa>::compute_vector_range(size_t start_idx,
642         size_t end_idx) {
643     assert(start_idx < vecs_count);
644     assert(end_idx <= vecs_count);
645     assert(start_idx < end_idx);
646
647     injector_preamble(start_idx, end_idx);
648     compute_body(start_idx_tail, end_idx);
649     injector_preamble_tail(start_idx);
650     compute_body(start_idx, start_idx_tail);
651     injector_postamble();
652 }
653
654 template <cpu_isa_t isa>
655 void jit_uni_eltwise_injector_f32<isa>::compute_vector(size_t idx) {
656     compute_vector_range(idx, idx + 1);
657 }
658
659 template <cpu_isa_t isa>
660 void jit_uni_eltwise_injector_f32<isa>::prepare_table() {
661     h->align(64);
662     h->L(l_table);
663
664     switch (elt_alg) {
665         case alg_kind::eltwise_relu:
666             relu_prepare_table(); break;
667         case alg_kind::eltwise_elu:
668         case alg_kind::eltwise_tanh:
669         case alg_kind::eltwise_logistic:
670             elu_prepare_table(); break;
671         case alg_kind::eltwise_soft_relu:
672             soft_relu_prepare_table(); break;
673         case alg_kind::eltwise_abs:
674             abs_prepare_table(); break;
675         case alg_kind::eltwise_sqrt:
676             sqrt_prepare_table(); break;
677         case alg_kind::eltwise_linear:
678             linear_prepare_table(); break;
679         case alg_kind::eltwise_bounded_relu:
680             bounded_relu_prepare_table(); break;
681         case alg_kind::eltwise_square:
682             break;
683         case alg_kind::eltwise_clamp:
684             clamp_prepare_table(); break;
685         default: assert(!"unsupported eltwise algorithm");
686     }
687 }
688
689 template struct jit_uni_eltwise_injector_f32<avx512_common>;
690 template struct jit_uni_eltwise_injector_f32<avx2>;
691 template struct jit_uni_eltwise_injector_f32<sse42>;
692
693
694 struct jit_args {
695     const float *from;
696     const float *for_comparison;
697     const float *to;
698     size_t work_amount;
699 };
700
701 struct jit_uni_eltwise_kernel_f32 : public c_compatible {
702     const eltwise_desc_t &desc_;
703
704     void (*ker_)(const jit_args *);
705     void operator()(const jit_args *args) { assert(ker_); ker_(args); }
706
707     jit_uni_eltwise_kernel_f32(const eltwise_desc_t &desc)
708         : desc_(desc), ker_(nullptr) {}
709     virtual ~jit_uni_eltwise_kernel_f32() {}
710
711 protected:
712     bool is_bwd() const { return desc_.prop_kind == prop_kind::backward_data; }
713 };
714
715 /* jit kernels */
716 namespace {
717
718 template <cpu_isa_t isa>
719 struct jit_uni_relu_kernel_f32 : public jit_uni_eltwise_kernel_f32,
720     public jit_generator
721 {
722     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_relu_kernel_f32)
723
724     void compute_step(bool vectorize, const int uf, const int shift) {
725         for (int i = 0; i < uf; i++) {
726             if (vectorize) {
727                 uni_vmovups(Vmm(i + 1), ptr[reg_from + i * shift]);
728                 if (is_bwd())
729                     uni_vmovups(Vmm(uf + i + 1),
730                                 ptr[reg_for_comparison + i * shift]);
731             } else {
732                 movss(Xmm(i + 1), ptr[reg_from + i * shift]);
733                 if (is_bwd())
734                     movss(Xmm(uf + i + 1),
735                           ptr[reg_for_comparison + i * shift]);
736             }
737         }
738
739         if (isa == sse42) {
740             for (int i = 0; i < uf; i++) {
741                 movups(Vmm(2 * uf + i + 1), Vmm(i + 1));
742                 mulps(Vmm(2 * uf + i + 1), vmm_ns);
743
744                 Vmm mask = Vmm(0);
745                 if (is_bwd()) {
746                     movups(mask, Vmm(uf + i + 1));
747                     cmpps(mask, vmm_zero, _cmp_nle_us);
748                 } else {
749                     movups(mask, Vmm(i + 1));
750                     cmpps(mask, vmm_zero, _cmp_nle_us);
751                 }
752                 blendvps(Vmm(2 * uf + i + 1), Vmm(i + 1));
753             }
754         } else {
755             for (int i = 0; i < uf; i++) {
756                 vmulps(Vmm(2 * uf + i + 1), Vmm(i + 1), vmm_ns);
757                 if (isa == avx2) {
758                     if (is_bwd())
759                         vcmpgtps(vmm_mask, Vmm(uf + i + 1), vmm_zero);
760                     else
761                         vcmpgtps(vmm_mask, Vmm(i + 1), vmm_zero);
762
763                     vblendvps(Vmm(2 * uf + i + 1), Vmm(2 * uf + i + 1),
764                               Vmm(i + 1), vmm_mask);
765
766                 } else {
767                     if (is_bwd())
768                         vcmpps(k_mask, Vmm(uf + i + 1), vmm_zero, _cmp_nle_us);
769                     else
770                         vcmpps(k_mask, Vmm(i + 1), vmm_zero, _cmp_nle_us);
771                     vblendmps(Vmm(2 * uf + i + 1) | k_mask, Vmm(2 * uf + i + 1),
772                               Vmm(i + 1));
773                 }
774             }
775         }
776
777         for (int i = 0; i < uf; i++) {
778             if (vectorize) {
779                 uni_vmovups(ptr[reg_to + i * shift], Vmm(2 * uf + i + 1));
780             } else {
781                 movss(ptr[reg_to + i * shift], Xmm(2 * uf + i + 1));
782             }
783         }
784     }
785
786     jit_uni_relu_kernel_f32(const eltwise_desc_t &desc)
787         : jit_uni_eltwise_kernel_f32(desc), jit_generator() {
788         assert(desc.alg_kind == alg_kind::eltwise_relu);
789         assert(isa == sse42 || isa == avx2 || isa == avx512_common);
790
791         Reg64 param = abi_param1;
792
793         const int simd_w = cpu_isa_traits<isa>::vlen / sizeof(float);
794         const int loop_dec[] = {simd_w, 1};
795         const int uf[] = {1, 1};
796         const int shift[] = {cpu_isa_traits<isa>::vlen, sizeof(float)};
797         const bool loop_vectorize[] = {true, false};
798
799         this->preamble();
800
801         mov(reg_from, ptr[param + GET_OFF(from)]);
802         if (is_bwd())
803             mov(reg_for_comparison, ptr[param + GET_OFF(for_comparison)]);
804         mov(reg_to, ptr[param + GET_OFF(to)]);
805         mov(reg_work_amount, ptr[param + GET_OFF(work_amount)]);
806
807         mov(imm_addr64, float2int(desc.alpha));
808         movq(xmm_ns, imm_addr64);
809         uni_vbroadcastss(vmm_ns, xmm_ns);
810
811         uni_vpxor(vmm_zero, vmm_zero, vmm_zero);
812
813         Label loop_label[3];
814
815         for (int id = 0; id < 2; id++) {
816             L(loop_label[id]);
817             cmp(reg_work_amount, uf[id] * loop_dec[id] - 1);
818             jle(loop_label[id + 1], T_NEAR);
819
820             compute_step(loop_vectorize[id], uf[id], shift[id]);
821
822             add(reg_from, uf[id] * shift[id]);
823             add(reg_to, uf[id] * shift[id]);
824             if (is_bwd())
825                 add(reg_for_comparison, uf[id] * shift[id]);
826
827             sub(reg_work_amount, uf[id] * loop_dec[id]);
828             jmp(loop_label[id]);
829         }
830
831         L(loop_label[2]);
832         this->postamble();
833
834         ker_ = (decltype(ker_))this->getCode();
835     }
836
837 private:
838     using Vmm = typename utils::conditional3<isa == sse42, Xmm,
839                                              isa == avx2, Ymm, Zmm>::type;
840
841     Reg64 reg_from = rax;
842     Reg64 reg_for_comparison = is_bwd() ? rdx : reg_from;
843     Reg64 reg_to = r8;
844     Reg64 reg_work_amount = rsi;
845     Reg64 imm_addr64 = rbx;
846
847     Xmm xmm_ns = Xmm(14);
848
849     Vmm vmm_ns = Vmm(isa == avx512_common ? 30 : 14);
850     Vmm vmm_zero = Vmm(isa == avx512_common ? 31 : 15);
851
852     Vmm vmm_mask = Vmm(isa == avx512_common ? 28 : 12);
853     Opmask k_mask = Opmask(1);
854 };
855
856 template <cpu_isa_t isa>
857 struct jit_uni_kernel_fwd_f32: public jit_uni_eltwise_kernel_f32,
858     public jit_generator {
859     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_kernel_fwd_f32)
860
861     jit_uni_kernel_fwd_f32(const eltwise_desc_t &desc)
862         : jit_uni_eltwise_kernel_f32(desc), jit_generator() {
863
864         eltwise_injector = new jit_uni_eltwise_injector_f32<isa>(this,
865                 desc.alg_kind, desc.alpha, desc.beta, false, 9, 1);
866
867         using namespace alg_kind;
868
869         assert(is_bwd() == false);
870         assert(utils::one_of(desc.alg_kind, eltwise_tanh, eltwise_elu,
871                     eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
872                     eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic));
873
874         preamble();
875
876         Reg64 param = abi_param1;
877         mov(reg_from, ptr[param + GET_OFF(from)]);
878         mov(reg_to, ptr[param + GET_OFF(to)]);
879         mov(reg_work_amount, ptr[param + GET_OFF(work_amount)]);
880
881         cmp(reg_work_amount, simd_w);
882         jl("reminder_loop_start", T_NEAR);
883
884         L("vectorized_loop_start");
885
886         uni_vmovups(vmm_src, ptr[reg_from]);
887         eltwise_injector->compute_vector(vmm_src.getIdx());
888         uni_vmovups(ptr[reg_to], vmm_src);
889
890         add(reg_from, vlen);
891         add(reg_to, vlen);
892
893         sub(reg_work_amount, simd_w);
894         cmp(reg_work_amount, simd_w);
895         jge("vectorized_loop_start", T_NEAR);
896
897         L("vectorized_loop_end");
898
899         L("reminder_loop_start");
900
901         cmp(reg_work_amount, 0);
902         jle("reminder_loop_end", T_NEAR);
903
904         movss(xmm_src, ptr[reg_from]);
905         eltwise_injector->compute_vector(xmm_src.getIdx());
906         movss(ptr[reg_to], xmm_src);
907
908         add(reg_from, sizeof(float));
909         add(reg_to, sizeof(float));
910
911         dec(reg_work_amount);
912         jmp("reminder_loop_start", T_NEAR);
913
914         L("reminder_loop_end");
915
916         postamble();
917
918         eltwise_injector->prepare_table();
919
920         ker_ = (decltype(ker_))this->getCode();
921     }
922
923     ~jit_uni_kernel_fwd_f32() {
924         delete eltwise_injector;
925     }
926
927 private:
928     using Vmm = typename utils::conditional3<isa == sse42, Xmm,
929                 isa == avx2, Ymm, Zmm>::type;
930
931     const int simd_w = cpu_isa_traits<isa>::vlen / sizeof(float);
932     const int vlen   = cpu_isa_traits<isa>::vlen;
933
934     Reg64 reg_from = rax;
935     Reg64 reg_to = r8;
936     Reg64 reg_work_amount = rsi;
937     Reg64 imm_addr64 = rbx;
938
939     Xmm xmm_src = Xmm(1);
940     Vmm vmm_src = Vmm(1);
941
942     jit_uni_eltwise_injector_f32<isa>* eltwise_injector;
943 };
944
945 } /* namespace */
946
947 template <cpu_isa_t isa>
948 status_t jit_uni_eltwise_fwd_t<isa>::pd_t::init() {
949     using namespace alg_kind;
950
951     assert(engine()->kind() == engine_kind::cpu);
952     bool ok = true && mayiuse(isa)
953         && utils::one_of(desc()->prop_kind, prop_kind::forward_training,
954                 prop_kind::forward_inference)
955         && utils::everyone_is(data_type::f32, desc()->data_desc.data_type)
956         && !has_zero_dim_memory()
957         && utils::implication(isa > avx2, utils::one_of(desc()->alg_kind,
958                 eltwise_relu, eltwise_elu))
959         && utils::implication(isa == sse42 || isa == avx2, utils::one_of(
960                     desc()->alg_kind, eltwise_relu, eltwise_tanh, eltwise_elu,
961                     eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
962                     eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic))
963         && memory_desc_wrapper(src_pd()).is_dense()
964         && attr()->has_default_values();
965
966     return ok ? status::success : status::unimplemented;
967 }
968
969 template <cpu_isa_t isa>
970 jit_uni_eltwise_fwd_t<isa>::jit_uni_eltwise_fwd_t(const pd_t *pd,
971         const input_vector &inputs, const output_vector &outputs)
972     : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd), kernel_(nullptr) {
973     const auto &desc = *conf_.desc();
974     switch (desc.alg_kind) {
975     case alg_kind::eltwise_relu:
976         kernel_ = new jit_uni_relu_kernel_f32<isa>(desc); break;
977     default:
978         kernel_ = new jit_uni_kernel_fwd_f32<isa>(desc);
979     }
980 }
981
982 template <cpu_isa_t isa>
983 jit_uni_eltwise_fwd_t<isa>::~jit_uni_eltwise_fwd_t()
984 { delete kernel_; }
985
986 template <cpu_isa_t isa>
987 void jit_uni_eltwise_fwd_t<isa>::execute_forward() {
988     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
989     auto dst = reinterpret_cast<data_t *>(this->memory(0));
990
991     const memory_desc_wrapper data_d(conf_.src_pd());
992
993     const size_t nelems = data_d.nelems();
994
995     src += data_d.blocking_desc().offset_padding;
996     dst += data_d.blocking_desc().offset_padding;
997
998     parallel(0, [&](const int ithr, const int nthr) {
999         size_t start{0}, end{0};
1000
1001         const int cache_line = 16;
1002
1003         balance211(utils::div_up(nelems, cache_line), nthr, ithr, start, end);
1004         start = nstl::min(nelems, start * cache_line);
1005         end = nstl::min(nelems, end * cache_line);
1006
1007         auto arg = jit_args();
1008         arg.from = &src[start];
1009         arg.for_comparison = &src[start];
1010         arg.to = &dst[start];
1011         arg.work_amount = end - start;
1012         if (arg.work_amount)
1013             (*kernel_)(&arg);
1014     });
1015 }
1016
1017 template <cpu_isa_t isa>
1018 status_t jit_uni_eltwise_bwd_t<isa>::pd_t::init() {
1019     assert(engine()->kind() == engine_kind::cpu);
1020
1021     bool ok = true
1022         && desc()->prop_kind == prop_kind::backward_data
1023         && utils::one_of(desc()->alg_kind, alg_kind::eltwise_relu)
1024         && src_pd()->desc()->data_type == data_type::f32
1025         && !has_zero_dim_memory()
1026         && mayiuse(isa)
1027         && memory_desc_wrapper(src_pd()).is_dense()
1028         && memory_desc_wrapper(diff_dst_pd()) == memory_desc_wrapper(src_pd())
1029         && attr()->has_default_values();
1030
1031     return ok ? status::success : status::unimplemented;
1032 }
1033
1034 template <cpu_isa_t isa>
1035 jit_uni_eltwise_bwd_t<isa>::jit_uni_eltwise_bwd_t(const pd_t *pd,
1036         const input_vector &inputs, const output_vector &outputs)
1037     : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd), kernel_(nullptr) {
1038     const auto &desc = *conf_.desc();
1039     switch (desc.alg_kind) {
1040     case alg_kind::eltwise_relu:
1041         kernel_ = new jit_uni_relu_kernel_f32<isa>(desc); break;
1042     default: assert(!"unknown eltwise alg_kind");
1043     }
1044 }
1045
1046 template <cpu_isa_t isa>
1047 jit_uni_eltwise_bwd_t<isa>::~jit_uni_eltwise_bwd_t()
1048 { delete kernel_; }
1049
1050 template <cpu_isa_t isa>
1051 void jit_uni_eltwise_bwd_t<isa>::execute_backward() {
1052     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
1053     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
1054     auto diff_src = reinterpret_cast<data_t *>(this->memory(0));
1055
1056     const memory_desc_wrapper data_d(conf_.src_pd());
1057     const memory_desc_wrapper diff_data_d(conf_.diff_src_pd());
1058
1059     const size_t nelems = data_d.nelems();
1060
1061     src += data_d.blocking_desc().offset_padding;
1062     diff_dst += diff_data_d.blocking_desc().offset_padding;
1063     diff_src += diff_data_d.blocking_desc().offset_padding;
1064
1065     parallel(0, [&](const int ithr, const int nthr) {
1066         size_t start{0}, end{0};
1067
1068         const int cache_line = 16;
1069
1070         balance211(utils::div_up(nelems, cache_line), nthr, ithr, start, end);
1071         start = nstl::min(nelems, start * cache_line);
1072         end = nstl::min(nelems, end * cache_line);
1073
1074         auto arg = jit_args();
1075         arg.from = &diff_dst[start];
1076         arg.to = &diff_src[start];
1077         arg.for_comparison = &src[start];
1078         arg.work_amount = end - start;
1079         if (arg.work_amount)
1080             (*kernel_)(&arg);
1081     });
1082 }
1083
1084 template struct jit_uni_eltwise_fwd_t<sse42>;
1085 template struct jit_uni_eltwise_bwd_t<sse42>;
1086 template struct jit_uni_eltwise_fwd_t<avx2>;
1087 template struct jit_uni_eltwise_bwd_t<avx2>;
1088 template struct jit_uni_eltwise_fwd_t<avx512_common>;
1089 template struct jit_uni_eltwise_bwd_t<avx512_common>;
1090
1091 }
1092 }
1093 }