#define CPU_JIT_AVX512_COMMON_CONVOLUTION_WINOGRAD_HPP
#include "c_types_map.hpp"
+#include "memory_tracking.hpp"
#include "cpu_convolution_pd.hpp"
#include "cpu_engine.hpp"
-#include "scratchpad.hpp"
#include "mkldnn_thread.hpp"
#include "jit_avx512_common_conv_winograd_kernel_f32.hpp"
namespace impl {
namespace cpu {
-namespace winograd {
+namespace winograd_avx512_common {
+inline void init_scratchpad(memory_tracking::registrar_t &scratchpad,
+ const jit_conv_winograd_conf_t &jcp) {
+ using namespace memory_tracking::names;
-struct winograd_scratchpad_t {
- public:
- winograd_scratchpad_t(const jit_conv_winograd_conf_t &jcp)
- {
- get_scratchpad_size_(jcp);
- allocate_scratchpad_(jcp);
- }
-
- ~winograd_scratchpad_t() {
- if (scratchpad_ != nullptr)
- delete scratchpad_;
- }
-
- char *U_ptr() {
- /* buffer for wei transform U*/
- return scratchpad_->get() + U_offset_;
- }
+ size_t U_sz = (size_t)alpha * alpha * jcp.ic * jcp.oc;
+ size_t V_sz = (size_t)alpha * alpha * jcp.mb * jcp.ic
+ * (jcp.itiles * jcp.jtiles + jcp.tile_4fma_padding);
+ size_t M_sz = (size_t)alpha * alpha * jcp.mb * jcp.oc
+ * (jcp.itiles * jcp.jtiles + jcp.tile_4fma_padding);
- char *V_ptr() {
- /* buffer for src transform V*/
- return scratchpad_->get() + V_offset_;
- }
-
- char *M_ptr() {
- /* buffer for dst transform M*/
- return scratchpad_->get() + M_offset_;
- }
+ scratchpad.book(key_wino_U, sizeof(float) * U_sz, PAGE_2M);
+ scratchpad.book(key_wino_V, sizeof(float) * V_sz, PAGE_2M);
+ scratchpad.book(key_wino_M, sizeof(float) * M_sz, PAGE_2M);
- char *bias_ptr() {
- /* buffer for bias update in bwdw*/
- return scratchpad_->get() + bias_offset_;
- }
+ if (jcp.sched_policy == WSCHED_WEI_S_D_G_W) {
+ const int nthr = mkldnn_get_max_threads();
- char *src_transpose_ptr() {
- /* buffer for src transpose in bwdw using qfma*/
- return scratchpad_->get() + src_transpose_offset_;
- }
+ size_t tr_src_sz = jcp.ver != ver_4fma ? 0 : (size_t)nthr
+ * alpha * alpha * jcp.tile_4fma * jcp.ic_simd_block;
+ scratchpad.book(key_conv_tr_src, sizeof(float) * tr_src_sz, PAGE_2M);
- int num_threads(){
- return nthreads_;
- }
+ size_t br_sz = jcp.with_bias ? nthr * jcp.oc : 0;
+ scratchpad.book(key_conv_bia_reduction, sizeof(float) * br_sz, PAGE_2M);
- private:
- inline void get_scratchpad_size_(const jit_conv_winograd_conf_t &jcp) {
- nthreads_ = mkldnn_get_max_threads();
-
- U_sz_ = (size_t)alpha * alpha * jcp.ic * jcp.oc * sizeof(float);
- V_sz_ = (size_t)alpha * alpha * jcp.mb * jcp.ic
- * (jcp.itiles * jcp.jtiles + jcp.tile_4fma_padding)
- * sizeof(float);
- M_sz_ = (size_t)alpha * alpha * jcp.mb * jcp.oc
- * (jcp.itiles * jcp.jtiles + jcp.tile_4fma_padding)
- * sizeof(float);
-
- switch (jcp.sched_policy) {
- case WSCHED_DATA_W_SGD:
- V_sz_ = (size_t)nthreads_ * alpha * alpha
- * jcp.nb_tile_block_ur * jcp.tile_block_ur
- * jcp.ic * sizeof(float);
- M_sz_ = (size_t)nthreads_* alpha * alpha
- * jcp.nb_tile_block_ur * jcp.tile_block_ur
- * jcp.oc * sizeof(float);
- break;
- case WSCHED_WEI_SDGt_W:
- U_sz_ = (size_t)nthreads_ * U_sz_;
- V_sz_ = (size_t)nthreads_ * alpha * alpha
- * (jcp.nb_tile_block_ur * jcp.tile_block_ur
- + jcp.tile_4fma_padding)
- * jcp.ic * sizeof(float);
- M_sz_ = (size_t)nthreads_ * alpha * alpha
- * (jcp.nb_tile_block_ur * jcp.tile_block_ur
- + jcp.tile_4fma_padding)
- * jcp.oc * sizeof(float);
- bias_sz_ = nthreads_ * jcp.oc * sizeof(float);
- break;
- case WSCHED_WEI_SDGtWo:
- U_sz_ = (size_t)nthreads_ * alpha * alpha
- * jcp.oc_block * jcp.oc_simd_block * jcp.ic * sizeof(float);
- M_sz_ = (size_t)nthreads_ * alpha * alpha
- * (jcp.nb_tile_block_ur * jcp.tile_block_ur
- + jcp.tile_4fma_padding)
- * jcp.oc_simd_block * jcp.oc_block * sizeof(float);
- bias_sz_ = nthreads_ * jcp.oc * sizeof(float);
- break;
- case WSCHED_WEI_S_D_Giot_W:
- U_sz_ = (size_t)(nthreads_ + 1) * alpha * alpha
- * jcp.ic * jcp.oc * sizeof(float);
- V_sz_ = (size_t)alpha * alpha
- * (jcp.itiles * jcp.jtiles + jcp.tile_4fma_padding)
- * jcp.ic * jcp.mb * sizeof(float);
- M_sz_ = (size_t)alpha * alpha
- * (jcp.itiles * jcp.jtiles + jcp.tile_4fma_padding)
- * jcp.oc * jcp.mb * sizeof(float);
- bias_sz_ = nthreads_ * jcp.oc * sizeof(float);
- src_transpose_sz_ = jcp.ver == ver_4fma
- ? ((size_t)nthreads_ * alpha * alpha
- * jcp.tile_4fma
- * jcp.ic_simd_block * sizeof(float))
- : 0;
- break;
- case WSCHED_WEI_S_D_G_W:
- src_transpose_sz_ = jcp.ver == ver_4fma
- ? ((size_t)nthreads_ * alpha * alpha
- * jcp.tile_4fma
- * jcp.ic_simd_block * sizeof(float))
- : 0;
- bias_sz_ = jcp.with_bias ? nthreads_ * jcp.oc * sizeof(float) : 0;
- break;
- default:
- break;
- }
- }
-
- inline void allocate_scratchpad_(const jit_conv_winograd_conf_t &jcp) {
- const size_t page_size = PAGE_2M;
- U_offset_ = 0;
- V_offset_ = utils::rnd_up(U_sz_, page_size);
- M_offset_ = V_offset_ + utils::rnd_up(V_sz_, page_size);
- scratchpad_sz_ = M_offset_ + M_sz_;
- if (src_transpose_sz_) {
- src_transpose_offset_ = M_offset_
- + utils::rnd_up(M_sz_, page_size);
- scratchpad_sz_ = src_transpose_offset_ + src_transpose_sz_;
- }
- if (bias_sz_) {
- bias_offset_ = src_transpose_sz_
- ? src_transpose_offset_
- + utils::rnd_up(src_transpose_sz_, page_size)
- : M_offset_ + utils::rnd_up(M_sz_, page_size);
- scratchpad_sz_ = bias_offset_ + bias_sz_;
- }
- scratchpad_ = create_scratchpad(scratchpad_sz_);
- }
-
- scratchpad_t *scratchpad_;
- int nthreads_;
- size_t scratchpad_sz_ = 0, U_sz_ = 0, V_sz_ = 0, M_sz_ = 0,
- bias_sz_ = 0, src_transpose_sz_ = 0;
- size_t U_offset_ = 0;
- size_t V_offset_ = 0;
- size_t M_offset_ = 0;
- size_t bias_offset_ = 0;
- size_t src_transpose_offset_ = 0; // only relevant for bwdw using qfma
-};
+ size_t padded_bias_sz =
+ jcp.with_bias && jcp.oc_without_padding != jcp.oc ? jcp.oc : 0;
+ scratchpad.book(key_conv_padded_bias, sizeof(float) * padded_bias_sz);
+ }
+}
}
template <bool is_fwd>
_jit_avx512_common_convolution_winograd_t(
const jit_conv_winograd_conf_t &jcp, const primitive_attr_t *attr)
- : kernel_(nullptr), scratchpad_(nullptr), attr_(attr) {
+ : kernel_(nullptr), attr_(attr) {
kernel_ = new _jit_avx512_common_conv_winograd_data_kernel_f32(jcp);
- scratchpad_ = new winograd::winograd_scratchpad_t(jcp);
}
- ~_jit_avx512_common_convolution_winograd_t() {
- delete kernel_;
- delete scratchpad_;
- };
+ ~_jit_avx512_common_convolution_winograd_t() { delete kernel_; }
protected:
void _execute_data_W_S_G_D(const int MB, float *inp_ptr, float *out_ptr,
- float *wei_ptr, float *bias_ptr = NULL);
- void _execute_data_W_SGD(const int MB, float *inp_ptr, float *out_ptr,
- float *wei_ptr, float *bias_ptr = NULL);
+ float *wei_ptr, float *bias_ptr,
+ const memory_tracking::grantor_t &scratchpad) const;
_jit_avx512_common_conv_winograd_data_kernel_f32 *kernel_;
- // Buffer required to store transforms in the frequency domain
- winograd::winograd_scratchpad_t *scratchpad_;
const primitive_attr_t *attr_;
};
-template <bool with_relu>
-struct _jit_avx512_common_convolution_winograd_fwd_t
+struct jit_avx512_common_convolution_winograd_fwd_t
: _jit_avx512_common_convolution_winograd_t<true>
, public cpu_primitive_t
{
- struct pd_t : public _cpu_convolution_fwd_pd_t<with_relu> {
- pd_t(engine_t *engine, const typename pd_t::base_desc_t *adesc,
+ struct pd_t : public cpu_convolution_fwd_pd_t {
+ pd_t(engine_t *engine, const convolution_desc_t *adesc,
const primitive_attr_t *attr,
const typename pd_t::base_class *hint_fwd_pd)
- : _cpu_convolution_fwd_pd_t<with_relu>(engine, adesc, attr,
- hint_fwd_pd)
+ : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
, jcp_() {}
DECLARE_COMMON_PD_T(
JIT_IMPL_NAME_HELPER("jit_wino:", avx512_common, ""),
- _jit_avx512_common_convolution_winograd_fwd_t<with_relu>);
+ jit_avx512_common_convolution_winograd_fwd_t);
virtual status_t init() override
{
using namespace prop_kind;
assert(this->engine()->kind() == engine_kind::cpu);
bool ok = true && this->set_default_params() == status::success
- && utils::one_of(this->cdesc_().prop_kind, forward_training,
+ && utils::one_of(this->desc()->prop_kind, forward_training,
forward_inference)
- && this->cdesc_().alg_kind == alg_kind::convolution_winograd
+ && utils::one_of(this->desc()->alg_kind,
+ alg_kind::convolution_auto,
+ alg_kind::convolution_winograd)
&& !this->has_zero_dim_memory()
&& utils::everyone_is(data_type::f32,
- this->cdesc_().src_desc.data_type,
- this->cdesc_().weights_desc.data_type,
- this->cdesc_().dst_desc.data_type)
+ this->desc()->src_desc.data_type,
+ this->desc()->weights_desc.data_type,
+ this->desc()->dst_desc.data_type)
&& IMPLICATION(this->with_bias(), data_type::f32
- == this->cdesc_().bias_desc.data_type)
+ == this->desc()->bias_desc.data_type)
&& mkldnn_thr_syncable();
+
if (!ok)
return status::unimplemented;
- return jit_avx512_common_conv_winograd_fwd_kernel_f32::init_conf(
- jcp_, this->cdesc_(), *this->src_pd_.desc(),
- *this->weights_pd_.desc(), *this->dst_pd_.desc(),
- *this->attr(), with_relu, this->negative_slope());
+ status_t status =
+ jit_avx512_common_conv_winograd_fwd_kernel_f32::init_conf(
+ jcp_, *this->desc(), *this->src_pd_.desc(),
+ *this->weights_pd_.desc(), *this->dst_pd_.desc(),
+ *this->attr());
+ if (status != status::success) return status;
+
+ auto scratchpad = this->scratchpad_registry().registrar();
+ winograd_avx512_common::init_scratchpad(scratchpad, jcp_);
+
+ if (status == status::success
+ && this->desc()->alg_kind == alg_kind::convolution_auto)
+ CHECK(this->set_alg_kind(alg_kind::convolution_winograd));
+
+ return status;
}
jit_conv_winograd_conf_t jcp_;
}
};
- _jit_avx512_common_convolution_winograd_fwd_t(const pd_t *pd,
+ jit_avx512_common_convolution_winograd_fwd_t(const pd_t *apd,
const input_vector &inputs, const output_vector &outputs)
- : _jit_avx512_common_convolution_winograd_t<true>(pd->jcp_, pd->attr())
- , cpu_primitive_t(&conf_, inputs, outputs)
- , conf_(*pd) {}
+ : _jit_avx512_common_convolution_winograd_t<true>(apd->jcp_, apd->attr())
+ , cpu_primitive_t(apd, inputs, outputs, true) {}
- ~_jit_avx512_common_convolution_winograd_fwd_t(){};
+ ~jit_avx512_common_convolution_winograd_fwd_t(){};
typedef typename prec_traits<data_type::f32>::type data_t;
- virtual void execute(event_t *e)
+ virtual void execute(event_t *e) const
{
float *src = (float *)this->input_memory(0);
float *dst = (float *)this->memory();
float *weights = (float *)this->input_memory(1);
float *bias = (float *)this->input_memory(2);
- switch ((conf_.jcp_).sched_policy) {
- case WSCHED_DATA_W_S_G_D:
- this->_execute_data_W_S_G_D(conf_.MB(), src, dst, weights, bias);
- break;
- case WSCHED_DATA_W_SGD:
- this->_execute_data_W_SGD(conf_.MB(), src, dst, weights, bias);
- break;
- default:
- break;
- }
+ this->_execute_data_W_S_G_D(pd()->MB(), src, dst, weights, bias,
+ this->scratchpad());
+
e->set_state(event_t::ready);
}
private:
- pd_t conf_;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
};
-using jit_avx512_common_convolution_winograd_fwd_t
- = _jit_avx512_common_convolution_winograd_fwd_t<false>;
-using jit_avx512_common_convolution_winograd_relu_t
- = _jit_avx512_common_convolution_winograd_fwd_t<true>;
-
struct jit_avx512_common_convolution_winograd_bwd_data_t
: _jit_avx512_common_convolution_winograd_t<false>,
public cpu_primitive_t {
assert(this->engine()->kind() == engine_kind::cpu);
bool ok = true && this->set_default_params() == status::success
&& utils::one_of(this->desc()->prop_kind, backward_data)
- && this->desc()->alg_kind == alg_kind::convolution_winograd
+ && utils::one_of(this->desc()->alg_kind,
+ alg_kind::convolution_auto,
+ alg_kind::convolution_winograd)
&& !this->has_zero_dim_memory()
&& utils::everyone_is(data_type::f32,
this->desc()->diff_src_desc.data_type,
this->desc()->weights_desc.data_type,
this->desc()->diff_dst_desc.data_type)
&& mkldnn_thr_syncable();
+
if (!ok)
return status::unimplemented;
- return jit_avx512_common_conv_winograd_bwd_data_kernel_f32::
- init_conf(jcp_, *this->desc(), *this->diff_src_pd_.desc(),
- *this->weights_pd_.desc(),
- *this->diff_dst_pd_.desc());
+ status_t status =
+ jit_avx512_common_conv_winograd_bwd_data_kernel_f32::init_conf(
+ jcp_, *this->desc(), *this->diff_src_pd_.desc(),
+ *this->weights_pd_.desc(), *this->diff_dst_pd_.desc());
+ if (status != status::success) return status;
+
+ auto scratchpad = this->scratchpad_registry().registrar();
+ winograd_avx512_common::init_scratchpad(scratchpad, jcp_);
+
+ if (status == status::success
+ && this->desc()->alg_kind == alg_kind::convolution_auto)
+ CHECK(this->set_alg_kind(alg_kind::convolution_winograd));
+
+ return status;
}
jit_conv_winograd_conf_t jcp_;
}
};
- jit_avx512_common_convolution_winograd_bwd_data_t(const pd_t *pd,
+ jit_avx512_common_convolution_winograd_bwd_data_t(const pd_t *apd,
const input_vector &inputs, const output_vector &outputs)
- : _jit_avx512_common_convolution_winograd_t<false>(pd->jcp_, pd->attr())
- , cpu_primitive_t(&conf_, inputs, outputs)
- , conf_(*pd) {}
+ : _jit_avx512_common_convolution_winograd_t<false>(apd->jcp_, apd->attr())
+ , cpu_primitive_t(apd, inputs, outputs, true) {}
~jit_avx512_common_convolution_winograd_bwd_data_t(){};
typedef typename prec_traits<data_type::f32>::type data_t;
- virtual void execute(event_t *e)
+ virtual void execute(event_t *e) const
{
+ assert(pd()->desc()->prop_kind == prop_kind::backward_data
+ && "invalid prop_kind");
+
float *diff_dst = (float *)this->input_memory(0);
float *diff_src = (float *)this->memory();
float *weights = (float *)this->input_memory(1);
- if (conf_.desc()->prop_kind == prop_kind::backward_data) {
- switch ((conf_.jcp_).sched_policy) {
- case WSCHED_DATA_W_S_G_D:
- this->_execute_data_W_S_G_D(conf_.MB(), diff_dst, diff_src, weights, NULL);
- break;
-
- case WSCHED_DATA_W_SGD:
- this->_execute_data_W_SGD(conf_.MB(), diff_dst, diff_src, weights, NULL);
- break;
-
- default:
- break;
- }
- } else {
- assert(!"invalid prop_kind");
- }
+ this->_execute_data_W_S_G_D(pd()->MB(), diff_dst, diff_src, weights, nullptr,
+ this->scratchpad());
e->set_state(event_t::ready);
}
private:
- pd_t conf_;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
};
struct jit_avx512_common_convolution_winograd_bwd_weights_t
assert(this->engine()->kind() == engine_kind::cpu);
bool ok = true && this->set_default_params() == status::success
&& utils::one_of(this->desc()->prop_kind, backward_weights)
- && this->desc()->alg_kind == alg_kind::convolution_winograd
+ && utils::one_of(this->desc()->alg_kind,
+ alg_kind::convolution_auto,
+ alg_kind::convolution_winograd)
&& !this->has_zero_dim_memory()
&& utils::everyone_is(data_type::f32,
this->desc()->src_desc.data_type,
if (!ok)
return status::unimplemented;
- return jit_avx512_common_conv_winograd_bwd_weights_kernel_f32::
- init_conf(jcp_, *this->desc(), *this->src_pd_.desc(),
- *this->diff_dst_pd_.desc(),
- *this->diff_weights_pd_.desc());
+ status_t status =
+ jit_avx512_common_conv_winograd_bwd_weights_kernel_f32::
+ init_conf(jcp_, *this->desc(), *this->src_pd_.desc(),
+ *this->diff_dst_pd_.desc(),
+ *this->diff_weights_pd_.desc());
+ if (status != status::success) return status;
+
+ auto scratchpad = this->scratchpad_registry().registrar();
+ winograd_avx512_common::init_scratchpad(scratchpad, jcp_);
+
+ if (status == status::success
+ && this->desc()->alg_kind == alg_kind::convolution_auto)
+ CHECK(this->set_alg_kind(alg_kind::convolution_winograd));
+
+ return status;
}
jit_conv_winograd_conf_t jcp_;
}
};
- jit_avx512_common_convolution_winograd_bwd_weights_t(const pd_t *pd,
+ jit_avx512_common_convolution_winograd_bwd_weights_t(const pd_t *apd,
const input_vector &inputs, const output_vector &outputs)
- : cpu_primitive_t(&conf_, inputs, outputs)
- , conf_(*pd)
- , kernel_(nullptr)
- , scratchpad_(nullptr)
- , padded_bias_(nullptr)
+ : cpu_primitive_t(apd, inputs, outputs, true), kernel_(nullptr)
{
- auto jcp = conf_.jcp_;
kernel_ = new jit_avx512_common_conv_winograd_bwd_weights_kernel_f32(
- jcp);
- scratchpad_ = new winograd::winograd_scratchpad_t(jcp);
- if (conf_.want_padded_bias())
- padded_bias_ = (float *)malloc(sizeof(float) * jcp.oc, 64);
+ pd()->jcp_);
}
~jit_avx512_common_convolution_winograd_bwd_weights_t()
- {
- delete kernel_;
- delete scratchpad_;
- free(padded_bias_);
- };
+ { delete kernel_; }
typedef typename prec_traits<data_type::f32>::type data_t;
- virtual void execute(event_t *e)
+ virtual void execute(event_t *e) const
{
- if (conf_.desc()->prop_kind == prop_kind::backward_weights) {
- const auto &jcp = kernel_->jcp;
- switch (jcp.sched_policy) {
- case WSCHED_WEI_S_D_G_W:
- _execute_backward_weights_S_D_G_W();
- break;
- case WSCHED_WEI_S_D_Giot_W:
- _execute_backward_weights_S_D_Giot_W();
- break;
- case WSCHED_WEI_SDGtWo:
- _execute_backward_weights_SDGtWo();
- break;
- case WSCHED_WEI_SDGt_W:
- _execute_backward_weights_SDGt_W();
- break;
- default:
- assert(!"Unknown Winograd schedule policy!");
- break;
- }
- }
- else
- assert(!"invalid prop_kind");
+ assert(pd()->desc()->prop_kind == prop_kind::backward_weights
+ && "invalid prop_kind");
+ _execute_backward_weights_S_D_G_W(scratchpad());
e->set_state(event_t::ready);
}
private:
- void _execute_backward_weights_S_D_G_W();
- void _execute_backward_weights_S_D_Giot_W();
- void _execute_backward_weights_SDGtWo();
- void _execute_backward_weights_SDGt_W();
- void _maybe_execute_diff_bias_copy();
+ void _execute_backward_weights_S_D_G_W(
+ const memory_tracking::grantor_t &scratchpad) const;
+ void _maybe_execute_diff_bias_copy(
+ const memory_tracking::grantor_t &scratchpad) const;
- pd_t conf_;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
jit_avx512_common_conv_winograd_bwd_weights_kernel_f32 *kernel_;
-
- // Buffer required to store transforms in the frequency domain
- winograd::winograd_scratchpad_t *scratchpad_;
-
- float *padded_bias_;
};
void trans_W_4x4_3x3(float Fw_[6][6][16][16], float F[3][3][16][16]);