Publishing R3
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / ref_eltwise.cpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <assert.h>
18
19 #include "c_types_map.hpp"
20 #include "type_helpers.hpp"
21 #include "math_utils.hpp"
22 #include "mkldnn_thread.hpp"
23
24 #include "ref_eltwise.hpp"
25
26 namespace mkldnn {
27 namespace impl {
28 namespace cpu {
29
30 using namespace alg_kind;
31 using namespace math;
32
33 template <typename T, typename A> inline T relu_fwd(T s, A alpha) {
34     return s > 0 ? s : (T)(s * alpha);
35 }
36 template <typename T, typename A> inline T relu_bwd(T dd, T s, A alpha) {
37     return s > 0 ? dd : (T)(dd * alpha);
38 }
39
40 template <typename T> T tanh_fwd(T s) {
41     const float e = ::expf((float)(2 * s)); /* maybe replace with -2*s? */
42     return (T)((e - 1) / (e + 1));
43 }
44 template <typename T> T tanh_bwd(T dd, T s) {
45     const float e = ::expf((float)(2 * s)); /* maybe replace with -2*s? */
46     const float th = (e - 1.f) / (e + 1.f);
47     return (T)(dd * (1 - th * th));
48 }
49
50 template <typename T, typename A> T elu_fwd(T s, A alpha) {
51     return s > 0 ? s : (T)(alpha * (::expf((float)s) - 1.f));
52 }
53 template <typename T, typename A> T elu_bwd(T dd, T s, A alpha) {
54     return (T)(dd * (s > 0 ? 1 : alpha * ::expf((float)s)));
55 }
56
57 template <typename T>
58 T square_fwd(T s) {
59     return s * s;
60 }
61
62 template <typename T>
63 T square_bwd(T dd, T s) {
64     return dd * 2*s;
65 }
66
67 template <typename T>
68 T abs_fwd(T s) {
69     return s > 0 ? s : -s;
70 }
71
72 template <typename T>
73 T abs_bwd(T dd, T s) {
74     return s > 0 ? dd : s < 0 ? -dd : 0;
75 }
76
77 template <typename T>
78 T sqrt_fwd(T s) {
79     return s > 0 ? (T)(::sqrtf((float)(s))) : 0;
80 }
81
82 template <typename T>
83 T sqrt_bwd(T dd, T s) {
84     return s > 0
85         ? (T)(dd / (2 * ::sqrtf((float)(s))))
86         : 0;
87 }
88
89 template <typename T, typename A>
90 T linear_fwd(T s, A alpha, A beta) {
91     return (T)(alpha * s + beta);
92 }
93
94 template <typename T, typename A>
95 T linear_bwd(T dd, T s, A alpha, A beta) {
96     (void) s;
97     (void) beta;
98     return (T)(dd * alpha);
99 }
100
101 template <typename T, typename A>
102 T bounded_relu_fwd(T s, A alpha) {
103     s = s > 0 ? s : 0;
104     return s > alpha ? (T)(alpha) : s;
105 }
106
107 template <typename T, typename A>
108 T bounded_relu_bwd(T dd, T s, A alpha) {
109     return dd * (0 < s && s < alpha ? 1 : 0);
110 }
111
112 template <typename T>
113 T soft_relu_fwd(T s) {
114     return (T)(::logf(1 + ::expf((float)s)));
115 }
116
117 template <typename T>
118 T soft_relu_bwd(T dd, T s) {
119     return (T)(dd / (1 + ::expf((float)(-s))));
120 }
121
122 template <typename T>
123 T logistic_fwd(T s) {
124     T v = (T)(::expf((float)s));
125     return v / (v + 1);
126 }
127
128 template <typename T>
129 T logistic_bwd(T dd, T s) {
130     T v = (T)(::expf((float)(-s)));
131     return dd * v / ((v + 1) * (v + 1));
132 }
133
134 template <typename T, typename A>
135 T clamp_fwd(T s, A alpha, A beta) {
136     return s > alpha ? (T)(alpha) : s < beta ? (T)(beta) : s;
137 }
138
139 template <typename T, typename A>
140 T clamp_bwd(T dd, T s, A alpha, A beta) {
141     return dd * (beta < s && s < alpha ? 1 : 0);
142 }
143
144 ref_eltwise_scalar_fwd_t::ref_eltwise_scalar_fwd_t(const alg_kind_t alg_, const float alpha_, const float beta_)
145         : alg(alg_), alpha(alpha_), beta(beta_) {
146     using namespace alg_kind;
147
148     assert(utils::one_of(alg, eltwise_tanh, eltwise_elu,
149                          eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
150                          eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic, eltwise_clamp));
151 }
152
153 float ref_eltwise_scalar_fwd_t::compute_scalar(float s) {
154     switch (alg) {
155         case eltwise_relu:   return relu_fwd(s, alpha);
156         case eltwise_tanh:   return tanh_fwd(s);
157         case eltwise_elu:    return elu_fwd(s, alpha);
158         case eltwise_square: return square_fwd(s);
159         case eltwise_abs:    return abs_fwd(s);
160         case eltwise_sqrt:   return sqrt_fwd(s);
161         case eltwise_linear: return linear_fwd(s, alpha, beta);
162         case eltwise_bounded_relu: return bounded_relu_fwd(s, alpha);
163         case eltwise_soft_relu: return soft_relu_fwd(s);
164         case eltwise_logistic: return logistic_fwd(s);
165         case eltwise_clamp: return clamp_fwd(s, alpha, beta);
166         default: assert(!"unknown eltwise alg_kind");
167     }
168
169     return 0.0f;
170 }
171
172 template <impl::data_type_t data_type>
173 void ref_eltwise_fwd_t<data_type>::execute_forward_nCspBc_padded() {
174     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
175     auto dst = reinterpret_cast<data_t*>(this->memory(0));
176
177     const memory_desc_wrapper data_d(conf_.src_pd());
178     const blocking_desc_t &blk = data_d.blocking_desc();
179     const int block = blk.block_dims[1];
180
181     const int MB = conf_.MB();
182     const int C = conf_.C() / block;
183     const int C_PADDED = blk.padding_dims[1] / block;
184     const int tail = conf_.C() % block;
185     const int SP = conf_.D() * conf_.H() * conf_.W();
186     const auto alg_kind = conf_.desc()->alg_kind;
187     const float alpha = conf_.desc()->alpha;
188     const float beta = conf_.desc()->beta;
189
190     auto ker = [=] (data_t &d, data_t s) {
191         switch (alg_kind) {
192             case eltwise_linear: d = linear_fwd(s, alpha, beta); break;
193             case eltwise_bounded_relu:
194                 d = bounded_relu_fwd(s, alpha); break;
195             case eltwise_soft_relu: d = soft_relu_fwd(s); break;
196             case eltwise_logistic: d = logistic_fwd(s); break;
197             case eltwise_clamp: d = clamp_fwd(s, alpha, beta); break;
198             default: assert(!"unknown eltwise alg_kind");
199         }
200     };
201
202     // FIXME: integer overflow?
203
204 #   pragma omp parallel for collapse(3) schedule(static)
205     for (int n = 0; n < MB; ++n) {
206         for (int c = 0; c < C_PADDED; ++c) {
207             for (int sp = 0; sp < SP; ++sp) {
208                 auto d_off = (n*C_PADDED*SP + c*SP + sp) * block;
209                 if (c < C) {
210                     for (int v = 0; v < block; v++)
211                         ker(dst[d_off + v], src[d_off + v]);
212                 } else {
213                     for (int v = 0; v < tail; v++)
214                         ker(dst[d_off + v], src[d_off + v]);
215                 }
216             }
217         }
218     }
219 }
220
221 template <impl::data_type_t data_type>
222 void ref_eltwise_fwd_t<data_type>::execute_forward_generic() {
223     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
224     auto dst = reinterpret_cast<data_t*>(this->memory(0));
225
226     const memory_desc_wrapper data_d(conf_.src_pd());
227
228     const int MB = conf_.MB();
229     const int C = conf_.C();
230     const int D = conf_.D();
231     const int H = conf_.H();
232     const int W = conf_.W();
233     const auto alg_kind = conf_.desc()->alg_kind;
234     const float alpha = conf_.desc()->alpha;
235     const float beta = conf_.desc()->beta;
236     const bool is_3d = conf_.desc()->data_desc.ndims == 5;
237
238 #   pragma omp parallel for collapse(5) schedule(static)
239     for (int n = 0; n < MB; ++n) {
240         for (int c = 0; c < C; ++c) {
241             for (int id = 0; id < D; ++id)
242             for (int h = 0; h < H; ++h)
243             for (int w = 0; w < W; ++w) {
244                 auto d_off = is_3d
245                     ? data_d.off(n, c, id, h, w) : data_d.off(n, c, h, w);
246                 data_t s = src[d_off];
247                 data_t &d = dst[d_off];
248                 switch (alg_kind) {
249                 case eltwise_relu: d = relu_fwd(s, alpha); break;
250                 case eltwise_tanh: d = tanh_fwd(s); break;
251                 case eltwise_elu: d = elu_fwd(s, alpha); break;
252                 case eltwise_square: d = square_fwd(s); break;
253                 case eltwise_abs: d = abs_fwd(s); break;
254                 case eltwise_sqrt: d = sqrt_fwd(s); break;
255                 case eltwise_linear: d = linear_fwd(s, alpha, beta); break;
256                 case eltwise_bounded_relu:
257                     d = bounded_relu_fwd(s, alpha); break;
258                 case eltwise_soft_relu: d = soft_relu_fwd(s); break;
259                 case eltwise_logistic: d = logistic_fwd(s); break;
260                 case eltwise_clamp: d = clamp_fwd(s, alpha, beta); break;
261                 default: assert(!"unknown eltwise alg_kind");
262                 }
263             }
264         }
265     }
266 }
267
268 template <impl::data_type_t data_type>
269 void ref_eltwise_fwd_t<data_type>::execute_forward_dense() {
270     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
271     auto dst = reinterpret_cast<data_t*>(this->memory(0));
272
273     const memory_desc_wrapper data_d(conf_.src_pd());
274
275     const ptrdiff_t nelems = static_cast<ptrdiff_t>(data_d.nelems(true));
276     const auto alg_kind = conf_.desc()->alg_kind;
277     const float alpha = conf_.desc()->alpha;
278     const float beta  = conf_.desc()->beta;
279
280     src += data_d.blocking_desc().offset_padding;
281     dst += data_d.blocking_desc().offset_padding;
282
283     if (alg_kind == eltwise_relu) {
284         // a fast path for relu as the most popular activation
285 #       pragma omp parallel for schedule(static)
286         for (ptrdiff_t e = 0; e < nelems; ++e)
287             dst[e] = relu_fwd(src[e], alpha);
288         return;
289     }
290
291 #   pragma omp parallel for schedule(static)
292     for (ptrdiff_t e = 0; e < nelems; ++e) {
293         const data_t s = src[e];
294         data_t &d = dst[e];
295
296         switch (alg_kind) {
297         case eltwise_tanh: d = tanh_fwd(s); break;
298         case eltwise_elu: d = elu_fwd(s, alpha); break;
299         case eltwise_square: d = square_fwd(s); break;
300         case eltwise_abs: d = abs_fwd(s); break;
301         case eltwise_sqrt: d = sqrt_fwd(s); break;
302         case eltwise_linear: d = linear_fwd(s, alpha, beta); break;
303         case eltwise_bounded_relu: d = bounded_relu_fwd(s, alpha); break;
304         case eltwise_soft_relu: d = soft_relu_fwd(s); break;
305         case eltwise_logistic: d = logistic_fwd(s); break;
306         case eltwise_clamp: d = clamp_fwd(s, alpha, beta); break;
307         default: assert(!"unknown eltwise alg_kind");
308         }
309     }
310 }
311
312 template <impl::data_type_t data_type>
313 void ref_eltwise_bwd_t<data_type>::execute_backward_generic() {
314     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
315     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
316     auto diff_src = reinterpret_cast<data_t*>(this->memory(0));
317
318     const memory_desc_wrapper data_d(conf_.src_pd());
319     const memory_desc_wrapper diff_data_d(conf_.diff_src_pd());
320
321     const int MB = conf_.MB();
322     const int C = conf_.C();
323     const int D = conf_.D();
324     const int H = conf_.H();
325     const int W = conf_.W();
326     const auto alg_kind = conf_.desc()->alg_kind;
327     const float alpha = conf_.desc()->alpha;
328     const float beta = conf_.desc()->beta;
329     const bool is_3d = conf_.desc()->data_desc.ndims == 5;
330
331 #   pragma omp parallel for collapse(5) schedule(static)
332     for (int n = 0; n < MB; ++n) {
333         for (int c = 0; c < C; ++c) {
334             for (int d = 0; d < D; ++d)
335             for (int h = 0; h < H; ++h)
336             for (int w = 0; w < W; ++w) {
337                 auto data_off = is_3d
338                     ? data_d.off(n, c, d, h, w) : data_d.off(n, c, h, w);
339                 auto diff_data_off = is_3d
340                     ? diff_data_d.off(n, c, d, h, w)
341                     : diff_data_d.off(n, c, h, w);
342                 data_t s = src[data_off];
343                 data_t dd = diff_dst[diff_data_off];
344                 data_t &ds = diff_src[diff_data_off];
345                 switch (alg_kind) {
346                 case eltwise_relu: ds = relu_bwd(dd, s, alpha); break;
347                 case eltwise_tanh: ds = tanh_bwd(dd, s); break;
348                 case eltwise_elu: ds = elu_bwd(dd, s, alpha); break;
349                 case eltwise_square: ds = square_bwd(dd, s); break;
350                 case eltwise_abs: ds = abs_bwd(dd, s); break;
351                 case eltwise_sqrt: ds = sqrt_bwd(dd, s); break;
352                 case eltwise_linear:
353                     ds = linear_bwd(dd, s, alpha, beta); break;
354                 case eltwise_bounded_relu:
355                     ds = bounded_relu_bwd(dd, s, alpha); break;
356                 case eltwise_soft_relu: ds = soft_relu_bwd(dd, s); break;
357                 case eltwise_logistic: ds = logistic_bwd(dd, s); break;
358                 case eltwise_clamp: ds = clamp_bwd(dd, s, alpha, beta); break;
359                 default: assert(!"unknown eltwise alg_kind");
360                 }
361             }
362         }
363     }
364 }
365
366 template <impl::data_type_t data_type>
367 void ref_eltwise_bwd_t<data_type>::execute_backward_dense() {
368     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
369     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
370     auto diff_src = reinterpret_cast<data_t*>(this->memory(0));
371
372     const memory_desc_wrapper data_d(conf_.src_pd());
373     const memory_desc_wrapper diff_data_d(conf_.diff_src_pd());
374
375     const ptrdiff_t nelems = static_cast<ptrdiff_t>(data_d.nelems(true));
376     const auto alg_kind = conf_.desc()->alg_kind;
377     const float alpha = conf_.desc()->alpha;
378     const float beta = conf_.desc()->beta;
379
380     src += data_d.blocking_desc().offset_padding;
381     diff_dst += diff_data_d.blocking_desc().offset_padding;
382     diff_src += diff_data_d.blocking_desc().offset_padding;
383
384 #   pragma omp parallel for schedule(static)
385     for (ptrdiff_t e = 0; e < nelems; ++e) {
386         const data_t dd = diff_dst[e];
387         const data_t s = src[e];
388         data_t &ds = diff_src[e];
389
390         switch (alg_kind) {
391         case eltwise_relu: ds = relu_bwd(dd, s, alpha); break;
392         case eltwise_tanh: ds = tanh_bwd(dd, s); break;
393         case eltwise_elu: ds = elu_bwd(dd, s, alpha); break;
394         case eltwise_square: ds = square_bwd(dd, s); break;
395         case eltwise_abs: ds = abs_bwd(dd, s); break;
396         case eltwise_sqrt: ds = sqrt_bwd(dd, s); break;
397         case eltwise_linear: ds = linear_bwd(dd, s, alpha, beta); break;
398         case eltwise_bounded_relu: ds = bounded_relu_bwd(dd, s, alpha); break;
399         case eltwise_soft_relu: ds = soft_relu_bwd(dd, s); break;
400         case eltwise_logistic: ds = logistic_bwd(dd, s); break;
401         case eltwise_clamp: ds = clamp_bwd(dd, s, alpha, beta); break;
402         default: assert(!"unknown eltwise alg_kind");
403         }
404     }
405 }
406
407 template struct ref_eltwise_fwd_t<data_type::f32>;
408 template struct ref_eltwise_fwd_t<data_type::s32>;
409 template struct ref_eltwise_fwd_t<data_type::s16>;
410 template struct ref_eltwise_fwd_t<data_type::s8>;
411 template struct ref_eltwise_fwd_t<data_type::u8>;
412
413 template struct ref_eltwise_bwd_t<data_type::f32>;
414 template struct ref_eltwise_bwd_t<data_type::s32>;
415 template struct ref_eltwise_bwd_t<data_type::s16>;
416
417 }
418 }
419 }
420
421 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s