namespace cpu {
using math::saturate;
+using math::get_bias;
-template <bool with_relu, data_type_t src_type, data_type_t wei_type,
+template <data_type_t src_type, data_type_t wei_type,
data_type_t dst_type, data_type_t acc_type>
-void _ref_convolution_fwd_t<with_relu, src_type, wei_type, dst_type, acc_type>
- ::execute_forward() {
+void ref_convolution_fwd_t<src_type, wei_type, dst_type, acc_type>
+ ::execute_forward() const {
auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
auto bias = reinterpret_cast<const char *>(this->input_memory(2));
auto dst = reinterpret_cast<dst_data_t *>(this->memory());
- const memory_desc_wrapper src_d(conf_.src_pd());
- const memory_desc_wrapper dst_d(conf_.dst_pd());
- const memory_desc_wrapper weights_d(conf_.weights_pd(0));
- const memory_desc_wrapper bias_d(conf_.weights_pd(1));
+ const memory_desc_wrapper src_d(pd()->src_pd());
+ const memory_desc_wrapper dst_d(pd()->dst_pd());
+ const memory_desc_wrapper weights_d(pd()->weights_pd(0));
+ const memory_desc_wrapper bias_d(pd()->weights_pd(1));
- const bool with_groups = conf_.with_groups();
+ const bool with_groups = pd()->with_groups();
- const int G = conf_.G();
- const int MB = conf_.MB();
- const int OD = conf_.OD();
- const int OH = conf_.OH();
- const int OW = conf_.OW();
- const int ID = conf_.ID();
- const int IH = conf_.IH();
- const int IW = conf_.IW();
+ const int G = pd()->G();
+ const int MB = pd()->MB();
+ const int OD = pd()->OD();
+ const int OH = pd()->OH();
+ const int OW = pd()->OW();
+ const int ID = pd()->ID();
+ const int IH = pd()->IH();
+ const int IW = pd()->IW();
- const int OC = conf_.OC() / G;
- const int IC = conf_.IC() / G;
- const int KD = conf_.KD();
- const int KH = conf_.KH();
- const int KW = conf_.KW();
+ const int OC = pd()->OC() / G;
+ const int IC = pd()->IC() / G;
+ const int KD = pd()->KD();
+ const int KH = pd()->KH();
+ const int KW = pd()->KW();
- const int KSD = conf_.KSD();
- const int KSH = conf_.KSH();
- const int KSW = conf_.KSW();
+ const int KSD = pd()->KSD();
+ const int KSH = pd()->KSH();
+ const int KSW = pd()->KSW();
- const int KDD = conf_.KDD();
- const int KDH = conf_.KDH();
- const int KDW = conf_.KDW();
+ const int KDD = pd()->KDD();
+ const int KDH = pd()->KDH();
+ const int KDW = pd()->KDW();
- const int padFront = conf_.padFront();
- const int padT = conf_.padT();
- const int padL = conf_.padL();
+ const int padFront = pd()->padFront();
+ const int padT = pd()->padT();
+ const int padL = pd()->padL();
- const float nslope = conf_.negative_slope();
+ const bool with_relu = 0; // TODO: change if support post_ops
+ const float nslope = 0.f;
- const int ndims = conf_.cdesc()->src_desc.ndims;
+ const int ndims = pd()->desc()->src_desc.ndims;
- auto ker = [=](acc_data_t &d, int g, int mb, int oc, int od, int oh,
+ auto ker = [=](int g, int mb, int oc, int od, int oh,
int ow) {
+ acc_data_t d = 0;
for (int ic = 0; ic < IC; ++ic)
for (int kd = 0; kd < KD; ++kd)
for (int kh = 0; kh < KH; ++kh)
else
assert(false);
- }
- };
- auto get_bias = [=, &bias](size_t off) -> float {
-# define CASE(dt) case dt: \
- return (float)(*((const prec_traits<dt>::type *)bias + off))
- switch (conf_.cdesc()->bias_desc.data_type) {
- CASE(data_type::s8);
- CASE(data_type::u8);
- CASE(data_type::s32);
- CASE(data_type::f32);
- default: assert(!"unimplemented");
}
-# undef CASE
- return 0;
+ return d;
};
+
parallel_nd(G, MB, OC, OD, OH, OW,
[&](int g, int mb, int oc, int od, int oh, int ow) {
- acc_data_t a = 0;
- ker(a, g, mb, oc, od, oh, ow);
-
- float a_fp = (float)a;
+ float a_fp = ker(g, mb, oc, od, oh, ow);
if (bias)
- a_fp += get_bias(bias_d.off(g*OC + oc));
+ a_fp += get_bias(bias, bias_d.off(g * OC + oc),
+ pd()->desc()->bias_desc.data_type);
if (with_relu && a_fp < 0)
a_fp *= nslope;
if (data_traits<dst_data_t>::data_type != data_type::f32) {
- switch (conf_.attr()->round_mode_) {
+ switch (pd()->attr()->round_mode_) {
case round_mode::down: a_fp = floorf(a_fp); break;
case round_mode::nearest: a_fp = nearbyintf(a_fp); break;
}
template <data_type_t diff_src_type, data_type_t wei_type,
data_type_t diff_dst_type, data_type_t acc_type>
void ref_convolution_bwd_data_t<diff_src_type, wei_type, diff_dst_type,
- acc_type>::execute_backward_data() {
+ acc_type>::execute_backward_data() const {
auto diff_dst = reinterpret_cast<const diff_dst_data_t*>(
this->input_memory(0));
auto weights = reinterpret_cast<const wei_data_t*>(this->input_memory(1));
auto bias = reinterpret_cast<const char *>(this->input_memory(2));
auto diff_src = reinterpret_cast<diff_src_data_t*>(this->memory());
- const memory_desc_wrapper diff_dst_d(conf_.diff_dst_pd());
- const memory_desc_wrapper diff_src_d(conf_.diff_src_pd());
- const memory_desc_wrapper weights_d(conf_.weights_pd(0));
- const memory_desc_wrapper bias_d(conf_.weights_pd(1));
+ const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
+ const memory_desc_wrapper diff_src_d(pd()->diff_src_pd());
+ const memory_desc_wrapper weights_d(pd()->weights_pd(0));
+ const memory_desc_wrapper bias_d(pd()->weights_pd(1));
- const bool with_groups = conf_.with_groups();
+ const bool with_groups = pd()->with_groups();
- const int G = conf_.G();
- const int MB = conf_.MB();
- const int OD = conf_.OD();
- const int OH = conf_.OH();
- const int OW = conf_.OW();
- const int ID = conf_.ID();
- const int IH = conf_.IH();
- const int IW = conf_.IW();
+ const int G = pd()->G();
+ const int MB = pd()->MB();
+ const int OD = pd()->OD();
+ const int OH = pd()->OH();
+ const int OW = pd()->OW();
+ const int ID = pd()->ID();
+ const int IH = pd()->IH();
+ const int IW = pd()->IW();
- const int OC = conf_.OC() / G;
- const int IC = conf_.IC() / G;
- const int KD = conf_.KD();
- const int KH = conf_.KH();
- const int KW = conf_.KW();
+ const int OC = pd()->OC() / G;
+ const int IC = pd()->IC() / G;
+ const int KD = pd()->KD();
+ const int KH = pd()->KH();
+ const int KW = pd()->KW();
- const int KSD = conf_.KSD();
- const int KSH = conf_.KSH();
- const int KSW = conf_.KSW();
+ const int KSD = pd()->KSD();
+ const int KSH = pd()->KSH();
+ const int KSW = pd()->KSW();
- const int KDD = conf_.KDD();
- const int KDH = conf_.KDH();
- const int KDW = conf_.KDW();
+ const int KDD = pd()->KDD();
+ const int KDH = pd()->KDH();
+ const int KDW = pd()->KDW();
- const int padFront = conf_.padFront();
- const int padT = conf_.padT();
- const int padL = conf_.padL();
+ const int padFront = pd()->padFront();
+ const int padT = pd()->padT();
+ const int padL = pd()->padL();
- const int ndims = conf_.cdesc()->diff_src_desc.ndims;
+ const int ndims = pd()->desc()->diff_src_desc.ndims;
- auto ker = [=](acc_data_t &d, int g, int mb, int ic, int id, int ih,
+ auto ker = [=](int g, int mb, int ic, int id, int ih,
int iw) {
+ acc_data_t d = 0;
for (int oc = 0; oc < OC; ++oc)
for (int kd = 0; kd < KD; ++kd)
for (int kh = 0; kh < KH; ++kh)
assert(false);
}
}
+ return d;
};
- auto get_bias = [=, &bias](size_t off) -> acc_data_t {
-# define CASE(dt) case dt: \
- return (acc_data_t)(*((const prec_traits<dt>::type *)bias + off))
- switch (conf_.desc()->bias_desc.data_type) {
- CASE(data_type::s8);
- CASE(data_type::u8);
- CASE(data_type::s32);
- CASE(data_type::f32);
- default: assert(!"unimplemented");
- }
-# undef CASE
- return 0;
- };
+
parallel_nd(G, MB, IC, ID, IH, IW,
[&](int g, int mb, int ic, int id, int ih, int iw) {
auto ds_idx = (ndims == 5)
: (ndims == 4)
? diff_src_d.off(mb, g*IC + ic, ih, iw)
: diff_src_d.off(mb, g*IC + ic, iw);
- acc_data_t a = bias
- ? get_bias(bias_d.off(g*IC + ic))
- : (acc_data_t)0;
- ker(a, g, mb, ic, id, ih, iw);
+ float a = bias
+ ? get_bias(bias, bias_d.off(g * IC + ic),
+ pd()->desc()->bias_desc.data_type)
+ : 0;
+ a += ker(g, mb, ic, id, ih, iw);
diff_src[ds_idx] = saturate<diff_src_data_t>(a);
});
}
template <data_type_t src_type, data_type_t diff_wei_type,
data_type_t diff_dst_type, data_type_t acc_type>
void ref_convolution_bwd_weights_t<src_type, diff_wei_type, diff_dst_type,
- acc_type>::execute_backward_weights() {
+ acc_type>::execute_backward_weights() const {
auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
auto diff_dst = reinterpret_cast<const diff_dst_data_t *>(
this->input_memory(1));
auto diff_weights = reinterpret_cast<diff_wei_data_t*>(this->memory(0));
auto diff_bias = reinterpret_cast<diff_wei_data_t *>(this->memory(1));
- const memory_desc_wrapper src_d(conf_.src_pd());
- const memory_desc_wrapper diff_dst_d(conf_.diff_dst_pd());
- const memory_desc_wrapper diff_weights_d(conf_.diff_weights_pd(0));
- const memory_desc_wrapper diff_bias_d(conf_.diff_weights_pd(1));
+ const memory_desc_wrapper src_d(pd()->src_pd());
+ const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
+ const memory_desc_wrapper diff_weights_d(pd()->diff_weights_pd(0));
+ const memory_desc_wrapper diff_bias_d(pd()->diff_weights_pd(1));
- const bool with_groups = conf_.with_groups();
+ const bool with_groups = pd()->with_groups();
- const int G = conf_.G();
- const int MB = conf_.MB();
- const int OD = conf_.OD();
- const int OH = conf_.OH();
- const int OW = conf_.OW();
- const int ID = conf_.ID();
- const int IH = conf_.IH();
- const int IW = conf_.IW();
+ const int G = pd()->G();
+ const int MB = pd()->MB();
+ const int OD = pd()->OD();
+ const int OH = pd()->OH();
+ const int OW = pd()->OW();
+ const int ID = pd()->ID();
+ const int IH = pd()->IH();
+ const int IW = pd()->IW();
- const int OC = conf_.OC() / G;
- const int IC = conf_.IC() / G;
- const int KD = conf_.KD();
- const int KH = conf_.KH();
- const int KW = conf_.KW();
+ const int OC = pd()->OC() / G;
+ const int IC = pd()->IC() / G;
+ const int KD = pd()->KD();
+ const int KH = pd()->KH();
+ const int KW = pd()->KW();
- const int KSD = conf_.KSD();
- const int KSH = conf_.KSH();
- const int KSW = conf_.KSW();
+ const int KSD = pd()->KSD();
+ const int KSH = pd()->KSH();
+ const int KSW = pd()->KSW();
- const int KDD = conf_.KDD();
- const int KDH = conf_.KDH();
- const int KDW = conf_.KDW();
+ const int KDD = pd()->KDD();
+ const int KDH = pd()->KDH();
+ const int KDW = pd()->KDW();
- const int padFront = conf_.padFront();
- const int padT = conf_.padT();
- const int padL = conf_.padL();
+ const int padFront = pd()->padFront();
+ const int padT = pd()->padT();
+ const int padL = pd()->padL();
- const int ndims = conf_.cdesc()->src_desc.ndims;
+ const int ndims = pd()->desc()->src_desc.ndims;
auto ker = [=](acc_data_t &d, int g, int oc, int ic, int kd, int kh, int kw) {
for (int mb = 0; mb < MB; ++mb)
parallel_nd(G, OC, [&](int g, int oc) {
if (diff_bias) {
+ // XXX: loss of precision when bias is a float...
acc_data_t db = 0;
ker_bias(db, g, oc);
diff_bias[diff_bias_d.off(g*OC+oc)]
using namespace data_type;
-template struct _ref_convolution_fwd_t<false, f32>;
-template struct _ref_convolution_fwd_t<true, f32>;
-template struct _ref_convolution_fwd_t<false, s16, s16, s32, s32>;
-template struct _ref_convolution_fwd_t<true, s16, s16, s32, s32>;
-
-template struct _ref_convolution_fwd_t<false, u8, s8, f32, s32>;
-template struct _ref_convolution_fwd_t<true, u8, s8, f32, s32>;
-template struct _ref_convolution_fwd_t<false, u8, s8, s32, s32>;
-template struct _ref_convolution_fwd_t<true, u8, s8, s32, s32>;
-template struct _ref_convolution_fwd_t<false, u8, s8, s8, s32>;
-template struct _ref_convolution_fwd_t<true, u8, s8, s8, s32>;
-template struct _ref_convolution_fwd_t<false, u8, s8, u8, s32>;
-template struct _ref_convolution_fwd_t<true, u8, s8, u8, s32>;
+template struct ref_convolution_fwd_t<f32>;
+template struct ref_convolution_fwd_t<s16, s16, s32, s32>;
+
+template struct ref_convolution_fwd_t<u8, s8, f32, s32>;
+template struct ref_convolution_fwd_t<u8, s8, s32, s32>;
+template struct ref_convolution_fwd_t<u8, s8, s8, s32>;
+template struct ref_convolution_fwd_t<u8, s8, u8, s32>;
template struct ref_convolution_bwd_data_t<f32, f32, f32, f32>;
template struct ref_convolution_bwd_data_t<s32, s16, s16, s32>;