updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / ref_eltwise.cpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <assert.h>
18
19 #include "c_types_map.hpp"
20 #include "type_helpers.hpp"
21 #include "math_utils.hpp"
22 #include "mkldnn_thread.hpp"
23
24 #include "bfloat16_utils.hpp"
25
26 #include "ref_eltwise.hpp"
27
28 namespace mkldnn {
29 namespace impl {
30 namespace cpu {
31
32 using namespace alg_kind;
33 using namespace math;
34
35 ref_eltwise_scalar_fwd_t::ref_eltwise_scalar_fwd_t(alg_kind_t alg, float alpha,
36         float beta): alg_(alg), alpha_(alpha), beta_(beta) {
37     assert(utils::one_of(alg_, eltwise_relu, eltwise_tanh, eltwise_elu,
38                 eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
39                 eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic,
40                 eltwise_clamp, eltwise_exp, eltwise_not));
41 }
42
43 ref_eltwise_scalar_fwd_t::ref_eltwise_scalar_fwd_t(
44         const post_ops_t::entry_t::eltwise_t &eltwise)
45     : ref_eltwise_scalar_fwd_t(eltwise.alg, eltwise.alpha, eltwise.beta) {}
46
47 float ref_eltwise_scalar_fwd_t::compute_scalar(float s) {
48     switch (alg_) {
49         case eltwise_relu: return relu_fwd(s, alpha_);
50         case eltwise_tanh: return tanh_fwd(s);
51         case eltwise_elu: return elu_fwd(s, alpha_);
52         case eltwise_square: return square_fwd(s);
53         case eltwise_abs: return abs_fwd(s);
54         case eltwise_sqrt: return sqrt_fwd(s);
55         case eltwise_linear: return linear_fwd(s, alpha_, beta_);
56         case eltwise_bounded_relu: return bounded_relu_fwd(s, alpha_);
57         case eltwise_soft_relu: return soft_relu_fwd(s);
58         case eltwise_logistic: return logistic_fwd(s);
59         case eltwise_clamp: return clamp_fwd(s, alpha_, beta_);
60         case eltwise_exp: return exp_fwd(s);
61         case eltwise_not: return not_fwd(s);
62         default: assert(!"unknown eltwise alg_kind");
63     }
64
65     return 0.f;
66 }
67
68 template <impl::data_type_t data_type>
69 void ref_eltwise_fwd_t<data_type>::execute_forward_nCspBc_padded() const {
70     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
71     auto dst = reinterpret_cast<data_t*>(this->memory(0));
72
73     const memory_desc_wrapper data_d(pd()->src_pd());
74     const blocking_desc_t &blk = data_d.blocking_desc();
75     const int block = blk.block_dims[1];
76
77     const int MB = pd()->MB();
78     const int C = pd()->C() / block;
79     const int C_PADDED = blk.padding_dims[1] / block;
80     const int tail = pd()->C() % block;
81     const int SP = pd()->D() * pd()->H() * pd()->W();
82     const auto alg_kind = pd()->desc()->alg_kind;
83     const float alpha = pd()->desc()->alpha;
84     const float beta = pd()->desc()->beta;
85
86     auto ker = [=] (data_t &d, data_t s) {
87         switch (alg_kind) {
88             case eltwise_linear: d = linear_fwd(s, alpha, beta); break;
89             case eltwise_bounded_relu:
90                 d = bounded_relu_fwd(s, alpha); break;
91             case eltwise_soft_relu: d = soft_relu_fwd(s); break;
92             case eltwise_logistic: d = logistic_fwd(s); break;
93             case eltwise_clamp: d = clamp_fwd(s, alpha, beta); break;
94             case eltwise_exp: d = exp_fwd(s); break;
95             case eltwise_not: d = not_fwd(s); break;
96             default: assert(!"unknown eltwise alg_kind");
97         }
98     };
99
100     // FIXME: integer overflow?
101
102     parallel_nd(MB, C_PADDED, SP,
103         [&](int n, int c, int sp) {
104         auto d_off = (n*C_PADDED*SP + c*SP + sp) * block;
105         if (c < C) {
106             for (int v = 0; v < block; v++)
107                 ker(dst[d_off + v], src[d_off + v]);
108         } else {
109             for (int v = 0; v < tail; v++)
110                 ker(dst[d_off + v], src[d_off + v]);
111         }
112     });
113 }
114
115 template <>
116 void ref_eltwise_fwd_t<data_type::bf16>::execute_forward_nCspBc_padded() const {
117     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
118     auto dst = reinterpret_cast<data_t*>(this->memory(0));
119
120     const memory_desc_wrapper data_d(pd()->src_pd());
121     const blocking_desc_t &blk = data_d.blocking_desc();
122     const int block = blk.block_dims[1];
123
124     const int MB = pd()->MB();
125     const int C = pd()->C() / block;
126     const int C_PADDED = blk.padding_dims[1] / block;
127     const int tail = pd()->C() % block;
128     const int SP = pd()->D() * pd()->H() * pd()->W();
129     const auto alg_kind = pd()->desc()->alg_kind;
130     const float alpha = pd()->desc()->alpha;
131     const float beta = pd()->desc()->beta;
132
133     auto ker = [=] (data_t &d, data_t s) {
134         float s_ = 0.0f, d_ = 0.0f;
135         bf16_cvt_utils::cvt_bfloat16_to_float(&s_, &s);
136         switch (alg_kind) {
137             case eltwise_linear: d_ = linear_fwd(s_, alpha, beta); break;
138             case eltwise_bounded_relu:
139                 d_ = bounded_relu_fwd(s_, alpha); break;
140             case eltwise_soft_relu: d_ = soft_relu_fwd(s_); break;
141             case eltwise_logistic: d_ = logistic_fwd(s_); break;
142             default: assert(!"unknown eltwise alg_kind");
143         }
144         bf16_cvt_utils::cvt_float_to_bfloat16(&d, &d_);
145     };
146
147     // FIXME: integer overflow?
148
149     parallel_nd(MB, C_PADDED, SP,
150         [&](int n, int c, int sp) {
151         auto d_off = (n*C_PADDED*SP + c*SP + sp) * block;
152         if (c < C) {
153             for (int v = 0; v < block; v++)
154                 ker(dst[d_off + v], src[d_off + v]);
155         } else {
156             for (int v = 0; v < tail; v++)
157                 ker(dst[d_off + v], src[d_off + v]);
158         }
159     });
160 }
161
162 template <impl::data_type_t data_type>
163 void ref_eltwise_fwd_t<data_type>::execute_forward_generic() const {
164     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
165     auto dst = reinterpret_cast<data_t*>(this->memory(0));
166
167     /* fast return */
168     if (pd()->has_zero_dim_memory()) return;
169
170     const memory_desc_wrapper data_d(pd()->src_pd());
171
172     const int MB = pd()->MB();
173     const int C = pd()->C();
174     const int D = pd()->D();
175     const int H = pd()->H();
176     const int W = pd()->W();
177     const auto alg_kind = pd()->desc()->alg_kind;
178     const float alpha = pd()->desc()->alpha;
179     const float beta = pd()->desc()->beta;
180     const bool is_3d = pd()->desc()->data_desc.ndims == 5;
181
182     parallel_nd(MB, C, D, H, W,
183         [&](int n, int c, int id, int h, int w) {
184         auto d_off = is_3d
185             ? data_d.off(n, c, id, h, w) : data_d.off(n, c, h, w);
186         data_t s = src[d_off];
187         data_t &d = dst[d_off];
188         switch (alg_kind) {
189             case eltwise_relu: d = relu_fwd(s, alpha); break;
190             case eltwise_tanh: d = tanh_fwd(s); break;
191             case eltwise_elu: d = elu_fwd(s, alpha); break;
192             case eltwise_square: d = square_fwd(s); break;
193             case eltwise_abs: d = abs_fwd(s); break;
194             case eltwise_sqrt: d = sqrt_fwd(s); break;
195             case eltwise_linear: d = linear_fwd(s, alpha, beta); break;
196             case eltwise_bounded_relu:
197                 d = bounded_relu_fwd(s, alpha); break;
198             case eltwise_soft_relu: d = soft_relu_fwd(s); break;
199             case eltwise_logistic: d = logistic_fwd(s); break;
200             case eltwise_clamp: d = clamp_fwd(s, alpha, beta); break;
201             case eltwise_exp: d = exp_fwd(s); break;
202             case eltwise_not: d = not_fwd(s); break;
203             default: assert(!"unknown eltwise alg_kind");
204         }
205     });
206 }
207
208 template <>
209 void ref_eltwise_fwd_t<data_type::bf16>::execute_forward_generic() const {
210     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
211     auto dst = reinterpret_cast<data_t*>(this->memory(0));
212
213     /* fast return */
214     if (pd()->has_zero_dim_memory()) return;
215
216     const memory_desc_wrapper data_d(pd()->src_pd());
217
218     const int MB = pd()->MB();
219     const int C = pd()->C();
220     const int D = pd()->D();
221     const int H = pd()->H();
222     const int W = pd()->W();
223     const auto alg_kind = pd()->desc()->alg_kind;
224     const float alpha = pd()->desc()->alpha;
225     const float beta = pd()->desc()->beta;
226     const bool is_3d = pd()->desc()->data_desc.ndims == 5;
227
228     parallel_nd(MB, C, D, H, W,
229         [&](int n, int c, int id, int h, int w) {
230         auto d_off = is_3d
231             ? data_d.off(n, c, id, h, w) : data_d.off(n, c, h, w);
232         data_t s = src[d_off];
233         data_t &d = dst[d_off];
234         float s_ = 0.0f, d_ = 0.0f;
235         bf16_cvt_utils::cvt_bfloat16_to_float(&s_, &s);
236         switch (alg_kind) {
237             case eltwise_relu: d_ = relu_fwd(s_, alpha); break;
238             case eltwise_tanh: d_ = tanh_fwd(s_); break;
239             case eltwise_elu: d_ = elu_fwd(s_, alpha); break;
240             case eltwise_square: d_ = square_fwd(s_); break;
241             case eltwise_abs: d_ = abs_fwd(s_); break;
242             case eltwise_sqrt: d_ = sqrt_fwd(s_); break;
243             case eltwise_linear: d_ = linear_fwd(s_, alpha, beta); break;
244             case eltwise_bounded_relu:
245                 d_ = bounded_relu_fwd(s_, alpha); break;
246             case eltwise_soft_relu: d_ = soft_relu_fwd(s_); break;
247             case eltwise_logistic: d_ = logistic_fwd(s_); break;
248             default: assert(!"unknown eltwise alg_kind");
249         }
250         bf16_cvt_utils::cvt_float_to_bfloat16(&d, &d_);
251     });
252 }
253
254 template <impl::data_type_t data_type>
255 void ref_eltwise_fwd_t<data_type>::execute_forward_dense() const {
256     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
257     auto dst = reinterpret_cast<data_t*>(this->memory(0));
258
259     const memory_desc_wrapper data_d(pd()->src_pd());
260
261     const ptrdiff_t nelems = static_cast<ptrdiff_t>(data_d.nelems(true));
262     const auto alg_kind = pd()->desc()->alg_kind;
263     const float alpha = pd()->desc()->alpha;
264     const float beta  = pd()->desc()->beta;
265
266     src += data_d.blocking_desc().offset_padding;
267     dst += data_d.blocking_desc().offset_padding;
268
269     if (alg_kind == eltwise_relu) {
270         // a fast path for relu as the most popular activation
271         parallel_nd(nelems, [&](ptrdiff_t e) {
272             dst[e] = relu_fwd(src[e], alpha);
273         });
274         return;
275     }
276
277     parallel_nd(nelems, [&](ptrdiff_t e) {
278         const data_t s = src[e];
279         data_t &d = dst[e];
280
281         switch (alg_kind) {
282         case eltwise_tanh: d = tanh_fwd(s); break;
283         case eltwise_elu: d = elu_fwd(s, alpha); break;
284         case eltwise_square: d = square_fwd(s); break;
285         case eltwise_abs: d = abs_fwd(s); break;
286         case eltwise_sqrt: d = sqrt_fwd(s); break;
287         case eltwise_linear: d = linear_fwd(s, alpha, beta); break;
288         case eltwise_bounded_relu: d = bounded_relu_fwd(s, alpha); break;
289         case eltwise_soft_relu: d = soft_relu_fwd(s); break;
290         case eltwise_logistic: d = logistic_fwd(s); break;
291         case eltwise_clamp: d = clamp_fwd(s, alpha, beta); break;
292         case eltwise_exp: d = exp_fwd(s); break;
293         case eltwise_not: d = not_fwd(s); break;
294         default: assert(!"unknown eltwise alg_kind");
295         }
296     });
297 }
298
299 template <>
300 void ref_eltwise_fwd_t<data_type::bf16>::execute_forward_dense() const {
301     auto src = reinterpret_cast<const mkldnn_bfloat16_t *>(this->input_memory(0));
302     auto dst = reinterpret_cast<mkldnn_bfloat16_t *>(this->memory(0));
303
304     const memory_desc_wrapper data_d(pd()->src_pd());
305
306     const ptrdiff_t nelems = static_cast<ptrdiff_t>(data_d.nelems(true));
307     const auto alg_kind = pd()->desc()->alg_kind;
308     const float alpha = pd()->desc()->alpha;
309     const float beta  = pd()->desc()->beta;
310
311     src += data_d.blocking_desc().offset_padding;
312     dst += data_d.blocking_desc().offset_padding;
313
314     if (alg_kind == eltwise_relu) {
315         // a fast path for relu as the most popular activation
316         parallel_nd(nelems, [&](ptrdiff_t e) {
317             float s_ = 0.0f;
318             bf16_cvt_utils::cvt_bfloat16_to_float(&s_, &src[e]);
319             float d_ = relu_fwd(s_, alpha);
320             bf16_cvt_utils::cvt_float_to_bfloat16(&dst[e], &d_);
321         });
322         return;
323     }
324
325     parallel_nd(nelems, [&](ptrdiff_t e) {
326         float s_ = 0.0f, d_ = 0.0f;
327         bf16_cvt_utils::cvt_bfloat16_to_float(&s_, &src[e]);
328         switch (alg_kind) {
329         case eltwise_tanh: d_ = tanh_fwd(s_); break;
330         case eltwise_elu: d_ = elu_fwd(s_, alpha); break;
331         case eltwise_square: d_ = square_fwd(s_); break;
332         case eltwise_abs: d_ = abs_fwd(s_); break;
333         case eltwise_sqrt: d_ = sqrt_fwd(s_); break;
334         case eltwise_linear: d_ = linear_fwd(s_, alpha, beta); break;
335         case eltwise_bounded_relu: d_ = bounded_relu_fwd(s_, alpha); break;
336         case eltwise_soft_relu: d_ = soft_relu_fwd(s_); break;
337         case eltwise_logistic: d_ = logistic_fwd(s_); break;
338         default: assert(!"unknown eltwise alg_kind");
339         }
340         bf16_cvt_utils::cvt_float_to_bfloat16(&dst[e], &d_);
341     });
342 }
343
344 template <impl::data_type_t data_type>
345 void ref_eltwise_bwd_t<data_type>::execute_backward_generic() const {
346     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
347     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
348     auto diff_src = reinterpret_cast<data_t*>(this->memory(0));
349
350     /* fast return */
351     if (pd()->has_zero_dim_memory()) return;
352
353     const memory_desc_wrapper data_d(pd()->src_pd());
354     const memory_desc_wrapper diff_data_d(pd()->diff_src_pd());
355
356     const int MB = pd()->MB();
357     const int C = pd()->C();
358     const int D = pd()->D();
359     const int H = pd()->H();
360     const int W = pd()->W();
361     const auto alg_kind = pd()->desc()->alg_kind;
362     const float alpha = pd()->desc()->alpha;
363     const float beta = pd()->desc()->beta;
364     const bool is_3d = pd()->desc()->data_desc.ndims == 5;
365
366     parallel_nd(MB, C, D, H, W,
367         [&](int n, int c, int d, int h, int w) {
368         auto data_off = is_3d
369             ? data_d.off(n, c, d, h, w) : data_d.off(n, c, h, w);
370         auto diff_data_off = is_3d
371             ? diff_data_d.off(n, c, d, h, w)
372             : diff_data_d.off(n, c, h, w);
373         data_t s = src[data_off];
374         data_t dd = diff_dst[diff_data_off];
375         data_t &ds = diff_src[diff_data_off];
376         switch (alg_kind) {
377             case eltwise_relu: ds = relu_bwd(dd, s, alpha); break;
378             case eltwise_tanh: ds = tanh_bwd(dd, s); break;
379             case eltwise_elu: ds = elu_bwd(dd, s, alpha); break;
380             case eltwise_square: ds = square_bwd(dd, s); break;
381             case eltwise_abs: ds = abs_bwd(dd, s); break;
382             case eltwise_sqrt: ds = sqrt_bwd(dd, s); break;
383             case eltwise_linear:
384                 ds = linear_bwd(dd, s, alpha, beta); break;
385             case eltwise_bounded_relu:
386                 ds = bounded_relu_bwd(dd, s, alpha); break;
387             case eltwise_soft_relu: ds = soft_relu_bwd(dd, s); break;
388             case eltwise_logistic: ds = logistic_bwd(dd, s); break;
389             case eltwise_clamp: ds = clamp_bwd(dd, s, alpha, beta); break;
390             case eltwise_exp: ds = exp_bwd(dd, s); break;
391             default: assert(!"unknown eltwise alg_kind");
392         }
393     });
394 }
395
396 template <>
397 void ref_eltwise_bwd_t<data_type::bf16>::execute_backward_generic() const {
398     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
399     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
400     auto diff_src = reinterpret_cast<data_t*>(this->memory(0));
401
402     /* fast return */
403     if (pd()->has_zero_dim_memory()) return;
404
405     const memory_desc_wrapper data_d(pd()->src_pd());
406     const memory_desc_wrapper diff_data_d(pd()->diff_src_pd());
407
408     const int MB = pd()->MB();
409     const int C = pd()->C();
410     const int D = pd()->D();
411     const int H = pd()->H();
412     const int W = pd()->W();
413     const auto alg_kind = pd()->desc()->alg_kind;
414     const float alpha = pd()->desc()->alpha;
415     const float beta = pd()->desc()->beta;
416     const bool is_3d = pd()->desc()->data_desc.ndims == 5;
417
418     parallel_nd(MB, C, D, H, W,
419         [&](int n, int c, int d, int h, int w) {
420         auto data_off = is_3d
421             ? data_d.off(n, c, d, h, w) : data_d.off(n, c, h, w);
422         auto diff_data_off = is_3d
423             ? diff_data_d.off(n, c, d, h, w)
424             : diff_data_d.off(n, c, h, w);
425
426         float dd_ = 0.0f, s_ = 0.0f, ds_ = 0.0f;
427         bf16_cvt_utils::cvt_bfloat16_to_float(&dd_, &diff_dst[diff_data_off]);
428         bf16_cvt_utils::cvt_bfloat16_to_float(&s_, &src[data_off]);
429         switch (alg_kind) {
430             case eltwise_relu: ds_ = relu_bwd(dd_, s_, alpha); break;
431             case eltwise_tanh: ds_ = tanh_bwd(dd_, s_); break;
432             case eltwise_elu: ds_ = elu_bwd(dd_, s_, alpha); break;
433             case eltwise_square: ds_ = square_bwd(dd_, s_); break;
434             case eltwise_abs: ds_ = abs_bwd(dd_, s_); break;
435             case eltwise_sqrt: ds_ = sqrt_bwd(dd_, s_); break;
436             case eltwise_linear:
437                 ds_ = linear_bwd(dd_, s_, alpha, beta); break;
438             case eltwise_bounded_relu:
439                 ds_ = bounded_relu_bwd(dd_, s_, alpha); break;
440             case eltwise_soft_relu: ds_ = soft_relu_bwd(dd_, s_); break;
441             case eltwise_logistic: ds_ = logistic_bwd(dd_, s_); break;
442             default: assert(!"unknown eltwise alg_kind");
443         }
444         bf16_cvt_utils::cvt_float_to_bfloat16(&diff_src[diff_data_off], &ds_);
445     });
446 }
447
448 template <impl::data_type_t data_type>
449 void ref_eltwise_bwd_t<data_type>::execute_backward_dense() const {
450     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
451     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
452     auto diff_src = reinterpret_cast<data_t*>(this->memory(0));
453
454     const memory_desc_wrapper data_d(pd()->src_pd());
455     const memory_desc_wrapper diff_data_d(pd()->diff_src_pd());
456
457     const ptrdiff_t nelems = static_cast<ptrdiff_t>(data_d.nelems(true));
458     const auto alg_kind = pd()->desc()->alg_kind;
459     const float alpha = pd()->desc()->alpha;
460     const float beta = pd()->desc()->beta;
461
462     src += data_d.blocking_desc().offset_padding;
463     diff_dst += diff_data_d.blocking_desc().offset_padding;
464     diff_src += diff_data_d.blocking_desc().offset_padding;
465
466     parallel_nd(nelems, [&](ptrdiff_t e) {
467         const data_t dd = diff_dst[e];
468         const data_t s = src[e];
469         data_t &ds = diff_src[e];
470
471         switch (alg_kind) {
472         case eltwise_relu: ds = relu_bwd(dd, s, alpha); break;
473         case eltwise_tanh: ds = tanh_bwd(dd, s); break;
474         case eltwise_elu: ds = elu_bwd(dd, s, alpha); break;
475         case eltwise_square: ds = square_bwd(dd, s); break;
476         case eltwise_abs: ds = abs_bwd(dd, s); break;
477         case eltwise_sqrt: ds = sqrt_bwd(dd, s); break;
478         case eltwise_linear: ds = linear_bwd(dd, s, alpha, beta); break;
479         case eltwise_bounded_relu: ds = bounded_relu_bwd(dd, s, alpha); break;
480         case eltwise_soft_relu: ds = soft_relu_bwd(dd, s); break;
481         case eltwise_logistic: ds = logistic_bwd(dd, s); break;
482         case eltwise_clamp: ds = clamp_bwd(dd, s, alpha, beta); break;
483         case eltwise_exp: ds = exp_bwd(dd, s); break;
484         default: assert(!"unknown eltwise alg_kind");
485         }
486     });
487 }
488
489 template <>
490 void ref_eltwise_bwd_t<data_type::bf16>::execute_backward_dense() const {
491     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
492     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
493     auto diff_src = reinterpret_cast<data_t*>(this->memory(0));
494
495     const memory_desc_wrapper data_d(pd()->src_pd());
496     const memory_desc_wrapper diff_data_d(pd()->diff_src_pd());
497
498     const ptrdiff_t nelems = static_cast<ptrdiff_t>(data_d.nelems(true));
499     const auto alg_kind = pd()->desc()->alg_kind;
500     const float alpha = pd()->desc()->alpha;
501     const float beta = pd()->desc()->beta;
502
503     src += data_d.blocking_desc().offset_padding;
504     diff_dst += diff_data_d.blocking_desc().offset_padding;
505     diff_src += diff_data_d.blocking_desc().offset_padding;
506
507     parallel_nd(nelems, [&](ptrdiff_t e) {
508         float dd_ = 0.0f, s_ = 0.0f, ds_ = 0.0f;
509         bf16_cvt_utils::cvt_bfloat16_to_float(&dd_, &diff_dst[e]);
510         bf16_cvt_utils::cvt_bfloat16_to_float(&s_, &src[e]);
511
512         switch (alg_kind) {
513         case eltwise_relu: ds_ = relu_bwd(dd_, s_, alpha); break;
514         case eltwise_tanh: ds_ = tanh_bwd(dd_, s_); break;
515         case eltwise_elu: ds_ = elu_bwd(dd_, s_, alpha); break;
516         case eltwise_square: ds_ = square_bwd(dd_, s_); break;
517         case eltwise_abs: ds_ = abs_bwd(dd_, s_); break;
518         case eltwise_sqrt: ds_ = sqrt_bwd(dd_, s_); break;
519         case eltwise_linear: ds_ = linear_bwd(dd_, s_, alpha, beta); break;
520         case eltwise_bounded_relu: ds_ = bounded_relu_bwd(dd_, s_, alpha); break;
521         case eltwise_soft_relu: ds_ = soft_relu_bwd(dd_, s_); break;
522         case eltwise_logistic: ds_ = logistic_bwd(dd_, s_); break;
523         default: assert(!"unknown eltwise alg_kind");
524         }
525         bf16_cvt_utils::cvt_float_to_bfloat16(&diff_src[e], &ds_);
526     });
527 }
528
529 template struct ref_eltwise_fwd_t<data_type::f32>;
530 template struct ref_eltwise_fwd_t<data_type::bf16>;
531 template struct ref_eltwise_fwd_t<data_type::s32>;
532 template struct ref_eltwise_fwd_t<data_type::s16>;
533 template struct ref_eltwise_fwd_t<data_type::s8>;
534 template struct ref_eltwise_fwd_t<data_type::u8>;
535
536 template struct ref_eltwise_bwd_t<data_type::f32>;
537 template struct ref_eltwise_bwd_t<data_type::bf16>;
538 template struct ref_eltwise_bwd_t<data_type::s32>;
539 template struct ref_eltwise_bwd_t<data_type::s16>;
540
541 }
542 }
543 }
544
545 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s