1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
19 #include "c_types_map.hpp"
20 #include "type_helpers.hpp"
21 #include "math_utils.hpp"
22 #include "mkldnn_thread.hpp"
24 #include "ref_eltwise.hpp"
30 using namespace alg_kind;
33 ref_eltwise_scalar_fwd_t::ref_eltwise_scalar_fwd_t(alg_kind_t alg, float alpha,
34 float beta): alg_(alg), alpha_(alpha), beta_(beta) {
35 assert(utils::one_of(alg_, eltwise_relu, eltwise_tanh, eltwise_elu,
36 eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
37 eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic,
38 eltwise_clamp, eltwise_exp, eltwise_not));
41 ref_eltwise_scalar_fwd_t::ref_eltwise_scalar_fwd_t(
42 const post_ops_t::entry_t::eltwise_t &eltwise)
43 : ref_eltwise_scalar_fwd_t(eltwise.alg, eltwise.alpha, eltwise.beta) {}
45 float ref_eltwise_scalar_fwd_t::compute_scalar(float s) {
47 case eltwise_relu: return relu_fwd(s, alpha_);
48 case eltwise_tanh: return tanh_fwd(s);
49 case eltwise_elu: return elu_fwd(s, alpha_);
50 case eltwise_square: return square_fwd(s);
51 case eltwise_abs: return abs_fwd(s);
52 case eltwise_sqrt: return sqrt_fwd(s);
53 case eltwise_linear: return linear_fwd(s, alpha_, beta_);
54 case eltwise_bounded_relu: return bounded_relu_fwd(s, alpha_);
55 case eltwise_soft_relu: return soft_relu_fwd(s);
56 case eltwise_logistic: return logistic_fwd(s);
57 case eltwise_clamp: return clamp_fwd(s, alpha_, beta_);
58 case eltwise_exp: return exp_fwd(s);
59 case eltwise_not: return not_fwd(s);
60 default: assert(!"unknown eltwise alg_kind");
66 template <impl::data_type_t data_type>
67 void ref_eltwise_fwd_t<data_type>::execute_forward_nCspBc_padded() const {
68 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
69 auto dst = reinterpret_cast<data_t*>(this->memory(0));
71 const memory_desc_wrapper data_d(pd()->src_pd());
72 const blocking_desc_t &blk = data_d.blocking_desc();
73 const int block = blk.block_dims[1];
75 const int MB = pd()->MB();
76 const int C = pd()->C() / block;
77 const int C_PADDED = blk.padding_dims[1] / block;
78 const int tail = pd()->C() % block;
79 const int SP = pd()->D() * pd()->H() * pd()->W();
80 const auto alg_kind = pd()->desc()->alg_kind;
81 const float alpha = pd()->desc()->alpha;
82 const float beta = pd()->desc()->beta;
84 auto ker = [=] (data_t &d, data_t s) {
86 case eltwise_linear: d = linear_fwd(s, alpha, beta); break;
87 case eltwise_bounded_relu:
88 d = bounded_relu_fwd(s, alpha); break;
89 case eltwise_soft_relu: d = soft_relu_fwd(s); break;
90 case eltwise_logistic: d = logistic_fwd(s); break;
91 case eltwise_clamp: d = clamp_fwd(s, alpha, beta); break;
92 case eltwise_exp: d = exp_fwd(s); break;
93 case eltwise_not: d = not_fwd(s); break;
94 default: assert(!"unknown eltwise alg_kind");
98 // FIXME: integer overflow?
100 parallel_nd(MB, C_PADDED, SP,
101 [&](int n, int c, int sp) {
102 auto d_off = (n*C_PADDED*SP + c*SP + sp) * block;
104 for (int v = 0; v < block; v++)
105 ker(dst[d_off + v], src[d_off + v]);
107 for (int v = 0; v < tail; v++)
108 ker(dst[d_off + v], src[d_off + v]);
113 template <impl::data_type_t data_type>
114 void ref_eltwise_fwd_t<data_type>::execute_forward_generic() const {
115 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
116 auto dst = reinterpret_cast<data_t*>(this->memory(0));
119 if (pd()->has_zero_dim_memory()) return;
121 const memory_desc_wrapper data_d(pd()->src_pd());
123 const int MB = pd()->MB();
124 const int C = pd()->C();
125 const int D = pd()->D();
126 const int H = pd()->H();
127 const int W = pd()->W();
128 const auto alg_kind = pd()->desc()->alg_kind;
129 const float alpha = pd()->desc()->alpha;
130 const float beta = pd()->desc()->beta;
131 const bool is_3d = pd()->desc()->data_desc.ndims == 5;
133 parallel_nd(MB, C, D, H, W,
134 [&](int n, int c, int id, int h, int w) {
136 ? data_d.off(n, c, id, h, w) : data_d.off(n, c, h, w);
137 data_t s = src[d_off];
138 data_t &d = dst[d_off];
140 case eltwise_relu: d = relu_fwd(s, alpha); break;
141 case eltwise_tanh: d = tanh_fwd(s); break;
142 case eltwise_elu: d = elu_fwd(s, alpha); break;
143 case eltwise_square: d = square_fwd(s); break;
144 case eltwise_abs: d = abs_fwd(s); break;
145 case eltwise_sqrt: d = sqrt_fwd(s); break;
146 case eltwise_linear: d = linear_fwd(s, alpha, beta); break;
147 case eltwise_bounded_relu:
148 d = bounded_relu_fwd(s, alpha); break;
149 case eltwise_soft_relu: d = soft_relu_fwd(s); break;
150 case eltwise_logistic: d = logistic_fwd(s); break;
151 case eltwise_clamp: d = clamp_fwd(s, alpha, beta); break;
152 case eltwise_exp: d = exp_fwd(s); break;
153 case eltwise_not: d = not_fwd(s); break;
154 default: assert(!"unknown eltwise alg_kind");
159 template <impl::data_type_t data_type>
160 void ref_eltwise_fwd_t<data_type>::execute_forward_dense() const {
161 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
162 auto dst = reinterpret_cast<data_t*>(this->memory(0));
164 const memory_desc_wrapper data_d(pd()->src_pd());
166 const ptrdiff_t nelems = static_cast<ptrdiff_t>(data_d.nelems(true));
167 const auto alg_kind = pd()->desc()->alg_kind;
168 const float alpha = pd()->desc()->alpha;
169 const float beta = pd()->desc()->beta;
171 src += data_d.blocking_desc().offset_padding;
172 dst += data_d.blocking_desc().offset_padding;
174 if (alg_kind == eltwise_relu) {
175 // a fast path for relu as the most popular activation
176 parallel_nd(nelems, [&](ptrdiff_t e) {
177 dst[e] = relu_fwd(src[e], alpha);
182 parallel_nd(nelems, [&](ptrdiff_t e) {
183 const data_t s = src[e];
187 case eltwise_tanh: d = tanh_fwd(s); break;
188 case eltwise_elu: d = elu_fwd(s, alpha); break;
189 case eltwise_square: d = square_fwd(s); break;
190 case eltwise_abs: d = abs_fwd(s); break;
191 case eltwise_sqrt: d = sqrt_fwd(s); break;
192 case eltwise_linear: d = linear_fwd(s, alpha, beta); break;
193 case eltwise_bounded_relu: d = bounded_relu_fwd(s, alpha); break;
194 case eltwise_soft_relu: d = soft_relu_fwd(s); break;
195 case eltwise_logistic: d = logistic_fwd(s); break;
196 case eltwise_clamp: d = clamp_fwd(s, alpha, beta); break;
197 case eltwise_exp: d = exp_fwd(s); break;
198 case eltwise_not: d = not_fwd(s); break;
199 default: assert(!"unknown eltwise alg_kind");
204 template <impl::data_type_t data_type>
205 void ref_eltwise_bwd_t<data_type>::execute_backward_generic() const {
206 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
207 auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
208 auto diff_src = reinterpret_cast<data_t*>(this->memory(0));
211 if (pd()->has_zero_dim_memory()) return;
213 const memory_desc_wrapper data_d(pd()->src_pd());
214 const memory_desc_wrapper diff_data_d(pd()->diff_src_pd());
216 const int MB = pd()->MB();
217 const int C = pd()->C();
218 const int D = pd()->D();
219 const int H = pd()->H();
220 const int W = pd()->W();
221 const auto alg_kind = pd()->desc()->alg_kind;
222 const float alpha = pd()->desc()->alpha;
223 const float beta = pd()->desc()->beta;
224 const bool is_3d = pd()->desc()->data_desc.ndims == 5;
226 parallel_nd(MB, C, D, H, W,
227 [&](int n, int c, int d, int h, int w) {
228 auto data_off = is_3d
229 ? data_d.off(n, c, d, h, w) : data_d.off(n, c, h, w);
230 auto diff_data_off = is_3d
231 ? diff_data_d.off(n, c, d, h, w)
232 : diff_data_d.off(n, c, h, w);
233 data_t s = src[data_off];
234 data_t dd = diff_dst[diff_data_off];
235 data_t &ds = diff_src[diff_data_off];
237 case eltwise_relu: ds = relu_bwd(dd, s, alpha); break;
238 case eltwise_tanh: ds = tanh_bwd(dd, s); break;
239 case eltwise_elu: ds = elu_bwd(dd, s, alpha); break;
240 case eltwise_square: ds = square_bwd(dd, s); break;
241 case eltwise_abs: ds = abs_bwd(dd, s); break;
242 case eltwise_sqrt: ds = sqrt_bwd(dd, s); break;
244 ds = linear_bwd(dd, s, alpha, beta); break;
245 case eltwise_bounded_relu:
246 ds = bounded_relu_bwd(dd, s, alpha); break;
247 case eltwise_soft_relu: ds = soft_relu_bwd(dd, s); break;
248 case eltwise_logistic: ds = logistic_bwd(dd, s); break;
249 case eltwise_clamp: ds = clamp_bwd(dd, s, alpha, beta); break;
250 case eltwise_exp: ds = exp_bwd(dd, s); break;
251 default: assert(!"unknown eltwise alg_kind");
256 template <impl::data_type_t data_type>
257 void ref_eltwise_bwd_t<data_type>::execute_backward_dense() const {
258 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
259 auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
260 auto diff_src = reinterpret_cast<data_t*>(this->memory(0));
262 const memory_desc_wrapper data_d(pd()->src_pd());
263 const memory_desc_wrapper diff_data_d(pd()->diff_src_pd());
265 const ptrdiff_t nelems = static_cast<ptrdiff_t>(data_d.nelems(true));
266 const auto alg_kind = pd()->desc()->alg_kind;
267 const float alpha = pd()->desc()->alpha;
268 const float beta = pd()->desc()->beta;
270 src += data_d.blocking_desc().offset_padding;
271 diff_dst += diff_data_d.blocking_desc().offset_padding;
272 diff_src += diff_data_d.blocking_desc().offset_padding;
274 parallel_nd(nelems, [&](ptrdiff_t e) {
275 const data_t dd = diff_dst[e];
276 const data_t s = src[e];
277 data_t &ds = diff_src[e];
280 case eltwise_relu: ds = relu_bwd(dd, s, alpha); break;
281 case eltwise_tanh: ds = tanh_bwd(dd, s); break;
282 case eltwise_elu: ds = elu_bwd(dd, s, alpha); break;
283 case eltwise_square: ds = square_bwd(dd, s); break;
284 case eltwise_abs: ds = abs_bwd(dd, s); break;
285 case eltwise_sqrt: ds = sqrt_bwd(dd, s); break;
286 case eltwise_linear: ds = linear_bwd(dd, s, alpha, beta); break;
287 case eltwise_bounded_relu: ds = bounded_relu_bwd(dd, s, alpha); break;
288 case eltwise_soft_relu: ds = soft_relu_bwd(dd, s); break;
289 case eltwise_logistic: ds = logistic_bwd(dd, s); break;
290 case eltwise_clamp: ds = clamp_bwd(dd, s, alpha, beta); break;
291 case eltwise_exp: ds = exp_bwd(dd, s); break;
292 default: assert(!"unknown eltwise alg_kind");
297 template struct ref_eltwise_fwd_t<data_type::f32>;
298 template struct ref_eltwise_fwd_t<data_type::s32>;
299 template struct ref_eltwise_fwd_t<data_type::s16>;
300 template struct ref_eltwise_fwd_t<data_type::s8>;
301 template struct ref_eltwise_fwd_t<data_type::u8>;
303 template struct ref_eltwise_bwd_t<data_type::f32>;
304 template struct ref_eltwise_bwd_t<data_type::s32>;
305 template struct ref_eltwise_bwd_t<data_type::s16>;
311 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s