#define CPU_JIT_AVX512_CORE_FP32_WINO_CONV_4x3_HPP
#include "c_types_map.hpp"
+#include "memory_tracking.hpp"
#include "cpu_convolution_pd.hpp"
#include "cpu_engine.hpp"
-#include "scratchpad.hpp"
#include "jit_avx512_core_fp32_wino_conv_4x3_kernel.hpp"
namespace impl {
namespace cpu {
-namespace winograd {
-
-struct winograd_scratchpad_avx512_core_t {
- public:
- winograd_scratchpad_avx512_core_t(const jit_conv_winograd_conf_t &jcp)
- {
- get_scratchpad_size_(jcp);
- allocate_scratchpad_(jcp);
- }
-
- ~winograd_scratchpad_avx512_core_t() {
- if (scratchpad_ != nullptr)
- delete scratchpad_;
- }
-
- char *U_ptr() {
- /* buffer for wei transform U*/
- return scratchpad_->get() + U_offset_;
- }
-
- char *V_ptr() {
- /* buffer for src transform V*/
- return scratchpad_->get() + V_offset_;
- }
-
- char *M_ptr() {
- /* buffer for dst transform M*/
- return scratchpad_->get() + M_offset_;
- }
-
- char *bias_ptr() {
- /* buffer for bias update in bwdw*/
- return scratchpad_->get() + bias_offset_;
- }
-
- int num_threads(){
- return nthreads_;
- }
-
- private:
- inline void get_scratchpad_size_(const jit_conv_winograd_conf_t &jcp) {
- nthreads_ = mkldnn_get_max_threads();
-
- U_sz_ = size_t(alpha) * alpha * jcp.ic * jcp.oc * sizeof(float);
- V_sz_ = size_t(alpha) * alpha * jcp.mb * jcp.ic
- * jcp.itiles * jcp.jtiles
- * sizeof(float);
- M_sz_ = size_t(alpha) * alpha * jcp.mb * jcp.oc
- * jcp.itiles * jcp.jtiles
- * sizeof(float);
-
- switch (jcp.sched_policy) {
- case WSCHED_DATA_W_SGD:
- V_sz_ = nthreads_ * alpha * alpha
- * jcp.nb_tile_block_ur * jcp.tile_block_ur
- * jcp.ic * sizeof(float);
- M_sz_ = nthreads_* alpha * alpha
- * jcp.nb_tile_block_ur * jcp.tile_block_ur
- * jcp.oc * sizeof(float);
- break;
- case WSCHED_WEI_SDGtWo:
- nthreads_ = nstl::min(mkldnn_get_max_threads(), jcp.tile_block);
-
- U_sz_ = nthreads_
- * (alpha * alpha * jcp.oc * (jcp.ic / jcp.nb_ic)
- + jcp.ic * jcp.oc * jcp.kh * jcp.kw)
- * sizeof(float);
- M_sz_ = nthreads_ * alpha * alpha
- * (jcp.ntiles / jcp.tile_block)
- * (jcp.oc / jcp.nb_oc) * sizeof(float);
- V_sz_ = nthreads_ * alpha * alpha
- * (jcp.ntiles / jcp.tile_block)
- * (jcp.ic / jcp.nb_ic)
- * sizeof(float);
- bias_sz_ = nthreads_ * jcp.oc * sizeof(float);
- break;
- case WSCHED_WEI_S_D_Giot_W:
- U_sz_ = (nthreads_ + 1) * alpha * alpha * jcp.ic * jcp.oc
- * sizeof(float);
- M_sz_ = size_t(alpha) * alpha * jcp.oc * jcp.ntiles * sizeof(float);
- V_sz_ = size_t(alpha) * alpha * jcp.ic * jcp.ntiles * sizeof(float);
- bias_sz_ = nthreads_ * jcp.oc * sizeof(float);
- break;
- default:
- break;
- }
- }
+namespace winograd_avx512_core {
+inline void init_scratchpad(memory_tracking::registrar_t &scratchpad,
+ const jit_conv_winograd_conf_t &jcp) {
+ using namespace utils;
+ using namespace memory_tracking::names;
+
+ size_t U_sz = (size_t)alpha * alpha * jcp.ic * jcp.oc;
+ size_t V_sz = (size_t)alpha * alpha * jcp.mb * jcp.ic * jcp.itiles
+ * jcp.jtiles;
+ size_t M_sz = (size_t)alpha * alpha * jcp.mb * jcp.oc * jcp.itiles
+ * jcp.jtiles;
+
+ switch (jcp.sched_policy) {
+ case WSCHED_DATA_W_SGD:
+ V_sz = (size_t)jcp.nthr * alpha * alpha * jcp.nb_tile_block_ur
+ * jcp.tile_block_ur * jcp.ic;
+ M_sz = (size_t)jcp.nthr * alpha * alpha * jcp.nb_tile_block_ur
+ * jcp.tile_block_ur * jcp.oc;
+ break;
+ case WSCHED_WEI_SDGtWo:
+ U_sz = (size_t)jcp.nthr * (alpha * alpha * jcp.oc
+ * (jcp.ic / jcp.nb_ic) + jcp.ic * jcp.oc * jcp.kh * jcp.kw);
+ M_sz = (size_t)jcp.nthr * alpha * alpha * (jcp.ntiles / jcp.tile_block)
+ * (jcp.oc / jcp.nb_oc);
+ V_sz = (size_t)jcp.nthr * alpha * alpha * (jcp.ntiles / jcp.tile_block)
+ * (jcp.ic / jcp.nb_ic);
+ break;
+ case WSCHED_WEI_S_D_Giot_W:
+ U_sz = (size_t)(jcp.nthr + 1) * alpha * alpha * jcp.ic * jcp.oc;
+ M_sz = (size_t)alpha * alpha * jcp.oc * jcp.ntiles;
+ V_sz = (size_t)alpha * alpha * jcp.ic * jcp.ntiles;
+ break;
+ default: break;
+ }
- inline void allocate_scratchpad_(const jit_conv_winograd_conf_t &jcp) {
- const size_t page_size = PAGE_2M;
- U_offset_ = 0;
- V_offset_ = utils::rnd_up(U_sz_, page_size);
- M_offset_ = V_offset_ + utils::rnd_up(V_sz_, page_size);
- scratchpad_sz_ = M_offset_ + M_sz_;
- if (bias_sz_) {
- bias_offset_ = M_offset_ + utils::rnd_up(M_sz_, page_size);
- scratchpad_sz_ = bias_offset_ + bias_sz_;
- }
- scratchpad_ = create_scratchpad(scratchpad_sz_);
- }
+ scratchpad.book(key_wino_U, sizeof(float) * U_sz, PAGE_2M);
+ scratchpad.book(key_wino_V, sizeof(float) * V_sz, PAGE_2M);
+ scratchpad.book(key_wino_M, sizeof(float) * M_sz, PAGE_2M);
- scratchpad_t *scratchpad_;
- size_t nthreads_;
- size_t scratchpad_sz_ = 0, U_sz_ = 0, V_sz_ = 0, M_sz_ = 0,
- bias_sz_ = 0;
- size_t U_offset_ = 0;
- size_t V_offset_ = 0;
- size_t M_offset_ = 0;
- size_t bias_offset_ = 0;
-};
+ if (one_of(jcp.sched_policy, WSCHED_WEI_SDGtWo, WSCHED_WEI_S_D_Giot_W)) {
+ size_t br_sz = (size_t)jcp.nthr * jcp.oc;
+ scratchpad.book(key_conv_bia_reduction, sizeof(float) * br_sz, PAGE_2M);
+ }
+}
}
template <bool is_fwd>
_jit_avx512_core_fp32_wino_conv_4x3_t(
const jit_conv_winograd_conf_t &jcp, const primitive_attr_t *attr)
- : kernel_(nullptr), scratchpad_(nullptr), attr_(attr) {
+ : kernel_(nullptr), attr_(attr) {
kernel_ = new _jit_avx512_core_fp32_wino_conv_4x3_data_kernel(jcp);
- scratchpad_ = new winograd::winograd_scratchpad_avx512_core_t(jcp);
}
- ~_jit_avx512_core_fp32_wino_conv_4x3_t() {
- delete kernel_;
- delete scratchpad_;
- };
+ ~_jit_avx512_core_fp32_wino_conv_4x3_t() { delete kernel_; }
protected:
void weight_transform_data(const jit_conv_winograd_conf_t &jcp,
- float *wp, float *twp);
+ float *wp, float *twp) const;
void input_transform_data(int image,
const jit_conv_winograd_conf_t &jcp,
- float *inp, float *tinp);
+ float *inp, float *tinp) const;
void input_transform_tileblock_data(int tile_block,
const jit_conv_winograd_conf_t &jcp,
- float *inp, float *tinp);
+ float *inp, float *tinp) const;
void output_transform_data(int image,
const jit_conv_winograd_conf_t &jcp,
- const post_ops_t &p_ops, float *toutp, float *pout_b, float *bias);
+ const post_ops_t &p_ops, float *toutp, float *pout_b,
+ float *bias) const;
void output_transform_tileblock_data(int tile_block,
const jit_conv_winograd_conf_t &jcp, const post_ops_t &p_ops,
- float *toutp, float *outp, float *bias);
+ float *toutp, float *outp, float *bias) const;
void _execute_data_W_S_G_D(const int MB, float *inp_ptr, float *out_ptr,
- float *wei_ptr, float *bias_ptr = NULL);
+ float *wei_ptr, float *bias_ptr,
+ const memory_tracking::grantor_t &scratchpad) const;
void _execute_data_W_SGD(const int MB, float *inp_ptr, float *out_ptr,
- float *wei_ptr, float *bias_ptr = NULL);
+ float *wei_ptr, float *bias_ptr,
+ const memory_tracking::grantor_t &scratchpad) const;
_jit_avx512_core_fp32_wino_conv_4x3_data_kernel *kernel_;
- // Buffer required to store transforms in the frequency domain
- winograd::winograd_scratchpad_avx512_core_t *scratchpad_;
const primitive_attr_t *attr_;
};
-template <bool with_relu>
-struct _jit_avx512_core_fp32_wino_conv_4x3_fwd_t
+struct jit_avx512_core_fp32_wino_conv_4x3_fwd_t
: _jit_avx512_core_fp32_wino_conv_4x3_t<true>
, public cpu_primitive_t
{
- struct pd_t : public _cpu_convolution_fwd_pd_t<with_relu> {
- pd_t(engine_t *engine, const typename pd_t::base_desc_t *adesc,
+ struct pd_t : public cpu_convolution_fwd_pd_t {
+ pd_t(engine_t *engine, const convolution_desc_t *adesc,
const primitive_attr_t *attr,
const typename pd_t::base_class *hint_fwd_pd)
- : _cpu_convolution_fwd_pd_t<with_relu>(engine, adesc, attr,
- hint_fwd_pd)
+ : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
, jcp_() {}
DECLARE_COMMON_PD_T(
JIT_IMPL_NAME_HELPER("jit_wino_4x3:", avx512_core, ""),
- _jit_avx512_core_fp32_wino_conv_4x3_fwd_t<with_relu>);
+ jit_avx512_core_fp32_wino_conv_4x3_fwd_t);
virtual status_t init() override
{
using namespace prop_kind;
assert(this->engine()->kind() == engine_kind::cpu);
bool ok = true && this->set_default_params() == status::success
- && utils::one_of(this->cdesc_().prop_kind, forward_training,
+ && utils::one_of(this->desc()->prop_kind, forward_training,
forward_inference)
- && this->cdesc_().alg_kind == alg_kind::convolution_winograd
+ && utils::one_of(this->desc()->alg_kind,
+ alg_kind::convolution_auto,
+ alg_kind::convolution_winograd)
&& utils::everyone_is(data_type::f32,
- this->cdesc_().src_desc.data_type,
- this->cdesc_().weights_desc.data_type,
- this->cdesc_().dst_desc.data_type)
+ this->desc()->src_desc.data_type,
+ this->desc()->weights_desc.data_type,
+ this->desc()->dst_desc.data_type)
&& IMPLICATION(this->with_bias(), data_type::f32
- == this->cdesc_().bias_desc.data_type)
+ == this->desc()->bias_desc.data_type)
&& mkldnn_thr_syncable();
if (!ok)
return status::unimplemented;
- return jit_avx512_core_fp32_wino_conv_4x3_fwd_kernel::init_conf(jcp_,
- this->cdesc_(), this->src_pd_,
- this->weights_pd_, this->dst_pd_,
- *this->attr(), with_relu, this->negative_slope());
+ status_t status =
+ jit_avx512_core_fp32_wino_conv_4x3_fwd_kernel::init_conf(jcp_,
+ *this->desc(), this->src_pd_, this->weights_pd_,
+ this->dst_pd_, *this->attr());
+ if (status != status::success) return status;
+
+ auto scratchpad = this->scratchpad_registry().registrar();
+ winograd_avx512_core::init_scratchpad(scratchpad, jcp_);
+ if (status == status::success
+ && this->desc()->alg_kind == alg_kind::convolution_auto)
+ CHECK(this->set_alg_kind(alg_kind::convolution_winograd));
+
+ return status;
}
jit_conv_winograd_conf_t jcp_;
if (this->dst_pd_.desc()->format == any)
CHECK(this->dst_pd_.set_format(nChw16c));
if (this->weights_pd_.desc()->format == any
- && (this->cdesc_().prop_kind != mkldnn_forward_inference))
+ && (this->desc()->prop_kind != mkldnn_forward_inference))
CHECK(this->weights_pd_.set_format(
this->with_groups() ? gOIhw16i16o : OIhw16i16o));
if (this->bias_pd_.desc()->format == any)
}
};
- _jit_avx512_core_fp32_wino_conv_4x3_fwd_t(const pd_t *pd,
+ jit_avx512_core_fp32_wino_conv_4x3_fwd_t(const pd_t *apd,
const input_vector &inputs, const output_vector &outputs)
- : _jit_avx512_core_fp32_wino_conv_4x3_t<true>(pd->jcp_, pd->attr())
- , cpu_primitive_t(&conf_, inputs, outputs)
- , conf_(*pd) {}
+ : _jit_avx512_core_fp32_wino_conv_4x3_t<true>(apd->jcp_, apd->attr())
+ , cpu_primitive_t(apd, inputs, outputs, true)
+ {}
- ~_jit_avx512_core_fp32_wino_conv_4x3_fwd_t(){};
+ ~jit_avx512_core_fp32_wino_conv_4x3_fwd_t(){};
typedef typename prec_traits<data_type::f32>::type data_t;
- virtual void execute(event_t *e)
+ virtual void execute(event_t *e) const
{
float *src = (float *)this->input_memory(0);
float *dst = (float *)this->memory();
float *weights = (float *)this->input_memory(1);
float *bias = (float *)this->input_memory(2);
+ auto scratchpad = this->scratchpad();
- switch ((conf_.jcp_).sched_policy) {
+ switch ((pd()->jcp_).sched_policy) {
case WSCHED_DATA_W_S_G_D:
- this->_execute_data_W_S_G_D(conf_.MB(), src, dst, weights, bias);
+ this->_execute_data_W_S_G_D(pd()->MB(), src, dst, weights, bias, scratchpad);
break;
case WSCHED_DATA_W_SGD:
- this->_execute_data_W_SGD(conf_.MB(), src, dst, weights, bias);
+ this->_execute_data_W_SGD(pd()->MB(), src, dst, weights, bias, scratchpad);
break;
default:
break;
}
private:
- pd_t conf_;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
};
-using jit_avx512_core_fp32_wino_conv_4x3_fwd_t
- = _jit_avx512_core_fp32_wino_conv_4x3_fwd_t<false>;
-using jit_avx512_core_fp32_wino_conv_4x3_relu_t
- = _jit_avx512_core_fp32_wino_conv_4x3_fwd_t<true>;
-
struct jit_avx512_core_fp32_wino_conv_4x3_bwd_data_t
: _jit_avx512_core_fp32_wino_conv_4x3_t<false>,
public cpu_primitive_t {
assert(this->engine()->kind() == engine_kind::cpu);
bool ok = true && this->set_default_params() == status::success
&& utils::one_of(this->desc()->prop_kind, backward_data)
- && this->desc()->alg_kind == alg_kind::convolution_winograd
+ && utils::one_of(this->desc()->alg_kind,
+ alg_kind::convolution_auto,
+ alg_kind::convolution_winograd)
&& utils::everyone_is(data_type::f32,
this->desc()->diff_src_desc.data_type,
this->desc()->weights_desc.data_type,
if (!ok)
return status::unimplemented;
- return jit_avx512_core_fp32_wino_conv_4x3_bwd_data_kernel::
- init_conf(jcp_, *this->desc(), *this->diff_src_pd_.desc(),
- *this->weights_pd_.desc(),
- *this->diff_dst_pd_.desc());
+ status_t status =
+ jit_avx512_core_fp32_wino_conv_4x3_bwd_data_kernel::init_conf(
+ jcp_, *this->desc(), *this->diff_src_pd_.desc(),
+ *this->weights_pd_.desc(), *this->diff_dst_pd_.desc());
+ if (status != status::success) return status;
+
+ auto scratchpad = this->scratchpad_registry().registrar();
+ winograd_avx512_core::init_scratchpad(scratchpad, jcp_);
+
+ if (status == status::success
+ && this->desc()->alg_kind == alg_kind::convolution_auto)
+ CHECK(this->set_alg_kind(alg_kind::convolution_winograd));
+
+ return status;
}
jit_conv_winograd_conf_t jcp_;
}
};
- jit_avx512_core_fp32_wino_conv_4x3_bwd_data_t(const pd_t *pd,
+ jit_avx512_core_fp32_wino_conv_4x3_bwd_data_t(const pd_t *apd,
const input_vector &inputs, const output_vector &outputs)
- : _jit_avx512_core_fp32_wino_conv_4x3_t<false>(pd->jcp_, pd->attr())
- , cpu_primitive_t(&conf_, inputs, outputs)
- , conf_(*pd) {}
+ : _jit_avx512_core_fp32_wino_conv_4x3_t<false>(apd->jcp_, apd->attr())
+ , cpu_primitive_t(apd, inputs, outputs, true)
+ {}
~jit_avx512_core_fp32_wino_conv_4x3_bwd_data_t(){};
typedef typename prec_traits<data_type::f32>::type data_t;
- virtual void execute(event_t *e)
+ virtual void execute(event_t *e) const
{
float *diff_dst = (float *)this->input_memory(0);
float *diff_src = (float *)this->memory();
float *weights = (float *)this->input_memory(1);
+ auto scratchpad = this->scratchpad();
- if (conf_.desc()->prop_kind == prop_kind::backward_data) {
- switch ((conf_.jcp_).sched_policy) {
+ if (pd()->desc()->prop_kind == prop_kind::backward_data) {
+ switch ((pd()->jcp_).sched_policy) {
case WSCHED_DATA_W_S_G_D:
- this->_execute_data_W_S_G_D(conf_.MB(), diff_dst, diff_src, weights, NULL);
+ this->_execute_data_W_S_G_D(pd()->MB(), diff_dst, diff_src, weights, NULL,
+ scratchpad);
break;
case WSCHED_DATA_W_SGD:
- this->_execute_data_W_SGD(conf_.MB(), diff_dst, diff_src, weights, NULL);
+ this->_execute_data_W_SGD(pd()->MB(), diff_dst, diff_src, weights, NULL,
+ scratchpad);
break;
default:
}
private:
- pd_t conf_;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
};
struct jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t
assert(this->engine()->kind() == engine_kind::cpu);
bool ok = true && this->set_default_params() == status::success
&& utils::one_of(this->desc()->prop_kind, backward_weights)
- && this->desc()->alg_kind == alg_kind::convolution_winograd
+ && utils::one_of(this->desc()->alg_kind,
+ alg_kind::convolution_auto,
+ alg_kind::convolution_winograd)
&& utils::everyone_is(data_type::f32,
this->desc()->src_desc.data_type,
this->desc()->diff_dst_desc.data_type,
if (!ok)
return status::unimplemented;
- return jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel::
- init_conf(jcp_, *this->desc(), *this->src_pd_.desc(),
- *this->diff_dst_pd_.desc(),
- *this->diff_weights_pd_.desc());
+ status_t status =
+ jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel::
+ init_conf(jcp_, *this->desc(), *this->src_pd_.desc(),
+ *this->diff_dst_pd_.desc(),
+ *this->diff_weights_pd_.desc());
+ if (status != status::success) return status;
+
+ auto scratchpad = this->scratchpad_registry().registrar();
+ winograd_avx512_core::init_scratchpad(scratchpad, jcp_);
+
+ if (status == status::success
+ && this->desc()->alg_kind == alg_kind::convolution_auto)
+ CHECK(this->set_alg_kind(alg_kind::convolution_winograd));
+
+ return status;
}
jit_conv_winograd_conf_t jcp_;
}
};
- jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t(const pd_t *pd,
+ jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t(const pd_t *apd,
const input_vector &inputs, const output_vector &outputs)
- : cpu_primitive_t(&conf_, inputs, outputs)
- , conf_(*pd)
+ : cpu_primitive_t(apd, inputs, outputs, true)
, kernel_(nullptr)
- , scratchpad_(nullptr)
{
- auto jcp = conf_.jcp_;
kernel_ = new jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel(
- jcp);
- scratchpad_ = new winograd::winograd_scratchpad_avx512_core_t(jcp);
+ pd()->jcp_);
}
~jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t()
{
delete kernel_;
- delete scratchpad_;
};
typedef typename prec_traits<data_type::f32>::type data_t;
- virtual void execute(event_t *e)
+ virtual void execute(event_t *e) const
{
- if (conf_.desc()->prop_kind == prop_kind::backward_weights) {
+ if (pd()->desc()->prop_kind == prop_kind::backward_weights) {
const auto &jcp = kernel_->jcp;
switch (jcp.sched_policy) {
case WSCHED_WEI_SDGtWo:
- _execute_backward_weights_SDGtWo();
+ _execute_backward_weights_SDGtWo(scratchpad());
break;
case WSCHED_WEI_S_D_Giot_W:
- _execute_backward_weights_S_D_Giot_W();
+ _execute_backward_weights_S_D_Giot_W(scratchpad());
break;
default:
assert(jcp.sched_policy != WSCHED_INVALID);
}
private:
- void _execute_backward_weights_SDGtWo();
- void _execute_backward_weights_S_D_Giot_W();
+ void _execute_backward_weights_SDGtWo(
+ const memory_tracking::grantor_t &scratchpad) const;
+ void _execute_backward_weights_S_D_Giot_W(
+ const memory_tracking::grantor_t &scratchpad) const;
- pd_t conf_;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel *kernel_;
-
- // Buffer required to store transforms in the frequency domain
- winograd::winograd_scratchpad_avx512_core_t *scratchpad_;
};
}
}