template <cpu_isa_t isa>
jit_uni_lrn_fwd_t<isa>::jit_uni_lrn_fwd_t(
- const pd_t *pd,
+ const pd_t *apd,
const input_vector &inputs, const output_vector &outputs)
- : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd), ker_(nullptr)
+ : cpu_primitive_t(apd, inputs, outputs), ker_(nullptr)
, ker_first_(nullptr), ker_last_(nullptr)
{
using namespace alg_kind;
- const int C = conf_.C();
- const int H = conf_.H();
- const int W = conf_.W();
- const int ls = conf_.desc()->local_size;
- float A = conf_.desc()->lrn_alpha / ls;
- float K = conf_.desc()->lrn_k;
+ const int C = pd()->C();
+ const int H = pd()->H();
+ const int W = pd()->W();
+ const int ls = pd()->desc()->local_size;
+ float A = pd()->desc()->lrn_alpha / ls;
+ float K = pd()->desc()->lrn_k;
- auto pk = conf_.desc()->prop_kind;
- auto ak = conf_.desc()->alg_kind;
- auto dfmt = conf_.src_pd()->desc()->format;
+ auto pk = pd()->desc()->prop_kind;
+ auto ak = pd()->desc()->alg_kind;
+ auto dfmt = pd()->src_pd()->desc()->format;
if (dfmt == nChw8c && ls == 5 && ak == lrn_across_channels) {
ker_ = new jit_uni_lrn_fwd_kernel_f32<isa>(
{ delete ker_; delete ker_first_; delete ker_last_; }
template <cpu_isa_t isa>
-void jit_uni_lrn_fwd_t<isa>::execute_forward() {
+void jit_uni_lrn_fwd_t<isa>::execute_forward() const {
using namespace alg_kind;
auto src = reinterpret_cast<const data_t*>(this->input_memory(0));
auto dst = reinterpret_cast<data_t*>(this->memory(0));
auto ws = reinterpret_cast<data_t*>(this->memory(1));
- const int N = conf_.MB();
- const int C = conf_.C();
- const int HW = conf_.H() * conf_.W();
- const int ls = conf_.desc()->local_size;
+ const int N = pd()->MB();
+ const int C = pd()->C();
+ const int HW = pd()->H() * pd()->W();
+ const int ls = pd()->desc()->local_size;
- auto ak = conf_.desc()->alg_kind;
- auto dfmt = conf_.src_pd()->desc()->format;
+ auto ak = pd()->desc()->alg_kind;
+ auto dfmt = pd()->src_pd()->desc()->format;
if (dfmt == nChw8c && ls == 5 && ak == lrn_across_channels) {
parallel_nd(N, C / VECTOR_LENGTH, [&](int n, int c8) {
}
template <cpu_isa_t isa>
-jit_uni_lrn_bwd_t<isa>::jit_uni_lrn_bwd_t(const pd_t *pd,
+jit_uni_lrn_bwd_t<isa>::jit_uni_lrn_bwd_t(const pd_t *apd,
const input_vector &inputs, const output_vector &outputs)
- : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd)
+ : cpu_primitive_t(apd, inputs, outputs)
, ker_(nullptr), ker_first_(nullptr), ker_last_(nullptr)
{
using namespace alg_kind;
- const int C = conf_.C();
- const int H = conf_.H();
- const int W = conf_.W();
- const int ls = conf_.desc()->local_size;
- float A = conf_.desc()->lrn_alpha / ls;
- float B = conf_.desc()->lrn_beta;
+ const int C = pd()->C();
+ const int H = pd()->H();
+ const int W = pd()->W();
+ const int ls = pd()->desc()->local_size;
+ float A = pd()->desc()->lrn_alpha / ls;
+ float B = pd()->desc()->lrn_beta;
int use_h_parallelizm = 0;// XXX
if (C / VECTOR_LENGTH == 1) {
}
template <cpu_isa_t isa>
-void jit_uni_lrn_bwd_t<isa>::execute_backward() {
+void jit_uni_lrn_bwd_t<isa>::execute_backward() const {
auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
auto ws = reinterpret_cast<const data_t*>(this->input_memory(2));
auto diff_src = reinterpret_cast<data_t*>(this->memory(0));
- const int N = conf_.MB();
- const int C = conf_.C();
- const int H = conf_.H();
- const int W = conf_.W();
+ const int N = pd()->MB();
+ const int C = pd()->C();
+ const int H = pd()->H();
+ const int W = pd()->W();
int use_h_parallelizm = 0; // XXX
if (use_h_parallelizm) {