}
};
- ref_shuffle_t(const pd_t *pd, const input_vector &inputs,
+ ref_shuffle_t(const pd_t *apd, const input_vector &inputs,
const output_vector &outputs)
- : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd)
+ : cpu_primitive_t(apd, inputs, outputs)
{
- const int axis_size = conf_.axis_size();
- const int group_size = conf_.group_size();
- const int transpose_row = conf_.is_fwd() ? group_size
+ const int axis_size = pd()->axis_size();
+ const int group_size = pd()->group_size();
+ const int transpose_row = pd()->is_fwd() ? group_size
: axis_size / group_size;
- const int transpose_col = conf_.is_fwd() ? axis_size / group_size
+ const int transpose_col = pd()->is_fwd() ? axis_size / group_size
: group_size;
rev_transposed_ = (int *)malloc(axis_size * sizeof(int), 64);
parallel_nd(transpose_col, transpose_row, [&](int i, int j) {
typedef typename typesize_traits<data_type_size>::type data_t;
- virtual void execute(event_t *e) {
+ virtual void execute(event_t *e) const {
using namespace memory_format;
- switch (conf_.data_pd()->desc()->format) {
+ switch (pd()->data_pd()->desc()->format) {
case nCdhw16c: execute_<nCdhw16c>(); break;
case nChw16c: execute_<nChw16c>(); break;
case nCdhw8c: execute_<nCdhw8c>(); break;
}
private:
- template<memory_format_t fmt>void execute_();
- pd_t conf_;
+ template<memory_format_t fmt>void execute_() const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
int *rev_transposed_;
};