namespace cpu {
using math::saturate;
+using math::get_bias;
template <data_type_t src_type, data_type_t wei_type, data_type_t dst_type,
data_type_t acc_type>
void ref_inner_product_fwd_t<src_type, wei_type, dst_type, acc_type>
- ::execute_forward() {
+ ::execute_forward() const {
auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
auto bias = reinterpret_cast<const char *>(this->input_memory(2));
auto dst = reinterpret_cast<dst_data_t *>(this->memory());
- const memory_desc_wrapper src_d(conf_.src_pd());
- const memory_desc_wrapper dst_d(conf_.dst_pd());
- const memory_desc_wrapper weights_d(conf_.weights_pd(0));
- const memory_desc_wrapper bias_d(conf_.weights_pd(1));
+ const memory_desc_wrapper src_d(pd()->src_pd());
+ const memory_desc_wrapper dst_d(pd()->dst_pd());
+ const memory_desc_wrapper weights_d(pd()->weights_pd(0));
+ const memory_desc_wrapper bias_d(pd()->weights_pd(1));
- const int MB = conf_.MB();
- const int OC = conf_.OC();
- const int IC = conf_.IC();
+ const int MB = pd()->MB();
+ const int OC = pd()->OC();
+ const int IC = pd()->IC();
const bool src_has_spatial = utils::one_of(src_d.ndims(), 4, 5);
const bool is_3d = src_d.ndims() == 5;
- const auto &post_ops = conf_.attr()->post_ops_;
+ const auto &post_ops = pd()->attr()->post_ops_;
const bool do_relu = post_ops.len_ == 1;
const float nslope = do_relu ? post_ops.entry_[0].eltwise.alpha : 0.f;
- auto ker_has_spatial = [=](acc_data_t &d, int mb, int oc) {
- const int KD = conf_.KD();
- const int KH = conf_.KH();
- const int KW = conf_.KW();
+ auto ker_has_spatial = [=](int mb, int oc) {
+ acc_data_t d = 0;
+ const int KD = pd()->KD();
+ const int KH = pd()->KH();
+ const int KW = pd()->KW();
for (int ic = 0; ic < IC; ++ic) {
for (int kd = 0; kd < KD; ++kd) {
for (int kh = 0; kh < KH; ++kh) {
}
}
}
+ return d;
};
- auto ker_no_spatial = [=](acc_data_t &d, int mb, int oc) {
+ auto ker_no_spatial = [=](int mb, int oc) {
+ acc_data_t d = 0;
for (int ic = 0; ic < IC; ++ic) {
d += (acc_data_t)src[src_d.off(mb, ic)]
* weights[weights_d.off(oc, ic)];
}
- };
-
- auto get_bias = [=, &bias](size_t off) -> acc_data_t {
-# define CASE(dt) case dt: \
- return (acc_data_t)(*((const prec_traits<dt>::type *)bias + off))
- switch (conf_.desc()->bias_desc.data_type) {
- CASE(data_type::s8);
- CASE(data_type::u8);
- CASE(data_type::s32);
- CASE(data_type::f32);
- default: assert(!"unimplemented");
- }
-# undef CASE
- return 0;
+ return d;
};
parallel_nd(MB, OC, [&](int mb, int oc) {
- acc_data_t a = bias ? get_bias(bias_d.off(oc)) : (acc_data_t)0;
- if (src_has_spatial) {
- ker_has_spatial(a, mb, oc);
- } else {
- ker_no_spatial(a, mb, oc);
- }
- if (do_relu && a < (acc_data_t)0) {
- float ds = (float)a * nslope;
- dst[dst_d.off(mb, oc)] = saturate<dst_data_t>(ds);
- } else {
- dst[dst_d.off(mb, oc)] = saturate<dst_data_t>(a);
- }
+ float a = bias
+ ? get_bias(bias, bias_d.off(oc), pd()->desc()->bias_desc.data_type)
+ : 0;
+ if (src_has_spatial)
+ a += ker_has_spatial(mb, oc);
+ else
+ a += ker_no_spatial(mb, oc);
+ if (do_relu && a < (acc_data_t)0)
+ a *= nslope;
+ dst[dst_d.off(mb, oc)] = saturate<dst_data_t>(a);
});
}
using namespace data_type;
template <data_type_t diff_src_type, data_type_t wei_type,
data_type_t diff_dst_type, data_type_t acc_type>
void ref_inner_product_bwd_data_t<diff_src_type, wei_type, diff_dst_type,
- acc_type>::execute_backward_data() {
+ acc_type>::execute_backward_data() const {
auto diff_dst = reinterpret_cast<const diff_dst_data_t *>(
this->input_memory(0));
auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
auto diff_src = reinterpret_cast<diff_src_data_t*>(this->memory());
- const memory_desc_wrapper diff_dst_d(conf_.diff_dst_pd());
- const memory_desc_wrapper weights_d(conf_.weights_pd(0));
- const memory_desc_wrapper diff_src_d(conf_.diff_src_pd());
+ const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
+ const memory_desc_wrapper weights_d(pd()->weights_pd(0));
+ const memory_desc_wrapper diff_src_d(pd()->diff_src_pd());
- const int MB = conf_.MB();
- const int OC = conf_.OC();
- const int IC = conf_.IC();
+ const int MB = pd()->MB();
+ const int OC = pd()->OC();
+ const int IC = pd()->IC();
const bool diff_src_has_spatial = utils::one_of(diff_src_d.ndims(), 4, 5);
parallel_nd(MB, IC, [&](int mb, int ic) {
if (diff_src_has_spatial) {
- const int KD = conf_.KD();
- const int KH = conf_.KH();
- const int KW = conf_.KW();
+ const int KD = pd()->KD();
+ const int KH = pd()->KH();
+ const int KW = pd()->KW();
for (int kd = 0; kd < KD; ++kd)
for (int kh = 0; kh < KH; ++kh)
for (int kw = 0; kw < KW; ++kw) {
template struct ref_inner_product_bwd_data_t<s32, s16, s16, s32>;
template <impl::data_type_t data_type>
-void ref_inner_product_bwd_weights_t<data_type>::execute_backward_weights() {
+void ref_inner_product_bwd_weights_t<data_type>::execute_backward_weights() const {
auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
auto diff_weights = reinterpret_cast<data_t*>(this->memory(0));
auto diff_bias = reinterpret_cast<data_t*>(this->memory(1));
- const memory_desc_wrapper src_d(conf_.src_pd());
- const memory_desc_wrapper diff_dst_d(conf_.diff_dst_pd());
- const memory_desc_wrapper diff_weights_d(conf_.diff_weights_pd(0));
- const memory_desc_wrapper diff_bias_d(conf_.diff_weights_pd(1));
+ const memory_desc_wrapper src_d(pd()->src_pd());
+ const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
+ const memory_desc_wrapper diff_weights_d(pd()->diff_weights_pd(0));
+ const memory_desc_wrapper diff_bias_d(pd()->diff_weights_pd(1));
- const int MB = conf_.MB();
- const int OC = conf_.OC();
- const int IC = conf_.IC();
+ const int MB = pd()->MB();
+ const int OC = pd()->OC();
+ const int IC = pd()->IC();
const bool src_has_spatial = utils::one_of(src_d.ndims(), 4 ,5);
parallel_nd(OC, IC, [&](int oc, int ic) {
if (src_has_spatial) {
- const int KD = conf_.KD();
- const int KH = conf_.KH();
- const int KW = conf_.KW();
+ const int KD = pd()->KD();
+ const int KH = pd()->KH();
+ const int KW = pd()->KW();
for (int kd = 0; kd < KD; ++kd) {
for (int kh = 0; kh < KH; ++kh) {
for (int kw = 0; kw < KW; ++kw) {