using namespace alg_kind;
using namespace math;
-ref_eltwise_scalar_fwd_t::ref_eltwise_scalar_fwd_t(const alg_kind_t alg_, const float alpha_, const float beta_)
- : alg(alg_), alpha(alpha_), beta(beta_) {
- using namespace alg_kind;
-
- assert(utils::one_of(alg, eltwise_relu, eltwise_tanh, eltwise_elu,
- eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
- eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic, eltwise_clamp));
+ref_eltwise_scalar_fwd_t::ref_eltwise_scalar_fwd_t(alg_kind_t alg, float alpha,
+ float beta): alg_(alg), alpha_(alpha), beta_(beta) {
+ assert(utils::one_of(alg_, eltwise_relu, eltwise_tanh, eltwise_elu,
+ eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
+ eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic,
+ eltwise_clamp, eltwise_exp, eltwise_not));
}
+ref_eltwise_scalar_fwd_t::ref_eltwise_scalar_fwd_t(
+ const post_ops_t::entry_t::eltwise_t &eltwise)
+ : ref_eltwise_scalar_fwd_t(eltwise.alg, eltwise.alpha, eltwise.beta) {}
+
float ref_eltwise_scalar_fwd_t::compute_scalar(float s) {
- switch (alg) {
- case eltwise_relu: return relu_fwd(s, alpha);
- case eltwise_tanh: return tanh_fwd(s);
- case eltwise_elu: return elu_fwd(s, alpha);
+ switch (alg_) {
+ case eltwise_relu: return relu_fwd(s, alpha_);
+ case eltwise_tanh: return tanh_fwd(s);
+ case eltwise_elu: return elu_fwd(s, alpha_);
case eltwise_square: return square_fwd(s);
- case eltwise_abs: return abs_fwd(s);
- case eltwise_sqrt: return sqrt_fwd(s);
- case eltwise_linear: return linear_fwd(s, alpha, beta);
- case eltwise_bounded_relu: return bounded_relu_fwd(s, alpha);
+ case eltwise_abs: return abs_fwd(s);
+ case eltwise_sqrt: return sqrt_fwd(s);
+ case eltwise_linear: return linear_fwd(s, alpha_, beta_);
+ case eltwise_bounded_relu: return bounded_relu_fwd(s, alpha_);
case eltwise_soft_relu: return soft_relu_fwd(s);
case eltwise_logistic: return logistic_fwd(s);
- case eltwise_clamp: return clamp_fwd(s, alpha, beta);
+ case eltwise_clamp: return clamp_fwd(s, alpha_, beta_);
+ case eltwise_exp: return exp_fwd(s);
+ case eltwise_not: return not_fwd(s);
default: assert(!"unknown eltwise alg_kind");
}
- return 0.0f;
+ return 0.f;
}
template <impl::data_type_t data_type>
-void ref_eltwise_fwd_t<data_type>::execute_forward_nCspBc_padded() {
+void ref_eltwise_fwd_t<data_type>::execute_forward_nCspBc_padded() const {
auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
auto dst = reinterpret_cast<data_t*>(this->memory(0));
- const memory_desc_wrapper data_d(conf_.src_pd());
+ const memory_desc_wrapper data_d(pd()->src_pd());
const blocking_desc_t &blk = data_d.blocking_desc();
const int block = blk.block_dims[1];
- const int MB = conf_.MB();
- const int C = conf_.C() / block;
+ const int MB = pd()->MB();
+ const int C = pd()->C() / block;
const int C_PADDED = blk.padding_dims[1] / block;
- const int tail = conf_.C() % block;
- const int SP = conf_.D() * conf_.H() * conf_.W();
- const auto alg_kind = conf_.desc()->alg_kind;
- const float alpha = conf_.desc()->alpha;
- const float beta = conf_.desc()->beta;
+ const int tail = pd()->C() % block;
+ const int SP = pd()->D() * pd()->H() * pd()->W();
+ const auto alg_kind = pd()->desc()->alg_kind;
+ const float alpha = pd()->desc()->alpha;
+ const float beta = pd()->desc()->beta;
auto ker = [=] (data_t &d, data_t s) {
switch (alg_kind) {
case eltwise_soft_relu: d = soft_relu_fwd(s); break;
case eltwise_logistic: d = logistic_fwd(s); break;
case eltwise_clamp: d = clamp_fwd(s, alpha, beta); break;
+ case eltwise_exp: d = exp_fwd(s); break;
+ case eltwise_not: d = not_fwd(s); break;
default: assert(!"unknown eltwise alg_kind");
}
};
}
template <impl::data_type_t data_type>
-void ref_eltwise_fwd_t<data_type>::execute_forward_generic() {
+void ref_eltwise_fwd_t<data_type>::execute_forward_generic() const {
auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
auto dst = reinterpret_cast<data_t*>(this->memory(0));
/* fast return */
- if (conf_.has_zero_dim_memory()) return;
+ if (pd()->has_zero_dim_memory()) return;
- const memory_desc_wrapper data_d(conf_.src_pd());
+ const memory_desc_wrapper data_d(pd()->src_pd());
- const int MB = conf_.MB();
- const int C = conf_.C();
- const int D = conf_.D();
- const int H = conf_.H();
- const int W = conf_.W();
- const auto alg_kind = conf_.desc()->alg_kind;
- const float alpha = conf_.desc()->alpha;
- const float beta = conf_.desc()->beta;
- const bool is_3d = conf_.desc()->data_desc.ndims == 5;
+ const int MB = pd()->MB();
+ const int C = pd()->C();
+ const int D = pd()->D();
+ const int H = pd()->H();
+ const int W = pd()->W();
+ const auto alg_kind = pd()->desc()->alg_kind;
+ const float alpha = pd()->desc()->alpha;
+ const float beta = pd()->desc()->beta;
+ const bool is_3d = pd()->desc()->data_desc.ndims == 5;
parallel_nd(MB, C, D, H, W,
[&](int n, int c, int id, int h, int w) {
case eltwise_soft_relu: d = soft_relu_fwd(s); break;
case eltwise_logistic: d = logistic_fwd(s); break;
case eltwise_clamp: d = clamp_fwd(s, alpha, beta); break;
+ case eltwise_exp: d = exp_fwd(s); break;
+ case eltwise_not: d = not_fwd(s); break;
default: assert(!"unknown eltwise alg_kind");
}
});
}
template <impl::data_type_t data_type>
-void ref_eltwise_fwd_t<data_type>::execute_forward_dense() {
+void ref_eltwise_fwd_t<data_type>::execute_forward_dense() const {
auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
auto dst = reinterpret_cast<data_t*>(this->memory(0));
- const memory_desc_wrapper data_d(conf_.src_pd());
+ const memory_desc_wrapper data_d(pd()->src_pd());
const ptrdiff_t nelems = static_cast<ptrdiff_t>(data_d.nelems(true));
- const auto alg_kind = conf_.desc()->alg_kind;
- const float alpha = conf_.desc()->alpha;
- const float beta = conf_.desc()->beta;
+ const auto alg_kind = pd()->desc()->alg_kind;
+ const float alpha = pd()->desc()->alpha;
+ const float beta = pd()->desc()->beta;
src += data_d.blocking_desc().offset_padding;
dst += data_d.blocking_desc().offset_padding;
case eltwise_soft_relu: d = soft_relu_fwd(s); break;
case eltwise_logistic: d = logistic_fwd(s); break;
case eltwise_clamp: d = clamp_fwd(s, alpha, beta); break;
+ case eltwise_exp: d = exp_fwd(s); break;
+ case eltwise_not: d = not_fwd(s); break;
default: assert(!"unknown eltwise alg_kind");
}
});
}
template <impl::data_type_t data_type>
-void ref_eltwise_bwd_t<data_type>::execute_backward_generic() {
+void ref_eltwise_bwd_t<data_type>::execute_backward_generic() const {
auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
auto diff_src = reinterpret_cast<data_t*>(this->memory(0));
/* fast return */
- if (conf_.has_zero_dim_memory()) return;
+ if (pd()->has_zero_dim_memory()) return;
- const memory_desc_wrapper data_d(conf_.src_pd());
- const memory_desc_wrapper diff_data_d(conf_.diff_src_pd());
+ const memory_desc_wrapper data_d(pd()->src_pd());
+ const memory_desc_wrapper diff_data_d(pd()->diff_src_pd());
- const int MB = conf_.MB();
- const int C = conf_.C();
- const int D = conf_.D();
- const int H = conf_.H();
- const int W = conf_.W();
- const auto alg_kind = conf_.desc()->alg_kind;
- const float alpha = conf_.desc()->alpha;
- const float beta = conf_.desc()->beta;
- const bool is_3d = conf_.desc()->data_desc.ndims == 5;
+ const int MB = pd()->MB();
+ const int C = pd()->C();
+ const int D = pd()->D();
+ const int H = pd()->H();
+ const int W = pd()->W();
+ const auto alg_kind = pd()->desc()->alg_kind;
+ const float alpha = pd()->desc()->alpha;
+ const float beta = pd()->desc()->beta;
+ const bool is_3d = pd()->desc()->data_desc.ndims == 5;
parallel_nd(MB, C, D, H, W,
[&](int n, int c, int d, int h, int w) {
case eltwise_soft_relu: ds = soft_relu_bwd(dd, s); break;
case eltwise_logistic: ds = logistic_bwd(dd, s); break;
case eltwise_clamp: ds = clamp_bwd(dd, s, alpha, beta); break;
+ case eltwise_exp: ds = exp_bwd(dd, s); break;
default: assert(!"unknown eltwise alg_kind");
}
});
}
template <impl::data_type_t data_type>
-void ref_eltwise_bwd_t<data_type>::execute_backward_dense() {
+void ref_eltwise_bwd_t<data_type>::execute_backward_dense() const {
auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
auto diff_src = reinterpret_cast<data_t*>(this->memory(0));
- const memory_desc_wrapper data_d(conf_.src_pd());
- const memory_desc_wrapper diff_data_d(conf_.diff_src_pd());
+ const memory_desc_wrapper data_d(pd()->src_pd());
+ const memory_desc_wrapper diff_data_d(pd()->diff_src_pd());
const ptrdiff_t nelems = static_cast<ptrdiff_t>(data_d.nelems(true));
- const auto alg_kind = conf_.desc()->alg_kind;
- const float alpha = conf_.desc()->alpha;
- const float beta = conf_.desc()->beta;
+ const auto alg_kind = pd()->desc()->alg_kind;
+ const float alpha = pd()->desc()->alpha;
+ const float beta = pd()->desc()->beta;
src += data_d.blocking_desc().offset_padding;
diff_dst += diff_data_d.blocking_desc().offset_padding;
case eltwise_soft_relu: ds = soft_relu_bwd(dd, s); break;
case eltwise_logistic: ds = logistic_bwd(dd, s); break;
case eltwise_clamp: ds = clamp_bwd(dd, s, alpha, beta); break;
+ case eltwise_exp: ds = exp_bwd(dd, s); break;
default: assert(!"unknown eltwise alg_kind");
}
});