using namespace mkldnn::impl::primitive_kind;
template <impl::data_type_t data_type>
-void gemm_inner_product_fwd_t<data_type>::execute_forward() {
+void gemm_inner_product_fwd_t<data_type>::execute_forward() const {
auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
auto bias = reinterpret_cast<const data_t *>(this->input_memory(2));
auto dst = reinterpret_cast<data_t*>(this->memory());
- const int MB = conf_.MB();
- const int OC = conf_.OC();
- const int IC = conf_.IC_total_padded();
+ const int MB = pd()->MB();
+ const int OC = pd()->OC();
+ const int IC = pd()->IC_total_padded();
- bool wei_tr = !utils::one_of(conf_.weights_pd()->desc()->format,
+ bool wei_tr = !utils::one_of(pd()->weights_pd()->desc()->format,
hwio, dhwio, io);
- const auto &post_ops = conf_.attr()->post_ops_;
+ const auto &post_ops = pd()->attr()->post_ops_;
const bool do_relu = post_ops.len_ == 1;
float alpha = 1.0, beta = 0.0;
}
template <impl::data_type_t data_type>
-void gemm_inner_product_bwd_data_t<data_type>::execute_backward_data() {
+void gemm_inner_product_bwd_data_t<data_type>::execute_backward_data() const {
auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(0));
auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
auto diff_src = reinterpret_cast<data_t*>(this->memory());
- const int MB = conf_.MB();
- const int OC = conf_.OC();
- const int IC = conf_.IC_total_padded();
+ const int MB = pd()->MB();
+ const int OC = pd()->OC();
+ const int IC = pd()->IC_total_padded();
- bool wei_tr = utils::one_of(conf_.weights_pd()->desc()->format,
+ bool wei_tr = utils::one_of(pd()->weights_pd()->desc()->format,
hwio, dhwio, io);
float alpha = 1.0, beta = 0.0;
}
template <impl::data_type_t data_type>
-void gemm_inner_product_bwd_weights_t<data_type>::execute_backward_weights() {
+void gemm_inner_product_bwd_weights_t<data_type>::execute_backward_weights() const {
auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
auto diff_weights = reinterpret_cast<data_t *>(this->memory(0));
auto diff_bias = reinterpret_cast<data_t *>(this->memory(1));
- const memory_desc_wrapper diff_dst_d(conf_.diff_dst_pd());
- const memory_desc_wrapper diff_bias_d(conf_.diff_weights_pd(1));
+ const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
+ const memory_desc_wrapper diff_bias_d(pd()->diff_weights_pd(1));
diff_dst += diff_dst_d.blocking_desc().offset_padding;
- const int MB = conf_.MB();
- const int OC = conf_.OC();
- const int IC = conf_.IC_total_padded();
+ const int MB = pd()->MB();
+ const int OC = pd()->OC();
+ const int IC = pd()->IC_total_padded();
- bool wei_tr = utils::one_of(conf_.diff_weights_pd()->desc()->format,
+ bool wei_tr = utils::one_of(pd()->diff_weights_pd()->desc()->format,
hwio, dhwio, io);
float alpha = 1.0, beta = 0.0;