1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
19 #include "c_types_map.hpp"
20 #include "type_helpers.hpp"
21 #include "math_utils.hpp"
22 #include "mkldnn_thread.hpp"
24 #include "ref_eltwise.hpp"
30 using namespace alg_kind;
33 template <typename T, typename A> inline T relu_fwd(T s, A alpha) {
34 return s > 0 ? s : (T)(s * alpha);
36 template <typename T, typename A> inline T relu_bwd(T dd, T s, A alpha) {
37 return s > 0 ? dd : (T)(dd * alpha);
40 template <typename T> T tanh_fwd(T s) {
41 const float e = ::expf((float)(2 * s)); /* maybe replace with -2*s? */
42 return (T)((e - 1) / (e + 1));
44 template <typename T> T tanh_bwd(T dd, T s) {
45 const float e = ::expf((float)(2 * s)); /* maybe replace with -2*s? */
46 const float th = (e - 1.f) / (e + 1.f);
47 return (T)(dd * (1 - th * th));
50 template <typename T, typename A> T elu_fwd(T s, A alpha) {
51 return s > 0 ? s : (T)(alpha * (::expf((float)s) - 1.f));
53 template <typename T, typename A> T elu_bwd(T dd, T s, A alpha) {
54 return (T)(dd * (s > 0 ? 1 : alpha * ::expf((float)s)));
63 T square_bwd(T dd, T s) {
69 return s > 0 ? s : -s;
73 T abs_bwd(T dd, T s) {
74 return s > 0 ? dd : s < 0 ? -dd : 0;
79 return s > 0 ? (T)(::sqrtf((float)(s))) : 0;
83 T sqrt_bwd(T dd, T s) {
85 ? (T)(dd / (2 * ::sqrtf((float)(s))))
89 template <typename T, typename A>
90 T linear_fwd(T s, A alpha, A beta) {
91 return (T)(alpha * s + beta);
94 template <typename T, typename A>
95 T linear_bwd(T dd, T s, A alpha, A beta) {
98 return (T)(dd * alpha);
101 template <typename T, typename A>
102 T bounded_relu_fwd(T s, A alpha) {
104 return s > alpha ? (T)(alpha) : s;
107 template <typename T, typename A>
108 T bounded_relu_bwd(T dd, T s, A alpha) {
109 return dd * (0 < s && s < alpha ? 1 : 0);
112 template <typename T>
113 T soft_relu_fwd(T s) {
114 return (T)(::logf(1 + ::expf((float)s)));
117 template <typename T>
118 T soft_relu_bwd(T dd, T s) {
119 return (T)(dd / (1 + ::expf((float)(-s))));
122 template <typename T>
123 T logistic_fwd(T s) {
124 T v = (T)(::expf((float)s));
128 template <typename T>
129 T logistic_bwd(T dd, T s) {
130 T v = (T)(::expf((float)(-s)));
131 return dd * v / ((v + 1) * (v + 1));
134 template <typename T, typename A>
135 T clamp_fwd(T s, A alpha, A beta) {
136 return s > alpha ? (T)(alpha) : s < beta ? (T)(beta) : s;
139 template <typename T, typename A>
140 T clamp_bwd(T dd, T s, A alpha, A beta) {
141 return dd * (beta < s && s < alpha ? 1 : 0);
144 ref_eltwise_scalar_fwd_t::ref_eltwise_scalar_fwd_t(const alg_kind_t alg_, const float alpha_, const float beta_)
145 : alg(alg_), alpha(alpha_), beta(beta_) {
146 using namespace alg_kind;
148 assert(utils::one_of(alg, eltwise_tanh, eltwise_elu,
149 eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
150 eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic, eltwise_clamp));
153 float ref_eltwise_scalar_fwd_t::compute_scalar(float s) {
155 case eltwise_relu: return relu_fwd(s, alpha);
156 case eltwise_tanh: return tanh_fwd(s);
157 case eltwise_elu: return elu_fwd(s, alpha);
158 case eltwise_square: return square_fwd(s);
159 case eltwise_abs: return abs_fwd(s);
160 case eltwise_sqrt: return sqrt_fwd(s);
161 case eltwise_linear: return linear_fwd(s, alpha, beta);
162 case eltwise_bounded_relu: return bounded_relu_fwd(s, alpha);
163 case eltwise_soft_relu: return soft_relu_fwd(s);
164 case eltwise_logistic: return logistic_fwd(s);
165 case eltwise_clamp: return clamp_fwd(s, alpha, beta);
166 default: assert(!"unknown eltwise alg_kind");
172 template <impl::data_type_t data_type>
173 void ref_eltwise_fwd_t<data_type>::execute_forward_nCspBc_padded() {
174 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
175 auto dst = reinterpret_cast<data_t*>(this->memory(0));
177 const memory_desc_wrapper data_d(conf_.src_pd());
178 const blocking_desc_t &blk = data_d.blocking_desc();
179 const int block = blk.block_dims[1];
181 const int MB = conf_.MB();
182 const int C = conf_.C() / block;
183 const int C_PADDED = blk.padding_dims[1] / block;
184 const int tail = conf_.C() % block;
185 const int SP = conf_.D() * conf_.H() * conf_.W();
186 const auto alg_kind = conf_.desc()->alg_kind;
187 const float alpha = conf_.desc()->alpha;
188 const float beta = conf_.desc()->beta;
190 auto ker = [=] (data_t &d, data_t s) {
192 case eltwise_linear: d = linear_fwd(s, alpha, beta); break;
193 case eltwise_bounded_relu:
194 d = bounded_relu_fwd(s, alpha); break;
195 case eltwise_soft_relu: d = soft_relu_fwd(s); break;
196 case eltwise_logistic: d = logistic_fwd(s); break;
197 case eltwise_clamp: d = clamp_fwd(s, alpha, beta); break;
198 default: assert(!"unknown eltwise alg_kind");
202 // FIXME: integer overflow?
204 # pragma omp parallel for collapse(3) schedule(static)
205 for (int n = 0; n < MB; ++n) {
206 for (int c = 0; c < C_PADDED; ++c) {
207 for (int sp = 0; sp < SP; ++sp) {
208 auto d_off = (n*C_PADDED*SP + c*SP + sp) * block;
210 for (int v = 0; v < block; v++)
211 ker(dst[d_off + v], src[d_off + v]);
213 for (int v = 0; v < tail; v++)
214 ker(dst[d_off + v], src[d_off + v]);
221 template <impl::data_type_t data_type>
222 void ref_eltwise_fwd_t<data_type>::execute_forward_generic() {
223 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
224 auto dst = reinterpret_cast<data_t*>(this->memory(0));
226 const memory_desc_wrapper data_d(conf_.src_pd());
228 const int MB = conf_.MB();
229 const int C = conf_.C();
230 const int D = conf_.D();
231 const int H = conf_.H();
232 const int W = conf_.W();
233 const auto alg_kind = conf_.desc()->alg_kind;
234 const float alpha = conf_.desc()->alpha;
235 const float beta = conf_.desc()->beta;
236 const bool is_3d = conf_.desc()->data_desc.ndims == 5;
238 # pragma omp parallel for collapse(5) schedule(static)
239 for (int n = 0; n < MB; ++n) {
240 for (int c = 0; c < C; ++c) {
241 for (int id = 0; id < D; ++id)
242 for (int h = 0; h < H; ++h)
243 for (int w = 0; w < W; ++w) {
245 ? data_d.off(n, c, id, h, w) : data_d.off(n, c, h, w);
246 data_t s = src[d_off];
247 data_t &d = dst[d_off];
249 case eltwise_relu: d = relu_fwd(s, alpha); break;
250 case eltwise_tanh: d = tanh_fwd(s); break;
251 case eltwise_elu: d = elu_fwd(s, alpha); break;
252 case eltwise_square: d = square_fwd(s); break;
253 case eltwise_abs: d = abs_fwd(s); break;
254 case eltwise_sqrt: d = sqrt_fwd(s); break;
255 case eltwise_linear: d = linear_fwd(s, alpha, beta); break;
256 case eltwise_bounded_relu:
257 d = bounded_relu_fwd(s, alpha); break;
258 case eltwise_soft_relu: d = soft_relu_fwd(s); break;
259 case eltwise_logistic: d = logistic_fwd(s); break;
260 case eltwise_clamp: d = clamp_fwd(s, alpha, beta); break;
261 default: assert(!"unknown eltwise alg_kind");
268 template <impl::data_type_t data_type>
269 void ref_eltwise_fwd_t<data_type>::execute_forward_dense() {
270 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
271 auto dst = reinterpret_cast<data_t*>(this->memory(0));
273 const memory_desc_wrapper data_d(conf_.src_pd());
275 const ptrdiff_t nelems = static_cast<ptrdiff_t>(data_d.nelems(true));
276 const auto alg_kind = conf_.desc()->alg_kind;
277 const float alpha = conf_.desc()->alpha;
278 const float beta = conf_.desc()->beta;
280 src += data_d.blocking_desc().offset_padding;
281 dst += data_d.blocking_desc().offset_padding;
283 if (alg_kind == eltwise_relu) {
284 // a fast path for relu as the most popular activation
285 # pragma omp parallel for schedule(static)
286 for (ptrdiff_t e = 0; e < nelems; ++e)
287 dst[e] = relu_fwd(src[e], alpha);
291 # pragma omp parallel for schedule(static)
292 for (ptrdiff_t e = 0; e < nelems; ++e) {
293 const data_t s = src[e];
297 case eltwise_tanh: d = tanh_fwd(s); break;
298 case eltwise_elu: d = elu_fwd(s, alpha); break;
299 case eltwise_square: d = square_fwd(s); break;
300 case eltwise_abs: d = abs_fwd(s); break;
301 case eltwise_sqrt: d = sqrt_fwd(s); break;
302 case eltwise_linear: d = linear_fwd(s, alpha, beta); break;
303 case eltwise_bounded_relu: d = bounded_relu_fwd(s, alpha); break;
304 case eltwise_soft_relu: d = soft_relu_fwd(s); break;
305 case eltwise_logistic: d = logistic_fwd(s); break;
306 case eltwise_clamp: d = clamp_fwd(s, alpha, beta); break;
307 default: assert(!"unknown eltwise alg_kind");
312 template <impl::data_type_t data_type>
313 void ref_eltwise_bwd_t<data_type>::execute_backward_generic() {
314 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
315 auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
316 auto diff_src = reinterpret_cast<data_t*>(this->memory(0));
318 const memory_desc_wrapper data_d(conf_.src_pd());
319 const memory_desc_wrapper diff_data_d(conf_.diff_src_pd());
321 const int MB = conf_.MB();
322 const int C = conf_.C();
323 const int D = conf_.D();
324 const int H = conf_.H();
325 const int W = conf_.W();
326 const auto alg_kind = conf_.desc()->alg_kind;
327 const float alpha = conf_.desc()->alpha;
328 const float beta = conf_.desc()->beta;
329 const bool is_3d = conf_.desc()->data_desc.ndims == 5;
331 # pragma omp parallel for collapse(5) schedule(static)
332 for (int n = 0; n < MB; ++n) {
333 for (int c = 0; c < C; ++c) {
334 for (int d = 0; d < D; ++d)
335 for (int h = 0; h < H; ++h)
336 for (int w = 0; w < W; ++w) {
337 auto data_off = is_3d
338 ? data_d.off(n, c, d, h, w) : data_d.off(n, c, h, w);
339 auto diff_data_off = is_3d
340 ? diff_data_d.off(n, c, d, h, w)
341 : diff_data_d.off(n, c, h, w);
342 data_t s = src[data_off];
343 data_t dd = diff_dst[diff_data_off];
344 data_t &ds = diff_src[diff_data_off];
346 case eltwise_relu: ds = relu_bwd(dd, s, alpha); break;
347 case eltwise_tanh: ds = tanh_bwd(dd, s); break;
348 case eltwise_elu: ds = elu_bwd(dd, s, alpha); break;
349 case eltwise_square: ds = square_bwd(dd, s); break;
350 case eltwise_abs: ds = abs_bwd(dd, s); break;
351 case eltwise_sqrt: ds = sqrt_bwd(dd, s); break;
353 ds = linear_bwd(dd, s, alpha, beta); break;
354 case eltwise_bounded_relu:
355 ds = bounded_relu_bwd(dd, s, alpha); break;
356 case eltwise_soft_relu: ds = soft_relu_bwd(dd, s); break;
357 case eltwise_logistic: ds = logistic_bwd(dd, s); break;
358 case eltwise_clamp: ds = clamp_bwd(dd, s, alpha, beta); break;
359 default: assert(!"unknown eltwise alg_kind");
366 template <impl::data_type_t data_type>
367 void ref_eltwise_bwd_t<data_type>::execute_backward_dense() {
368 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
369 auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
370 auto diff_src = reinterpret_cast<data_t*>(this->memory(0));
372 const memory_desc_wrapper data_d(conf_.src_pd());
373 const memory_desc_wrapper diff_data_d(conf_.diff_src_pd());
375 const ptrdiff_t nelems = static_cast<ptrdiff_t>(data_d.nelems(true));
376 const auto alg_kind = conf_.desc()->alg_kind;
377 const float alpha = conf_.desc()->alpha;
378 const float beta = conf_.desc()->beta;
380 src += data_d.blocking_desc().offset_padding;
381 diff_dst += diff_data_d.blocking_desc().offset_padding;
382 diff_src += diff_data_d.blocking_desc().offset_padding;
384 # pragma omp parallel for schedule(static)
385 for (ptrdiff_t e = 0; e < nelems; ++e) {
386 const data_t dd = diff_dst[e];
387 const data_t s = src[e];
388 data_t &ds = diff_src[e];
391 case eltwise_relu: ds = relu_bwd(dd, s, alpha); break;
392 case eltwise_tanh: ds = tanh_bwd(dd, s); break;
393 case eltwise_elu: ds = elu_bwd(dd, s, alpha); break;
394 case eltwise_square: ds = square_bwd(dd, s); break;
395 case eltwise_abs: ds = abs_bwd(dd, s); break;
396 case eltwise_sqrt: ds = sqrt_bwd(dd, s); break;
397 case eltwise_linear: ds = linear_bwd(dd, s, alpha, beta); break;
398 case eltwise_bounded_relu: ds = bounded_relu_bwd(dd, s, alpha); break;
399 case eltwise_soft_relu: ds = soft_relu_bwd(dd, s); break;
400 case eltwise_logistic: ds = logistic_bwd(dd, s); break;
401 case eltwise_clamp: ds = clamp_bwd(dd, s, alpha, beta); break;
402 default: assert(!"unknown eltwise alg_kind");
407 template struct ref_eltwise_fwd_t<data_type::f32>;
408 template struct ref_eltwise_fwd_t<data_type::s32>;
409 template struct ref_eltwise_fwd_t<data_type::s16>;
410 template struct ref_eltwise_fwd_t<data_type::s8>;
411 template struct ref_eltwise_fwd_t<data_type::u8>;
413 template struct ref_eltwise_bwd_t<data_type::f32>;
414 template struct ref_eltwise_bwd_t<data_type::s32>;
415 template struct ref_eltwise_bwd_t<data_type::s16>;
421 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s