* limitations under the License.
*******************************************************************************/
-#include "mkldnn_types.h"
#include "c_types_map.hpp"
-#include "jit_avx512_common_convolution.hpp"
#include "mkldnn_thread.hpp"
#include "type_helpers.hpp"
#include "utils.hpp"
+#include "jit_avx512_common_convolution.hpp"
+
namespace mkldnn {
namespace impl {
namespace cpu {
using namespace mkldnn::impl::status;
using namespace mkldnn::impl::memory_format;
+using namespace mkldnn::impl::memory_tracking::names;
using namespace mkldnn::impl::utils;
using namespace nstl;
ker(&p);
}
#define wht_blk_off(d, g, ...) \
- (conf_.with_groups() \
+ (pd()->with_groups() \
? (d).blk_off((g), __VA_ARGS__) \
: (d).blk_off(__VA_ARGS__))
-template <bool with_relu, data_type_t src_type, data_type_t wei_type,
+template <data_type_t src_type, data_type_t wei_type, data_type_t dst_type>
+void jit_avx512_common_convolution_fwd_t<src_type, wei_type, dst_type>::
+prepare_padded_bias(const dst_data_t *&bias) const {
+ if (!pd()->wants_padded_bias()) return;
+
+ auto padded_bias = scratchpad().template get<dst_data_t>(
+ key_conv_padded_bias);
+ utils::array_copy(padded_bias, bias, pd()->jcp_.oc_without_padding);
+ utils::array_set(padded_bias + pd()->jcp_.oc_without_padding,
+ (dst_data_t)0, pd()->jcp_.oc - pd()->jcp_.oc_without_padding);
+ bias = padded_bias;
+}
+
+template <data_type_t src_type, data_type_t wei_type,
data_type_t dst_type>
-void _jit_avx512_common_convolution_fwd_t
- <with_relu, src_type, wei_type, dst_type>::execute_forward_1d()
+void jit_avx512_common_convolution_fwd_t
+ <src_type, wei_type, dst_type>::execute_forward_1d() const
{
auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
auto bias = reinterpret_cast<const dst_data_t *>(this->input_memory(2));
auto dst = reinterpret_cast<dst_data_t *>(this->memory());
- const memory_desc_wrapper src_d(conf_.src_pd());
- const memory_desc_wrapper dst_d(conf_.dst_pd());
- const memory_desc_wrapper weights_d(conf_.weights_pd(0));
+ prepare_padded_bias(bias);
- const auto &jcp = kernel_->jcp;
+ const memory_desc_wrapper src_d(pd()->src_pd());
+ const memory_desc_wrapper dst_d(pd()->dst_pd());
+ const memory_desc_wrapper weights_d(pd()->weights_pd(0));
+
+ const auto &jcp = pd()->jcp_;
assert(jcp.nb_oc % jcp.nb_oc_blocking == 0);
int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
else
nthr = mkldnn_get_max_threads();
- if (conf_.want_padded_bias()) {
- for (int oc = 0; oc < jcp.oc_without_padding; ++oc)
- padded_bias_[oc] = bias[oc];
- bias = padded_bias_;
- }
parallel(nthr, [&](const int ithr, const int nthr) {
int start{0}, end{0}, start_copy;
balance211(work_amount, nthr, ithr, start, end);
int ocb = occ * jcp.nb_oc_blocking;
int g_ocb = g * jcp.nb_oc + ocb;
int g_oc = g_ocb * jcp.oc_block;
- int g_icb = g * jcp.nb_ic;
+ int g_icb = g * jcp.nb_ic * jcp.nonblk_group_off;
int ow_s = owb * jcp.ow_block;
int iw_s = ow_s * jcp.stride_w;
});
}
-template <bool with_relu, data_type_t src_type, data_type_t wei_type,
+template <data_type_t src_type, data_type_t wei_type,
data_type_t dst_type>
-void _jit_avx512_common_convolution_fwd_t
- <with_relu, src_type, wei_type, dst_type>::execute_forward_2d()
+void jit_avx512_common_convolution_fwd_t
+ <src_type, wei_type, dst_type>::execute_forward_2d() const
{
auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
auto bias = reinterpret_cast<const dst_data_t *>(this->input_memory(2));
auto dst = reinterpret_cast<dst_data_t *>(this->memory());
- const memory_desc_wrapper src_d(conf_.src_pd());
- const memory_desc_wrapper dst_d(conf_.dst_pd());
- const memory_desc_wrapper weights_d(conf_.weights_pd(0));
+ prepare_padded_bias(bias);
- const auto &jcp = kernel_->jcp;
- const int MB = conf_.MB();
+ const memory_desc_wrapper src_d(pd()->src_pd());
+ const memory_desc_wrapper dst_d(pd()->dst_pd());
+ const memory_desc_wrapper weights_d(pd()->weights_pd(0));
+
+ const auto &jcp = pd()->jcp_;
+ const int MB = pd()->MB();
assert(jcp.nb_oc % jcp.nb_oc_blocking == 0);
int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
else
nthr = mkldnn_get_max_threads();
- if (conf_.want_padded_bias()) {
- for (int oc = 0; oc < jcp.oc_without_padding; ++oc)
- padded_bias_[oc] = bias[oc];
- bias = padded_bias_;
- }
-
parallel(nthr, [&](const int ithr, const int nthr) {
int start{0}, end{0}, start_copy;
balance211(work_amount, nthr, ithr, start, end);
int ocb = occ * jcp.nb_oc_blocking;
int g_ocb = g * jcp.nb_oc + ocb;
int g_oc = g_ocb * jcp.oc_block;
- int g_icb = g * jcp.nb_ic;
+ int g_icb = g * jcp.nb_ic * jcp.nonblk_group_off;
int work_rem = end - start;
});
}
-template <bool with_relu, data_type_t src_type, data_type_t wei_type,
+template <data_type_t src_type, data_type_t wei_type,
data_type_t dst_type>
-void _jit_avx512_common_convolution_fwd_t
- <with_relu, src_type, wei_type, dst_type>::execute_forward_3d()
+void jit_avx512_common_convolution_fwd_t
+ <src_type, wei_type, dst_type>::execute_forward_3d() const
{
auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
auto bias = reinterpret_cast<const dst_data_t *>(this->input_memory(2));
auto dst = reinterpret_cast<dst_data_t *>(this->memory());
- const memory_desc_wrapper src_d(conf_.src_pd());
- const memory_desc_wrapper dst_d(conf_.dst_pd());
- const memory_desc_wrapper weights_d(conf_.weights_pd(0));
- const memory_desc_wrapper bias_d(conf_.weights_pd(1));
+ prepare_padded_bias(bias);
- const auto &jcp = kernel_->jcp;
- const int MB = conf_.MB();
- assert(jcp.nb_oc % jcp.nb_oc_blocking == 0);
+ const memory_desc_wrapper src_d(pd()->src_pd());
+ const memory_desc_wrapper dst_d(pd()->dst_pd());
+ const memory_desc_wrapper weights_d(pd()->weights_pd(0));
+ const memory_desc_wrapper bias_d(pd()->weights_pd(1));
- if (conf_.want_padded_bias()) {
- for (int oc = 0; oc < jcp.oc_without_padding; ++oc)
- padded_bias_[oc] = bias[oc];
- bias = padded_bias_;
- }
+ const auto &jcp = pd()->jcp_;
+ const int MB = pd()->MB();
+ assert(jcp.nb_oc % jcp.nb_oc_blocking == 0);
parallel(0, [&](const int ithr, const int nthr) {
int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
int ocb = occ * jcp.nb_oc_blocking;
int g_ocb = g * jcp.nb_oc + ocb;
int g_oc = g_ocb * jcp.oc_block;
- int g_icb = g * jcp.nb_ic;
+ int g_icb = g * jcp.nb_ic * jcp.nonblk_group_off;
int work_rem = end - start;
int ih_s = -jcp.t_pad + oh_s * jcp.stride_h;
});
}
-template struct _jit_avx512_common_convolution_fwd_t<false, data_type::f32>;
-template struct _jit_avx512_common_convolution_fwd_t<true, data_type::f32>;
-template struct _jit_avx512_common_convolution_fwd_t<false, data_type::s16,
- data_type::s16, data_type::s32>;
-template struct _jit_avx512_common_convolution_fwd_t<true, data_type::s16,
+template struct jit_avx512_common_convolution_fwd_t<data_type::f32>;
+template struct jit_avx512_common_convolution_fwd_t<data_type::s16,
data_type::s16, data_type::s32>;
template <data_type_t diff_dst_type, data_type_t wei_type,
data_type_t diff_src_type>
void jit_avx512_common_convolution_bwd_data_t<diff_dst_type, wei_type,
- diff_src_type>::execute_backward_data_1d() {
+ diff_src_type>::execute_backward_data_1d() const {
auto diff_dst = reinterpret_cast<const diff_dst_data_t *>
(this->input_memory(0));
auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
auto diff_src = reinterpret_cast<diff_src_data_t*>(this->memory());
- const memory_desc_wrapper diff_dst_d(conf_.diff_dst_pd());
- const memory_desc_wrapper diff_src_d(conf_.diff_src_pd());
- const memory_desc_wrapper weights_d(conf_.weights_pd(0));
+ const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
+ const memory_desc_wrapper diff_src_d(pd()->diff_src_pd());
+ const memory_desc_wrapper weights_d(pd()->weights_pd(0));
const auto &jcp = kernel_->jcp;
template <data_type_t diff_dst_type, data_type_t wei_type,
data_type_t diff_src_type>
void jit_avx512_common_convolution_bwd_data_t<diff_dst_type, wei_type,
- diff_src_type>::execute_backward_data_2d() {
+ diff_src_type>::execute_backward_data_2d() const {
auto diff_dst = reinterpret_cast<const diff_dst_data_t *>
(this->input_memory(0));
auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
auto diff_src = reinterpret_cast<diff_src_data_t*>(this->memory());
- const memory_desc_wrapper diff_dst_d(conf_.diff_dst_pd());
- const memory_desc_wrapper diff_src_d(conf_.diff_src_pd());
- const memory_desc_wrapper weights_d(conf_.weights_pd(0));
+ const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
+ const memory_desc_wrapper diff_src_d(pd()->diff_src_pd());
+ const memory_desc_wrapper weights_d(pd()->weights_pd(0));
const auto &jcp = kernel_->jcp;
- const int MB = conf_.MB();
+ const int MB = pd()->MB();
parallel(0, [&](const int ithr, const int nthr) {
int start{0}, end{0}, start_copy;
template <data_type_t diff_dst_type, data_type_t wei_type,
data_type_t diff_src_type>
void jit_avx512_common_convolution_bwd_data_t<diff_dst_type, wei_type,
- diff_src_type>::execute_backward_data_3d() {
+ diff_src_type>::execute_backward_data_3d() const {
auto diff_dst = reinterpret_cast<const diff_dst_data_t *>
(this->input_memory(0));
auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
auto diff_src = reinterpret_cast<diff_src_data_t*>(this->memory());
- const memory_desc_wrapper diff_dst_d(conf_.diff_dst_pd());
- const memory_desc_wrapper diff_src_d(conf_.diff_src_pd());
- const memory_desc_wrapper weights_d(conf_.weights_pd(0));
+ const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
+ const memory_desc_wrapper diff_src_d(pd()->diff_src_pd());
+ const memory_desc_wrapper weights_d(pd()->weights_pd(0));
const auto &jcp = kernel_->jcp;
- const int MB = conf_.MB();
+ const int MB = pd()->MB();
parallel(0, [&](const int ithr, const int nthr) {
int start{0}, end{0}, start_copy;
data_type_t diff_weights_type>
jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
diff_weights_type>::
-jit_avx512_common_convolution_bwd_weights_t(const pd_t *pd,
+jit_avx512_common_convolution_bwd_weights_t(const pd_t *apd,
const input_vector &inputs, const output_vector &outputs)
- : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd), kernel_(nullptr)
+ : cpu_primitive_t(apd, inputs, outputs), kernel_(nullptr)
, trans_kernel_(nullptr), trans_dst_kernel_(nullptr), acc_ker_(nullptr)
- , reducer_bias_(nullptr), padded_bias_(nullptr), tr_src_(nullptr)
- , tr_diff_dst_(nullptr), ws_reduction_(nullptr), tr_src_bctx_(nullptr)
- , tr_diff_dst_bctx_(nullptr)
+ , reducer_bias_(nullptr)
{
- const auto &j = conf_.jcp_;
- kernel_ = new jit_avx512_common_conv_bwd_weights_kernel_f32(j);
+ const auto &j = pd()->jcp_;
- balance();
+ nthr_ = j.nthr;
+ nthr_mb_ = j.nthr_mb;
+ nthr_g_ = j.nthr_g;
+ nthr_oc_b_ = j.nthr_oc_b;
+ nthr_ic_b_ = j.nthr_ic_b;
+
+ kernel_ = new jit_avx512_common_conv_bwd_weights_kernel_f32(j);
if (utils::one_of(j.ver, ver_4fma, ver_4vnni, ver_vnni)) {
trans_kernel_ = create_trans_src(&j);
if (utils::one_of(j.ver, ver_4vnni, ver_vnni))
trans_dst_kernel_ = create_trans_dst(&j);
- if (j.is_1stconv) {
- const int tr_src_size =
- nthr_ / nthr_oc_b_ * j.ih * j.stride_w * j.tr_ld;
- tr_src_ = (src_data_t *)malloc(tr_src_size * sizeof(src_data_t), 64);
- } else {
- // XXX: See the comment about tr_iw and guarding elements in
- // jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf()
- const int max_nthr = nthr_mb_ * j.ngroups * j.nb_ic;
- const int min_tr_src_size_per_thr = j.ih * j.ic_block * j.tr_iw;
- const int tr_src_size = max_nthr * min_tr_src_size_per_thr
- + j.tr_src_num_guard_elems;
- tr_src_ = (src_data_t *)malloc(tr_src_size * sizeof(src_data_t), 64);
- /* to avoid NaNs in computations we zero tail num_guard_elems for
- * each possible thread group */
- for (int ithr = 1; ithr <= max_nthr; ++ithr) {
- src_data_t *ts = &tr_src_[ithr * min_tr_src_size_per_thr];
- for (int i = 0; i < j.tr_src_num_guard_elems; ++i)
- ts[i] = 0;
- }
- }
-
- /* prepare synchronization contexts */
- if (nthr_oc_b_ > 1) {
- const int tr_src_bctx_size = nthr_ / nthr_oc_b_;
- tr_src_bctx_ = (simple_barrier::ctx_t *)malloc(
- tr_src_bctx_size * sizeof(simple_barrier::ctx_t), 64);
- for (int i = 0; i < tr_src_bctx_size; ++i)
- simple_barrier::ctx_init(&tr_src_bctx_[i]);
- }
-
- if (utils::one_of(j.ver, ver_4vnni, ver_vnni)) {
- const size_t tr_diff_dst_size =
- nthr_mb_ * j.ngroups * j.nb_oc * j.oc_block * j.tr_ow * j.oh;
- tr_diff_dst_ = (diff_dst_data_t *)malloc(
- tr_diff_dst_size * sizeof(diff_dst_data_t), 64);
-
- /* prepare synchronization contexts */
- if (nthr_ic_b_ > 1) {
- const size_t tr_diff_dst_bctx_size = nthr_ / nthr_ic_b_;
- tr_diff_dst_bctx_ = (simple_barrier::ctx_t *)malloc(
- tr_diff_dst_bctx_size * sizeof(simple_barrier::ctx_t),
- 64);
- for (size_t i = 0; i < tr_diff_dst_bctx_size; ++i)
- simple_barrier::ctx_init(&tr_diff_dst_bctx_[i]);
- }
- }
}
- if (nthr_mb_ > 1) {
- const int wei_size = j.ngroups * j.oc * j.ic * j.kh * j.kw * j.kd;
- const int bia_size = j.ngroups * j.oc;
- ws_reduction_ = (diff_weights_data_t *)malloc((nthr_mb_ - 1)
- * (wei_size + bia_size) * sizeof(diff_weights_data_t), 64);
+ if (nthr_mb_ > 1)
acc_ker_ = new cpu_accumulator_1d_t<diff_weights_type>();
- simple_barrier::ctx_init(&reduction_bctx_);
- }
- if (conf_.with_bias()) {
- const size_t max_buffer_size = nthr_ * 3 * 5 * 5 * 16 * 16;
- reducer_bias_ = new cpu_reducer_t<diff_weights_type>(reduce_balancer_t(
- nthr_, j.oc_block, j.ngroups * j.nb_oc, j.mb,
- max_buffer_size));
- if (conf_.want_padded_bias())
- padded_bias_ = (diff_weights_data_t *)
- malloc(sizeof(diff_weights_data_t) * j.oc, 64);
- }
+ reducer_bias_ =
+ new cpu_reducer_t<diff_weights_type>(pd()->reducer_bia_conf_);
}
template <data_type_t src_type, data_type_t diff_dst_type,
const diff_weights_data_t *diff_weights;
diff_weights_data_t *diff_bias;
+ const memory_tracking::grantor_t scratchpad;
+
+ src_data_t *tr_src;
+ simple_barrier::ctx_t *tr_src_bctx;
+
+ diff_dst_data_t *tr_diff_dst;
+ simple_barrier::ctx_t *tr_diff_dst_bctx;
+
+ diff_weights_data_t *wei_bia_reduction;
+ simple_barrier::ctx_t *wei_bia_reduction_bctx;
+
int ithr;
int ithr_ic_b, ithr_oc_b, ithr_g, ithr_mb;
int ithr_but_oc;
int ic_b_start = 0, ic_b_end = 0, ic_b_work;
thread_info_t(const jit_avx512_common_convolution_bwd_weights_t *self,
- int ithr): ithr(ithr) {
-
+ int ithr): scratchpad(self->scratchpad()), ithr(ithr) {
src = reinterpret_cast<const src_data_t *>(self->input_memory(0));
diff_dst = reinterpret_cast<const diff_dst_data_t *>(
self->input_memory(1));
diff_weights = reinterpret_cast<diff_weights_data_t *>(self->memory(0));
- diff_bias = self->conf_.want_padded_bias()
- ? self->padded_bias_
+ diff_bias = self->pd()->wants_padded_bias()
+ ? scratchpad.template get<diff_weights_data_t>(
+ key_conv_padded_bias)
: reinterpret_cast<diff_weights_data_t *>(self->memory(1));
+ tr_src = scratchpad.template get<src_data_t>(key_conv_tr_src);
+ tr_src_bctx = scratchpad.template get<simple_barrier::ctx_t>(
+ key_conv_tr_src_bctx);
+
+ tr_diff_dst = scratchpad.template get<diff_dst_data_t>(
+ key_conv_tr_diff_dst);
+ tr_diff_dst_bctx = scratchpad.template get<simple_barrier::ctx_t>(
+ key_conv_tr_diff_dst_bctx);
+
+ wei_bia_reduction = scratchpad.template get<diff_weights_data_t>(
+ key_conv_wei_bia_reduction);
+ wei_bia_reduction_bctx = scratchpad.template get<simple_barrier::ctx_t>(
+ key_conv_wei_bia_reduction_bctx);
+
ithr_ic_b = ithr % self->nthr_ic_b_;
ithr_oc_b = ithr / self->nthr_ic_b_ % self->nthr_oc_b_;
ithr_g = ithr / self->nthr_ic_b_ / self->nthr_oc_b_ % self->nthr_g_;
template <data_type_t src_type, data_type_t diff_dst_type,
data_type_t diff_weights_type>
void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
- diff_weights_type>::compute_diff_weights(const thread_info_t *ti) {
- const memory_desc_wrapper src_d(conf_.src_pd(0));
- const memory_desc_wrapper diff_dst_d(conf_.diff_dst_pd());
- const memory_desc_wrapper diff_weights_d(conf_.diff_weights_pd(0));
+ diff_weights_type>::compute_diff_weights(const thread_info_t *ti) const {
+ const memory_desc_wrapper src_d(pd()->src_pd(0));
+ const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
+ const memory_desc_wrapper diff_weights_d(pd()->diff_weights_pd(0));
const auto &jcp = kernel_->jcp;
const int wei_size = jcp.ngroups * jcp.oc * jcp.ic * jcp.kh*jcp.kw*jcp.kd;
diff_weights_data_t *diff_wei = ti->ithr_mb == 0
? (diff_weights_data_t*)ti->diff_weights
- : (diff_weights_data_t*)ws_reduction_ + (ti->ithr_mb - 1) * wei_size;
+ : ti->wei_bia_reduction + (ti->ithr_mb - 1) * wei_size;
diff_weights_data_t *diff_bia = ti->ithr_mb == 0
? (diff_weights_data_t*)ti->diff_bias
- : (diff_weights_data_t*)ws_reduction_ + (nthr_mb_ - 1) * wei_size
+ : ti->wei_bia_reduction + (nthr_mb_ - 1) * wei_size
+ (ti->ithr_mb - 1) * jcp.ngroups * jcp.oc;
// TODO: use memory descriptor with the same fmt as src (or use a macro :))
const int _ic = g * jcp.nb_ic + ic_b;
src_data_t *src1 = (src_data_t*)&ti->src[src_d.blk_off(img, _ic, j)];
- src_data_t *tr_src1 = &tr_src_[tr_src_off(ti->ithr_mb, _ic, j)];
+ src_data_t *tr_src1 = &ti->tr_src[tr_src_off(ti->ithr_mb, _ic, j)];
assert(jcp.ic_block == 16);
const int src_stride = jcp.iw * jcp.ic_block;
const diff_dst_data_t *diff_dst1
= &ti->diff_dst[diff_dst_d.blk_off(img, oc, j)];
diff_dst_data_t *tr_diff_dst1
- = &tr_diff_dst_[tr_diff_dst_off(img, oc, j)];
+ = &ti->tr_diff_dst[tr_diff_dst_off(img, oc, j)];
assert(jcp.ic_block == 16);
if (jcp.is_1stconv && jcp.ver == ver_4fma) {
/* prepare contexts */
auto tr_ctx = jit_trans_src_t::ctx_t();
- tr_ctx.tr_src = tr_src_
+ tr_ctx.tr_src = ti->tr_src
+ ti->ithr_but_oc * jcp.ih * jcp.stride_w * jcp.tr_ld;
assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_oc_b_ == 1));
balance211(jcp.ih, nthr_oc_b_, ti->ithr_oc_b, ih_start, ih_end);
tr_ctx.tr_src_ih_start = ih_start;
tr_ctx.tr_src_ih_end = ih_end;
- tr_ctx.tr_src_bctx = tr_src_bctx_ + ti->ithr_but_oc;
+ tr_ctx.tr_src_bctx = ti->tr_src_bctx + ti->ithr_but_oc;
auto p = jit_conv_call_s();
p.src = tr_ctx.tr_src;
/* tr_src[nb_ic][ih][16][~iw~] <- src[nb_ic][ih][iw][16] */
using simple_barrier::barrier;
if (nthr_oc_b_ > 1)
- barrier(&tr_src_bctx_[ti->ithr_but_oc], nthr_oc_b_);
+ barrier(&ti->tr_src_bctx[ti->ithr_but_oc], nthr_oc_b_);
uker_trans(img);
if (nthr_oc_b_ > 1)
- barrier(&tr_src_bctx_[ti->ithr_but_oc], nthr_oc_b_);
+ barrier(&ti->tr_src_bctx[ti->ithr_but_oc], nthr_oc_b_);
}
if (utils::one_of(jcp.ver, ver_4vnni, ver_vnni)) {
/* tr_diff_dst[nb_oc][OW][oh][16c][2ow]
* <- diff_dst[nb_oc][oh][ow][16c] */
if (nthr_ic_b_ > 1)
- barrier(&tr_diff_dst_bctx_[ti->ithr_but_ic], nthr_ic_b_);
+ barrier(&ti->tr_diff_dst_bctx[ti->ithr_but_ic], nthr_ic_b_);
diff_dst_trans(img);
if (nthr_ic_b_ > 1)
- barrier(&tr_diff_dst_bctx_[ti->ithr_but_ic], nthr_ic_b_);
+ barrier(&ti->tr_diff_dst_bctx[ti->ithr_but_ic], nthr_ic_b_);
}
for (int g = ti->g_start; g < ti->g_end; ++g) {
jit_conv_ker_pipeline(kernel_->jit_ker, p,
(utils::one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni)
- ? &tr_src_[tr_src_off(ti->ithr_mb, _ic, 0)]
+ ? &ti->tr_src[tr_src_off(ti->ithr_mb, _ic, 0)]
: &ti->src[src_d.blk_off(img, _ic)]),
utils::one_of(jcp.ver, ver_4vnni, ver_vnni)
- ? &tr_diff_dst_[tr_diff_dst_off(ti->ithr_mb, _oc, 0)]
+ ? &ti->tr_diff_dst[tr_diff_dst_off(ti->ithr_mb, _oc, 0)]
: &ti->diff_dst[diff_dst_d.blk_off(img, _oc)],
diff_wei + wht_blk_off(diff_weights_d, g, oc_b, ic_b),
0, (img == ti->img_start), 0, 0);
const int _ic = ti->g_start * jcp.nb_ic + ti->ic_b_start;
jit_conv_ker_pipeline(kernel_->jit_ker, p,
(utils::one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni)
- ? &tr_src_[tr_src_off(ti->ithr_mb, _ic, 0)]
+ ? &ti->tr_src[tr_src_off(ti->ithr_mb, _ic, 0)]
: &ti->src[src_d.blk_off(img + 1, _ic)]),
utils::one_of(jcp.ver, ver_4vnni, ver_vnni)
- ? &tr_diff_dst_[tr_diff_dst_off(ti->ithr_mb, _oc, 0)]
+ ? &ti->tr_diff_dst[tr_diff_dst_off(ti->ithr_mb, _oc, 0)]
: &ti->diff_dst[diff_dst_d.blk_off(img + 1, _oc)],
diff_wei + wht_blk_off(
diff_weights_d, ti->g_start,
template <data_type_t src_type, data_type_t diff_dst_type,
data_type_t diff_weights_type>
void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
- diff_weights_type>::compute_diff_weights_3d(const thread_info_t *ti) {
- const memory_desc_wrapper src_d(conf_.src_pd(0));
- const memory_desc_wrapper diff_dst_d(conf_.diff_dst_pd());
- const memory_desc_wrapper diff_weights_d(conf_.diff_weights_pd(0));
+ diff_weights_type>::compute_diff_weights_3d(const thread_info_t *ti) const
+{
+ const memory_desc_wrapper src_d(pd()->src_pd(0));
+ const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
+ const memory_desc_wrapper diff_weights_d(pd()->diff_weights_pd(0));
const auto &jcp = kernel_->jcp;
const int wei_size
diff_weights_data_t *diff_wei = ti->ithr_mb == 0
? (diff_weights_data_t*)ti->diff_weights
- : (diff_weights_data_t*)ws_reduction_ + (ti->ithr_mb - 1) * wei_size;
+ : ti->wei_bia_reduction + (ti->ithr_mb - 1) * wei_size;
diff_weights_data_t *diff_bia = ti->ithr_mb == 0
? (diff_weights_data_t*)ti->diff_bias
- : (diff_weights_data_t*)ws_reduction_ + (nthr_mb_ - 1) * wei_size
+ : ti->wei_bia_reduction + (nthr_mb_ - 1) * wei_size
+ (ti->ithr_mb - 1) * jcp.ngroups * jcp.oc;
const int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block;
template <data_type_t src_type, data_type_t diff_dst_type,
data_type_t diff_weights_type>
void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
- diff_weights_type>::reduce_diff_weights(const thread_info_t *ti) {
- const memory_desc_wrapper diff_weights_d(conf_.diff_weights_pd(0));
+ diff_weights_type>::reduce_diff_weights(const thread_info_t *ti) const {
+ const memory_desc_wrapper diff_weights_d(pd()->diff_weights_pd(0));
const auto &jcp = kernel_->jcp;
const int wei_size = jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw;
const int bia_size = jcp.ngroups * jcp.oc;
const diff_weights_data_t *diff_bias_ws
- = ws_reduction_ + (nthr_mb_ - 1) * wei_size;
+ = ti->wei_bia_reduction + (nthr_mb_ - 1) * wei_size;
- /* diff_weights[:] += sum(ws_reduction_[thr_mb][:]) */
- simple_barrier::barrier(&reduction_bctx_, nthr_);
+ /* diff_weights[:] += sum(wei_reduction_[thr_mb][:]) */
+ simple_barrier::barrier(ti->wei_bia_reduction_bctx, nthr_);
const int ic_b_kh_work = ti->ic_b_work * jcp.kh;
const int work = ti->g_work * ti->oc_b_work * ic_b_kh_work;
diff_weights_data_t *d
= (diff_weights_data_t *)ti->diff_weights + off;
diff_weights_data_t *s
- = ws_reduction_ + (thr_mb - 1) * wei_size + off;
+ = ti->wei_bia_reduction + (thr_mb - 1) * wei_size + off;
acc_ker_->accumulate(d, s, acc_size);
template <data_type_t src_type, data_type_t diff_dst_type,
data_type_t diff_weights_type>
void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
- diff_weights_type>::reduce_diff_weights_3d(const thread_info_t *ti) {
- const memory_desc_wrapper diff_weights_d(conf_.diff_weights_pd(0));
+ diff_weights_type>::reduce_diff_weights_3d(const thread_info_t *ti) const {
+ const memory_desc_wrapper diff_weights_d(pd()->diff_weights_pd(0));
const auto &jcp = kernel_->jcp;
const int wei_size = jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw
* jcp.kd;
- /* diff_weights[:] += sum(ws_reduction_[thr_mb][:]) */
- simple_barrier::barrier(&reduction_bctx_, nthr_);
+ /* diff_weights[:] += sum(wei_reduction_[thr_mb][:]) */
+ simple_barrier::barrier(ti->wei_bia_reduction_bctx, nthr_);
const int ic_b_kh_work = ti->ic_b_work * jcp.kd;
const int work = ti->g_work * ti->oc_b_work * ic_b_kh_work;
diff_weights_data_t *d
= (diff_weights_data_t *)ti->diff_weights + off;
diff_weights_data_t *s
- = ws_reduction_ + (thr_mb - 1) * wei_size + off;
+ = ti->wei_bia_reduction + (thr_mb - 1) * wei_size + off;
acc_ker_->accumulate(d, s, acc_size);
nd_iterator_jump(w, end, sub_g_start, ti->g_work, sub_oc_b_start,
template <data_type_t src_type, data_type_t diff_dst_type,
data_type_t diff_weights_type>
void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
- diff_weights_type>::compute_diff_bias(const thread_info_t *ti) {
- const memory_desc_wrapper diff_dst_d(conf_.diff_dst_pd());
+ diff_weights_type>::compute_diff_bias(const thread_info_t *ti) const {
+ const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
auto rb = this->reducer_bias_;
- assert(nthr_ == rb->balancer_.nthr_);
+ assert(nthr_ == rb->balancer().nthr_);
+
+ const auto reducer_bia_scratchpad = memory_tracking::grantor_t(
+ ti->scratchpad, prefix_reducer_bia);
const auto &jcp = kernel_->jcp;
if (jcp.with_bias && jcp.is_1stconv && jcp.ver == ver_4fma) return;
- const int b_job_start = rb->balancer_.ithr_job_off(ti->ithr);
- const int b_njobs = rb->balancer_.ithr_njobs(ti->ithr);
+ const int b_job_start = rb->balancer().ithr_job_off(ti->ithr);
+ const int b_njobs = rb->balancer().ithr_njobs(ti->ithr);
if (b_njobs == 0) return;
/* reduction dimension */
int img_start{0}, img_end{0};
- balance211(jcp.mb, rb->balancer_.nthr_per_group_,
- rb->balancer_.id_in_group(ti->ithr), img_start, img_end);
+ balance211(jcp.mb, rb->balancer().nthr_per_group_,
+ rb->balancer().id_in_group(ti->ithr), img_start, img_end);
/* jobs */
int g_start{0}, ocb_start{0};
const diff_dst_data_t *d_dst
= &ti->diff_dst[diff_dst_d.blk_off(img, _oc)];
- diff_weights_data_t *d_bias = &rb->get_local_ptr(ti->ithr,
- (diff_weights_data_t *)ti->diff_bias)[
- b_job_loc * rb->balancer_.job_size_];
+ diff_weights_data_t *d_bias = rb->get_local_ptr(ti->ithr,
+ ti->diff_bias, reducer_bia_scratchpad)
+ + b_job_loc * rb->balancer().job_size_;
if (img == img_start)
for (int o = 0; o < 16; ++o)
}
}
- rb->reduce(ti->ithr, ti->diff_bias);
+ rb->reduce(ti->ithr, ti->diff_bias, reducer_bia_scratchpad);
}
template <data_type_t src_type, data_type_t diff_dst_type,
data_type_t diff_weights_type>
void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
- diff_weights_type>::compute_diff_bias_3d(const thread_info_t *ti) {
+ diff_weights_type>::compute_diff_bias_3d(const thread_info_t *ti) const {
const auto &jcp = kernel_->jcp;
* jcp.kw * jcp.kd;
const int bia_size = jcp.ngroups * jcp.oc;
const diff_weights_data_t *diff_bias_ws
- = ws_reduction_ + (size_t)(nthr_mb_ - 1) * wei_size;
+ = ti->wei_bia_reduction + (size_t)(nthr_mb_ - 1) * wei_size;
if (nthr_mb_ > 1) mkldnn_thr_barrier();
template <data_type_t src_type, data_type_t diff_dst_type,
data_type_t diff_weights_type>
void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
- diff_weights_type>::execute_backward_weights() {
+ diff_weights_type>::prepare_scratchpad_data() const
+{
+ const auto &j = pd()->jcp_;
+ auto scratchpad = this->scratchpad();
+
+ if (utils::one_of(j.ver, ver_4fma, ver_4vnni, ver_vnni)) {
+ if (!j.is_1stconv) {
+ // XXX: See the comment about tr_iw and guarding elements in
+ // jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf()
+ const int max_nthr = j.nthr_mb * j.ngroups * j.nb_ic;
+ const int min_tr_src_size_per_thr = j.ih * j.ic_block * j.tr_iw;
+
+ auto tr_src = scratchpad.template get<src_data_t>(key_conv_tr_src);
+ /* to avoid NaNs in computations we zero tail num_guard_elems for
+ * each possible thread group */
+
+ for (int ithr = 1; ithr <= max_nthr; ++ithr) {
+ src_data_t *ts = &tr_src[ithr * min_tr_src_size_per_thr];
+ for (int i = 0; i < j.tr_src_num_guard_elems; ++i)
+ ts[i] = 0;
+ }
+ }
+
+ if (j.nthr_oc_b > 1) {
+ const int tr_src_bctx_size = j.nthr / j.nthr_oc_b;
+ auto tr_src_bctx = scratchpad.template get<simple_barrier::ctx_t>(
+ key_conv_tr_src_bctx);
+ for (int i = 0; i < tr_src_bctx_size; ++i)
+ simple_barrier::ctx_init(&tr_src_bctx[i]);
+ }
+
+ if (utils::one_of(j.ver, ver_4vnni, ver_vnni) && j.nthr_ic_b > 1) {
+ const int tr_diff_dst_bctx_size = j.nthr / j.nthr_ic_b;
+ auto tr_diff_dst_bctx =
+ scratchpad.template get<simple_barrier::ctx_t>(
+ key_conv_tr_diff_dst_bctx);
+ for (int i = 0; i < tr_diff_dst_bctx_size; ++i)
+ simple_barrier::ctx_init(&tr_diff_dst_bctx[i]);
+ }
+ }
+
+ if (nthr_mb_ > 1) {
+ simple_barrier::ctx_init(scratchpad.template get<simple_barrier::ctx_t>(
+ key_conv_wei_bia_reduction_bctx));
+ }
+
+ const auto reducer_bia_scratchpad = memory_tracking::grantor_t(scratchpad,
+ prefix_reducer_bia);
+ auto rb = this->reducer_bias_;
+ rb->init(reducer_bia_scratchpad);
+}
+
+template <data_type_t src_type, data_type_t diff_dst_type,
+ data_type_t diff_weights_type>
+void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
+ diff_weights_type>::execute_backward_weights() const {
+ prepare_scratchpad_data();
+
parallel(nthr_, [&](const int ithr, const int nthr) {
assert(nthr_ == nthr);
thread_info_t thread_info(this, ithr);
- if (utils::one_of(conf_.ndims(), 3, 4)) {
+ if (utils::one_of(pd()->ndims(), 3, 4)) {
compute_diff_weights(&thread_info);
if (nthr_mb_ > 1) reduce_diff_weights(&thread_info);
- if (conf_.with_bias()) compute_diff_bias(&thread_info);
- } else if (conf_.ndims() == 5) {
+ if (pd()->with_bias()) compute_diff_bias(&thread_info);
+ } else if (pd()->ndims() == 5) {
compute_diff_weights_3d(&thread_info);
if (nthr_mb_ > 1) reduce_diff_weights_3d(&thread_info);
- if (conf_.with_bias()) compute_diff_bias_3d(&thread_info);
+ if (pd()->with_bias()) compute_diff_bias_3d(&thread_info);
} else {
assert(false);
}
});
/* TODO: put that into compute_diff_bias() */
- if (conf_.want_padded_bias()) {
+ if (pd()->wants_padded_bias()) {
+ auto diff_bias = scratchpad().template get<const diff_weights_data_t>(
+ key_conv_padded_bias);
auto diff_bias_in
= reinterpret_cast<diff_weights_data_t *>(this->memory(1));
- for (int oc = 0; oc < conf_.jcp_.oc_without_padding; ++oc)
- diff_bias_in[oc] = this->padded_bias_[oc];
- }
-}
-
-template <data_type_t src_type, data_type_t diff_dst_type,
- data_type_t diff_weights_type>
-void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
- diff_weights_type>::balance() {
- const int max_threads = mkldnn_get_max_threads();
- const auto &j = conf_.jcp_;
-
- nthr_ = nthr_mb_ = nthr_g_ = nthr_oc_b_ = nthr_ic_b_ = 1;
-
- if (max_threads < j.ngroups) {
- /* simplification... fortunately it doesn't hurt much */
- return;
- }
-
- if (!mkldnn_thr_syncable()
- && utils::one_of(j.ver, ver_4fma, ver_4vnni, ver_vnni)) {
- // should not happen -- the driver is not ready
- // for TBB-like non-synchronous threading yet
- return;
+ for (int oc = 0; oc < pd()->jcp_.oc_without_padding; ++oc)
+ diff_bias_in[oc] = diff_bias[oc];
}
-
- if (j.ver == ver_4fma && j.is_1stconv) {
- nthr_g_ = 1;
- nthr_oc_b_ = 1;
- nthr_ic_b_ = nstl::min(j.nb_ic, max_threads);
- nthr_mb_ = nstl::min(max_threads / nthr_ic_b_, j.mb);
- nthr_ = nthr_mb_ * nthr_oc_b_ * nthr_ic_b_ * nthr_g_;
- return;
- }
-
- nthr_g_ = j.ngroups;
- const int nthr = max_threads / nthr_g_;
-
- auto calc_mem_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) {
- /* calculate per thread memory cost (read/write). high level optimizer
- * tries to minimize memory consumption. few notes:
- * (n1) unclear why, but that essentially helps first convolution...
- * (n2) assuming the reduction over minibatch is always there:
- * - instead of 8 it should be 5 here (write ~= 2 read):
- * kernel: temporal workspace 1 write
- * reduction: 1 read from workspace and 1 write to the diff_wei
- * - but experiments showed 8 works better than 5 or 6... */
-
- const int src_coef = j.ver == ver_4fma || j.ver == ver_vnni ? 4 : 1;
- const int dst_coef = 1;
- const int wei_coef = j.ver == ver_vnni ? 4 : 8;
-
- return 0
- + src_coef
- * div_up(j.mb, nthr_mb) * div_up(j.ngroups, nthr_g_)
- * div_up(j.nb_ic, nthr_ic_b) * j.ic_block * j.ih * j.iw * j.id
- / j.stride_d / j.stride_h / j.stride_w /* (n1) */
- + dst_coef
- * div_up(j.mb, nthr_mb) * div_up(j.ngroups, nthr_g_)
- * div_up(j.nb_oc, nthr_oc_b) * j.oc_block * j.oh * j.ow * j.od
- + wei_coef /* (n2) */
- * div_up(j.ngroups, nthr_g_)
- * div_up(j.nb_oc, nthr_oc_b) * div_up(j.nb_ic, nthr_ic_b)
- * j.kh * j.kw * j.kd * j.ic_block * j.oc_block;
- };
-
- int best_mem_cost = calc_mem_cost(nthr_mb_, nthr_oc_b_, nthr_ic_b_);
-
- /* step 1: find the best thread distribution with lowest memory cost */
- const int nthr_mb_max = nstl::min(nthr, j.mb * j.od);
- for (int nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) {
- const int nthr_par = nthr / nthr_mb;
- const int nthr_oc_b_max = nstl::min(nthr_par, j.nb_oc);
- for (int nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) {
- int nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, j.nb_ic);
-
- int mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
- if (mem_cost <= best_mem_cost) {
- best_mem_cost = mem_cost;
- nthr_mb_ = nthr_mb;
- nthr_oc_b_ = nthr_oc_b;
- nthr_ic_b_ = nthr_ic_b;
- }
- }
-
- if (!mkldnn_thr_syncable()) { assert(nthr_mb == 1); break; }
- }
-
- if (j.ver != ver_vnni && !mayiuse(avx512_mic)) {
- auto calc_comp_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) {
- return 1
- * div_up(j.mb, nthr_mb)
- * div_up(j.ngroups, nthr_g_)
- * div_up(j.nb_oc, nthr_oc_b)
- * div_up(j.nb_ic, nthr_ic_b);
- };
-
- /* step 2: search for a thread distribution with lower compute cost.
- * the constrains:
- * - memory cost cannot exceed 110% of the best found in the step 1
- * - unless compute cost is 133% lower than the current best case
- * note: both constants were found empirically */
- int best_comp_cost = calc_comp_cost(nthr_mb_, nthr_oc_b_, nthr_ic_b_);
- for (int nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) {
- const int nthr_par = nthr / nthr_mb;
- const int nthr_oc_b_max = nstl::min(nthr_par, j.nb_oc);
- for (int nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) {
- int nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, j.nb_ic);
- int mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
- int comp_cost = calc_comp_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
-
- const bool opt1 = comp_cost <= best_comp_cost
- && mem_cost < 1.1 * best_mem_cost;
- const bool opt2 = 4 * comp_cost <= 3 * best_comp_cost;
-
- if (opt1 || opt2) {
- best_comp_cost = comp_cost;
- nthr_mb_ = nthr_mb;
- nthr_oc_b_ = nthr_oc_b;
- nthr_ic_b_ = nthr_ic_b;
- }
- }
-
- if (!mkldnn_thr_syncable()) { assert(nthr_mb == 1); break; }
- }
- }
-
- if (nthr_mb_ > max_threads/2 && nthr_mb_ < max_threads)
- nthr_mb_ = min(j.mb * j.od, max_threads);
- nthr_ = nthr_mb_ * nthr_g_ * nthr_oc_b_ * nthr_ic_b_;
- assert(nthr_ <= max_threads);
- assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_mb_ == 1));
}
template struct jit_avx512_common_convolution_bwd_weights_t<data_type::f32>;