template <impl::data_type_t data_type>
template <mkldnn_memory_format_t fmt>
-void ref_lrn_fwd_t<data_type>::execute_forward() {
+void ref_lrn_fwd_t<data_type>::execute_forward() const {
using namespace alg_kind;
using namespace memory_format;
auto dst = reinterpret_cast<data_t*>(this->memory(0));
auto ws = reinterpret_cast<data_t*>(this->memory(1));
- const memory_desc_wrapper data_d(conf_.src_pd());
- const memory_desc_wrapper ws_d(conf_.workspace_pd());
+ const memory_desc_wrapper data_d(pd()->src_pd());
+ const memory_desc_wrapper ws_d(pd()->workspace_pd());
MAYBE_UNUSED(ws_d);
- const int C = conf_.C();
- const int H = conf_.H();
- const int W = conf_.W();
+ const int C = pd()->C();
+ const int H = pd()->H();
+ const int W = pd()->W();
const size_t stride_mb = data_d.blocking_desc().strides[0][0];
- const bool across_channels = conf_.desc()->alg_kind == lrn_across_channels;
+ const bool across_channels = pd()->desc()->alg_kind == lrn_across_channels;
constexpr int blksize = fmt == nChw16c ? 16 : 8;
auto data_off = [&](int mb, int c, int h, int w) -> size_t {
};
auto ker = [=](data_t *d, int mb, int oc, int oh, int ow) {
- const float alpha = static_cast<float>(conf_.desc()->lrn_alpha);
- const float beta = static_cast<float>(conf_.desc()->lrn_beta);
- const float k = static_cast<float>(conf_.desc()->lrn_k);
+ const float alpha = static_cast<float>(pd()->desc()->lrn_alpha);
+ const float beta = static_cast<float>(pd()->desc()->lrn_beta);
+ const float k = static_cast<float>(pd()->desc()->lrn_k);
- const int size = conf_.desc()->local_size;
+ const int size = pd()->desc()->local_size;
const int half_size = (size - 1) / 2;
float sum = 0;
d[0] = static_cast<data_t>(src[off] * fast_negative_powf(sum, beta));
};
- const int MB = conf_.MB();
+ const int MB = pd()->MB();
if (fmt == nChw16c || fmt == nChw8c) {
parallel_nd(MB, utils::div_up(C, blksize), H, W,
[&](int mb, int c_blk, int h, int w) {
template <impl::data_type_t data_type>
template <mkldnn_memory_format_t fmt>
-void ref_lrn_bwd_t<data_type>::execute_backward() {
+void ref_lrn_bwd_t<data_type>::execute_backward() const {
using namespace alg_kind;
using namespace memory_format;
auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
auto diff_src = reinterpret_cast<data_t*>(this->memory(0));
- const memory_desc_wrapper data_d(conf_.src_pd());
- const memory_desc_wrapper diff_data_d(conf_.diff_dst_pd());
+ const memory_desc_wrapper data_d(pd()->src_pd());
+ const memory_desc_wrapper diff_data_d(pd()->diff_dst_pd());
MAYBE_UNUSED(diff_data_d);
- const int MB = conf_.MB();
- const int C = conf_.C();
- const int H = conf_.H();
- const int W = conf_.W();
+ const int MB = pd()->MB();
+ const int C = pd()->C();
+ const int H = pd()->H();
+ const int W = pd()->W();
const size_t stride_mb = data_d.blocking_desc().strides[0][0];
constexpr int blksize = fmt == nChw16c ? 16 : 8;
- const float alpha = static_cast<float>(conf_.desc()->lrn_alpha);
- const float beta = static_cast<float>(conf_.desc()->lrn_beta);
- const float k = static_cast<float>(conf_.desc()->lrn_k);
- const int kernel_size = conf_.desc()->local_size;
+ const float alpha = static_cast<float>(pd()->desc()->lrn_alpha);
+ const float beta = static_cast<float>(pd()->desc()->lrn_beta);
+ const float k = static_cast<float>(pd()->desc()->lrn_k);
+ const int kernel_size = pd()->desc()->local_size;
const int half_ksize = (kernel_size - 1) / 2;
auto data_off = [&](int mb, int c, int h, int w) -> size_t {
}
}
-template void ref_lrn_fwd_t<data_type::f32>::execute_forward<memory_format::nChw16c>();
-template void ref_lrn_fwd_t<data_type::f32>::execute_forward<memory_format::nChw8c>();
-template void ref_lrn_fwd_t<data_type::f32>::execute_forward<memory_format::nchw>();
-template void ref_lrn_fwd_t<data_type::f32>::execute_forward<memory_format::nhwc>();
-template void ref_lrn_fwd_t<data_type::f32>::execute_forward<memory_format::any>();
-template void ref_lrn_bwd_t<data_type::f32>::execute_backward<memory_format::nChw16c>();
-template void ref_lrn_bwd_t<data_type::f32>::execute_backward<memory_format::nChw8c>();
-template void ref_lrn_bwd_t<data_type::f32>::execute_backward<memory_format::nchw>();
-template void ref_lrn_bwd_t<data_type::f32>::execute_backward<memory_format::nhwc>();
-template void ref_lrn_bwd_t<data_type::f32>::execute_backward<memory_format::any>();
+template void ref_lrn_fwd_t<data_type::f32>::execute_forward<memory_format::nChw16c>() const;
+template void ref_lrn_fwd_t<data_type::f32>::execute_forward<memory_format::nChw8c>() const;
+template void ref_lrn_fwd_t<data_type::f32>::execute_forward<memory_format::nchw>() const;
+template void ref_lrn_fwd_t<data_type::f32>::execute_forward<memory_format::nhwc>() const;
+template void ref_lrn_fwd_t<data_type::f32>::execute_forward<memory_format::any>() const;
+template void ref_lrn_bwd_t<data_type::f32>::execute_backward<memory_format::nChw16c>() const;
+template void ref_lrn_bwd_t<data_type::f32>::execute_backward<memory_format::nChw8c>() const;
+template void ref_lrn_bwd_t<data_type::f32>::execute_backward<memory_format::nchw>() const;
+template void ref_lrn_bwd_t<data_type::f32>::execute_backward<memory_format::nhwc>() const;
+template void ref_lrn_bwd_t<data_type::f32>::execute_backward<memory_format::any>() const;
}
}