* limitations under the License.
*******************************************************************************/
-#include "mkldnn_types.h"
-
#include "c_types_map.hpp"
-#include "jit_avx512_common_1x1_convolution.hpp"
-#include "utils.hpp"
#include "mkldnn_thread.hpp"
#include "type_helpers.hpp"
+#include "utils.hpp"
#include "jit_generator.hpp"
+#include "jit_avx512_common_1x1_convolution.hpp"
+
namespace mkldnn {
namespace impl {
namespace cpu {
using namespace mkldnn::impl::status;
using namespace mkldnn::impl::memory_format;
+using namespace mkldnn::impl::memory_tracking::names;
using namespace mkldnn::impl::utils;
#define data_blk_off(f, n, c, h, w) \
? (f).blk_off(n, c, w) \
: (f).blk_off(n, c, h, w))
+
namespace {
template <typename T, typename U>
void balance2D(U nthr, U ithr, T ny, T &ny_start, T &ny_end,
T nx, T &nx_start, T &nx_end, T nx_divider)
{
- const T grp_size = utils::div_up(nthr, nx_divider);
- const T grp_count = utils::div_up(nthr, grp_size);
-
- T grp = ithr / grp_size;
- T grp_ithr = ithr % grp_size;
- T grp_nthr = grp_size;
- T first_grps = nthr % grp_count;
- if (first_grps > 0 && grp >= first_grps) {
- ithr -= first_grps * grp_size;
- grp_nthr--;
- grp = ithr / grp_nthr + first_grps;
- grp_ithr = ithr % grp_nthr;
+ const int grp_count = nstl::min(nx_divider, nthr);
+ const int grp_size_big = nthr / grp_count + 1;
+ const int grp_size_small = nthr / grp_count;
+ const int n_grp_big = nthr % grp_count;
+ const int threads_in_big_groups = n_grp_big * grp_size_big;
+
+ const int ithr_bound_distance = ithr - threads_in_big_groups;
+ T grp, grp_ithr, grp_nthr;
+ if (ithr_bound_distance < 0) { // ithr in first groups
+ grp = ithr / grp_size_big;
+ grp_ithr = ithr % grp_size_big;
+ grp_nthr = grp_size_big;
+ } else { // ithr in last groups
+ grp = n_grp_big + ithr_bound_distance / grp_size_small;
+ grp_ithr = ithr_bound_distance % grp_size_small;
+ grp_nthr = grp_size_small;
}
+
balance211(nx, grp_count, grp, nx_start, nx_end);
balance211(ny, grp_nthr, grp_ithr, ny_start, ny_end);
}
}
/* convolution forward */
-template <bool with_relu, data_type_t src_type, data_type_t wei_type,
- data_type_t dst_type>
-void _jit_avx512_common_1x1_convolution_fwd_t
- <with_relu, src_type, wei_type, dst_type>::execute_forward()
-{
+template <data_type_t src_type, data_type_t wei_type, data_type_t dst_type>
+void jit_avx512_common_1x1_convolution_fwd_t<src_type, wei_type, dst_type>::
+execute_forward() const {
auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
auto weights =
reinterpret_cast<const wei_data_t *>(this->input_memory(1));
auto bias = reinterpret_cast<const dst_data_t *>(this->input_memory(2));
auto dst = reinterpret_cast<dst_data_t *>(this->memory());
+ auto scratchpad = this->scratchpad();
+
auto &jcp = kernel_->jcp;
- if (conf_.want_padded_bias()) {
- assert(jcp.ngroups == 1);
- for (int oc = 0; oc < jcp.oc_without_padding; ++oc)
- padded_bias_[oc] = bias[oc];
- bias = padded_bias_;
+ if (pd()->wants_padded_bias()) {
+ auto padded_bias = scratchpad.template get<dst_data_t>(
+ key_conv_padded_bias);
+ utils::array_copy(padded_bias, bias, jcp.oc_without_padding);
+ utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
+ jcp.oc - jcp.oc_without_padding);
+ bias = padded_bias;
}
parallel(0, [&](const int ithr, const int nthr) {
- execute_forward_thr(ithr, nthr, src, weights, bias, dst);
+ execute_forward_thr(ithr, nthr, src, weights, bias, dst, scratchpad);
});
+
+ if (pd()->wants_zero_pad_dst())
+ output_memory_primitive(0)->zero_pad();
}
-template <bool with_relu, data_type_t src_type, data_type_t wei_type,
- data_type_t dst_type>
-void _jit_avx512_common_1x1_convolution_fwd_t
- <with_relu, src_type, wei_type, dst_type>::execute_forward_thr(
- const int ithr, const int nthr,
- const src_data_t *src, const wei_data_t *weights,
- const dst_data_t *bias, dst_data_t *dst)
-{
- const memory_desc_wrapper src_d(conf_.src_pd());
- const memory_desc_wrapper dst_d(conf_.dst_pd());
- const memory_desc_wrapper weights_d(conf_.weights_pd(0));
+template <data_type_t src_type, data_type_t wei_type, data_type_t dst_type>
+void jit_avx512_common_1x1_convolution_fwd_t<src_type, wei_type, dst_type>::
+execute_forward_thr(const int ithr, const int nthr, const src_data_t *src,
+ const wei_data_t *weights, const dst_data_t *bias, dst_data_t *dst,
+ const memory_tracking::grantor_t &scratchpad) const {
+ const memory_desc_wrapper src_d(pd()->src_pd());
+ const memory_desc_wrapper dst_d(pd()->dst_pd());
+ const memory_desc_wrapper weights_d(pd()->weights_pd(0));
+
+ auto rtus_space = scratchpad.get<src_data_t>(key_conv_rtus_space);
const int ndims = src_d.ndims();
- const int stride_h = (ndims == 3) ? 1 : conf_.cdesc()->strides[0];
- const int stride_w = conf_.cdesc()->strides[ndims - 3];
- const int pad_t = (ndims == 3) ? 0 : conf_.cdesc()->padding[0][0];
- const int pad_l = conf_.cdesc()->padding[0][ndims - 3];
+ const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0];
+ const int stride_w = pd()->desc()->strides[ndims - 3];
+ const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0];
+ const int pad_l = pd()->desc()->padding[0][ndims - 3];
- auto &jcp = kernel_->jcp;
- const int MB = conf_.MB();
+ const auto &jcp = kernel_->jcp;
+ const int MB = pd()->MB();
const int work_amount = MB * jcp.ngroups * jcp.nb_bcast;
auto step = [](int default_step, int remaining, int tail_step) {
p.output_data = &dst[dst_off];
p.bias_data = &bias[_ocb * jcp.oc_block];
- p.load_data = &weights[conf_.with_groups()
+ p.load_data = &weights[pd()->with_groups()
? weights_d.blk_off(g, ocb, icb)
: weights_d.blk_off(ocb, icb)];
const int _icb = g * nb_ic + icb;
- if (conf_.rtus_.reduce_src_) {
- rp.ws = scratch_ + ithr * ws_per_thread_
+ if (pd()->rtus_.reduce_src_) {
+ rp.ws = rtus_space + ithr * pd()->rtus_.space_per_thread_
+ _icb * jcp.is * jcp.ic_block;
if (ocb == ocb_start) {
rp.src = src + data_blk_off(src_d, n, _icb, ih, iw);
}
-template struct _jit_avx512_common_1x1_convolution_fwd_t<true, data_type::f32>;
-template struct _jit_avx512_common_1x1_convolution_fwd_t<false, data_type::f32>;
-template struct _jit_avx512_common_1x1_convolution_fwd_t<false, data_type::s16,
- data_type::s16, data_type::s32>;
-template struct _jit_avx512_common_1x1_convolution_fwd_t<true, data_type::s16,
+template struct jit_avx512_common_1x1_convolution_fwd_t<data_type::f32>;
+template struct jit_avx512_common_1x1_convolution_fwd_t<data_type::s16,
data_type::s16, data_type::s32>;
/* convolution backward wtr data */
template <data_type_t diff_dst_type, data_type_t wei_type,
- data_type_t diff_src_type>
-void _jit_avx512_common_1x1_convolution_bwd_data_t
- <diff_dst_type, wei_type, diff_src_type>::execute_backward_data()
-{
+ data_type_t diff_src_type>
+void jit_avx512_common_1x1_convolution_bwd_data_t<diff_dst_type, wei_type,
+ diff_src_type>::execute_backward_data() const {
auto diff_dst = reinterpret_cast<const diff_dst_data_t *>
(this->input_memory(0));
auto weights = reinterpret_cast<const wei_data_t *>
(this->input_memory(1));
auto diff_src = reinterpret_cast<diff_src_data_t *>(this->memory());
- const memory_desc_wrapper diff_dst_d(conf_.diff_dst_pd());
- const memory_desc_wrapper weights_d(conf_.weights_pd(0));
- const memory_desc_wrapper diff_src_d(conf_.diff_src_pd());
+ const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
+ const memory_desc_wrapper weights_d(pd()->weights_pd(0));
+ const memory_desc_wrapper diff_src_d(pd()->diff_src_pd());
+
+ auto rtus_space = scratchpad().template get<diff_src_data_t>(
+ key_conv_rtus_space);
const int ndims = diff_src_d.ndims();
const auto &jcp = kernel_->jcp;
- const int MB = conf_.MB();
+ const int MB = pd()->MB();
// TODO (Roma): remove this restriction
assert(jcp.stride_w == 1 && jcp.stride_h == 1);
- const int stride_h = (ndims == 3) ? 1 : conf_.desc()->strides[0];
- const int stride_w = conf_.desc()->strides[ndims - 3];
- const int pad_t = (ndims == 3) ? 0 : conf_.desc()->padding[0][0];
- const int pad_l = conf_.desc()->padding[0][ndims - 3];
+ const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0];
+ const int stride_w = pd()->desc()->strides[ndims - 3];
+ const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0];
+ const int pad_l = pd()->desc()->padding[0][ndims - 3];
const int nb_ic = jcp.nb_load;
const int nb_oc = jcp.nb_reduce;
const int _icb = g * nb_ic + icb;
rp.src = diff_src + data_blk_off(diff_src_d, n, _icb, ih, iw);
- if (conf_.rtus_.reduce_src_) {
- rp.ws = scratch_ + ithr * ws_per_thread_;
+ if (pd()->rtus_.reduce_src_) {
+ rp.ws = rtus_space
+ + ithr * pd()->rtus_.space_per_thread_;
p.output_data = rp.ws;
} else
p.output_data = rp.src;
size_t diff_dst_off = data_blk_off(diff_dst_d, n, _ocb, oh, ow);
p.bcast_data = &diff_dst[diff_dst_off];
- p.load_data = &weights[conf_.with_groups()
+ p.load_data = &weights[pd()->with_groups()
? weights_d.blk_off(g, ocb, icb)
: weights_d.blk_off(ocb, icb)];
kernel_->jit_ker(&p);
}
- if (conf_.rtus_.reduce_src_)
+ if (pd()->rtus_.reduce_src_)
rtus_driver_->ker_(&rp);
}
}
});
}
-template struct _jit_avx512_common_1x1_convolution_bwd_data_t<data_type::f32>;
-template struct _jit_avx512_common_1x1_convolution_bwd_data_t<data_type::s16,
+template struct jit_avx512_common_1x1_convolution_bwd_data_t<data_type::f32>;
+template struct jit_avx512_common_1x1_convolution_bwd_data_t<data_type::s16,
data_type::s16, data_type::s32>;
/* convolution backward wtr weights */
#define wht_blk_off(d, g, ...) \
- (conf_.with_groups() \
+ (pd()->with_groups() \
? (d).blk_off((g), __VA_ARGS__) \
: (d).blk_off(__VA_ARGS__))
jit_avx512_common_1x1_convolution_bwd_weights_t ::
- jit_avx512_common_1x1_convolution_bwd_weights_t(const pd_t *pd,
+ jit_avx512_common_1x1_convolution_bwd_weights_t(const pd_t *apd,
const input_vector &inputs, const output_vector &outputs)
- : cpu_primitive_t(&conf_, inputs, outputs)
- , conf_(*pd), kernel_(nullptr), acc_ker_(nullptr), reducer_bias_(nullptr)
- , trans_kernel_(nullptr), rtus_driver_(nullptr), ws_per_thread_(0)
- , scratch_(nullptr), padded_bias_(nullptr), bctx_(nullptr)
- , tr_src_(nullptr), ws_reduction_(nullptr)
+ : cpu_primitive_t(apd, inputs, outputs)
+ , kernel_(nullptr), acc_ker_(nullptr), reducer_bias_(nullptr)
+ , trans_kernel_(nullptr), rtus_driver_(nullptr)
{
- kernel_ = new jit_avx512_common_1x1_conv_kernel(conf_.jcp_, *conf_.attr());
-
- const auto &jcp = kernel_->jcp;
-
- const int wei_size = jcp.ngroups * jcp.oc * jcp.ic;
- ws_reduction_ =
- (data_t *)malloc((jcp.nthr_mb - 1) * wei_size * sizeof(data_t), 64);
+ kernel_ = new jit_avx512_common_1x1_conv_kernel(pd()->jcp_, *pd()->attr());
acc_ker_ = new cpu_accumulator_1d_t<data_type::f32>();
+ reducer_bias_ = new cpu_reducer_t<data_type::f32>(pd()->reducer_bia_conf_);
+ init_rtus_driver<avx512_common>(this);
- if (conf_.with_bias()) {
- const size_t max_buffer_size = jcp.nthr * 3 * 5 * 5 * 16 * 16;
- reducer_bias_ = new cpu_reducer_t<data_type::f32>(
- reduce_balancer_t(jcp.nthr, jcp.oc_block,
- jcp.ngroups * jcp.nb_load, jcp.mb, max_buffer_size));
-
- if (conf_.want_padded_bias()) {
- assert(jcp.ngroups == 1);
- padded_bias_ = (data_t *)malloc(sizeof(data_t) * jcp.oc, 64);
- }
- }
+ const auto &jcp = kernel_->jcp;
if (jcp.transpose_src) {
- const ptrdiff_t tr_src_size = (ptrdiff_t)jcp.nthr_mb
- * (ptrdiff_t)jcp.ngroups * (ptrdiff_t)jcp.ic * jcp.tr_is;
- tr_src_ = (data_t *)malloc(tr_src_size * sizeof(data_t), 64);
- parallel_nd(tr_src_size, [&](ptrdiff_t i) { tr_src_[i] = 0; });
auto tp = jit_transpose4x16_src_t();
tp.src_pf0_distance = 4;
tp.tr_src_pf0_distance = 0;
tp.src_pf1 = true;
tp.tr_src_pf1 = false;
trans_kernel_ = new jit_transpose4x16_src(&jcp, &tp);
-
- bctx_ = (simple_barrier::ctx_t *)malloc(
- jcp.nthr * sizeof(simple_barrier::ctx_t), 64);
- for (int i = 0; i < jcp.nthr; ++i)
- simple_barrier::ctx_init(&bctx_[i]);
}
-
- init_rtus_driver<avx512_common>(this);
}
-void jit_avx512_common_1x1_convolution_bwd_weights_t::execute_backward_weights()
+void jit_avx512_common_1x1_convolution_bwd_weights_t::execute_backward_weights() const
{
auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
auto diff_weights = reinterpret_cast<data_t *>(this->memory(0));
auto diff_bias_in = reinterpret_cast<data_t *>(this->memory(1));
- data_t *diff_bias = conf_.want_padded_bias() ? padded_bias_ : diff_bias_in;
- const memory_desc_wrapper diff_dst_d(conf_.diff_dst_pd());
- const memory_desc_wrapper src_d(conf_.src_pd());
- const memory_desc_wrapper diff_weights_d(conf_.diff_weights_pd(0));
+ const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
+ const memory_desc_wrapper src_d(pd()->src_pd());
+ const memory_desc_wrapper diff_weights_d(pd()->diff_weights_pd(0));
const auto &jcp = kernel_->jcp;
+
+ const auto scratchpad = this->scratchpad();
+
+ auto rtus_space = scratchpad.get<data_t>(key_conv_rtus_space);
+ data_t *diff_bias = pd()->wants_padded_bias()
+ ? scratchpad.get<data_t>(key_conv_padded_bias) : diff_bias_in;
+ auto wei_reduction = scratchpad.get<data_t>(key_conv_wei_reduction);
+
+ /* prepare src transposition barriers */
+ auto tr_src = scratchpad.get<data_t>(key_conv_tr_src);
+ auto tr_src_bctx = scratchpad.get<simple_barrier::ctx_t>(
+ key_conv_tr_src_bctx);
+ if (jcp.transpose_src) {
+ for (int i = 0; i < jcp.nthr; ++i)
+ simple_barrier::ctx_init(&tr_src_bctx[i]);
+ }
+
const int ndims = src_d.ndims();
const int wei_size = jcp.ngroups * jcp.oc * jcp.ic;
simple_barrier::ctx_t reduction_barrier;
simple_barrier::ctx_init(&reduction_barrier);
+ const auto reducer_bia_scratchpad = memory_tracking::grantor_t(scratchpad,
+ prefix_reducer_bia);
+ auto rb = this->reducer_bias_;
+ rb->init(reducer_bia_scratchpad);
+
// TODO (Roma): remove this restriction
assert(jcp.stride_w == 1 && jcp.stride_h == 1);
const int sp_nb = jcp.nb_reduce;
const int mb_sp_work = jcp.mb * sp_nb;
- const int stride_h = (ndims == 3) ? 1 : conf_.desc()->strides[0];
- const int stride_w = conf_.desc()->strides[ndims - 3];
- const int pad_t = (ndims == 3) ? 0 : conf_.desc()->padding[0][0];
- const int pad_l = conf_.desc()->padding[0][ndims - 3];
+ const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0];
+ const int stride_w = pd()->desc()->strides[ndims - 3];
+ const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0];
+ const int pad_l = pd()->desc()->padding[0][ndims - 3];
auto step = [](int default_step, int remaining, int tail_step) {
assert(default_step <= tail_step);
const int src1_off = data_blk_off(src_d, img, _ic, ih, iw);
data_t *src1 = (data_t *)&src[src1_off];
- data_t *tr_src1 = &tr_src_[tr_src_off(ithr_mb, ic_b_tr, is)];
+ data_t *tr_src1 = &tr_src[tr_src_off(ithr_mb, ic_b_tr, is)];
assert(jcp.ic_block == 16);
const int src_stride = jcp.is * jcp.ic_block;
const int oc_b_work = oc_b_end - oc_b_start;
const int ic_b_work = ic_b_end - ic_b_start;
- data_t *diff_wei = ithr_mb == 0 ?
- diff_weights :
- ws_reduction_ + (ithr_mb - 1) * wei_size;
+ data_t *diff_wei = ithr_mb == 0
+ ? diff_weights : wei_reduction + (ithr_mb - 1) * wei_size;
int sp_b_step = 0;
for (int mb_sp_b = mb_sp_b_start; mb_sp_b < mb_sp_b_end;
if (jcp.transpose_src) {
if (jcp.nthr_oc_b > 1)
simple_barrier::barrier(
- &bctx_[ithr_but_oc], jcp.nthr_oc_b);
+ &tr_src_bctx[ithr_but_oc], jcp.nthr_oc_b);
const int sp_size
= nstl::min(sp_b_step * jcp.reduce_block,
jcp.is - sp_b * jcp.reduce_block);
bcast_step, ithr_oc_b, jcp.nthr_oc_b, ic_b_start);
if (jcp.nthr_oc_b > 1)
simple_barrier::barrier(
- &bctx_[ithr_but_oc], jcp.nthr_oc_b);
+ &tr_src_bctx[ithr_but_oc], jcp.nthr_oc_b);
}
for (int oc_b = oc_b_start; oc_b < oc_b_end;
store_to = diff_wei + off;
const data_t *diff_src = jcp.transpose_src ?
- &tr_src_[tr_src_off(ithr_mb, _ic_b_tr, 0)] :
+ &tr_src[tr_src_off(ithr_mb, _ic_b_tr, 0)] :
&src[src_d.blk_off(img, _ic_b)];
int sp_b_end = sp_b + sp_b_step;
int sp = sp_b * jcp.reduce_block;
p.load_data = pdiff_dst + sp * jcp.oc_block;
- if (conf_.rtus_.reduce_src_) {
+ if (pd()->rtus_.reduce_src_) {
const int oh = sp / jcp.ow;
const int ow = sp % jcp.ow;
const int iw = nstl::max(ow * stride_w - pad_l, 0);
rp.iw_start = iw;
- rp.ws = scratch_ + ithr * ws_per_thread_
- + sp * jcp.ic_block;
+ rp.ws = rtus_space
+ + ithr * pd()->rtus_.space_per_thread_
+ + sp * jcp.ic_block;
if (ndims == 3)
rp.src = local_src + iw
}
}
- /* diff_weights[:] += sum(ws_reduction_[thr_mb][:]) */
+ /* diff_weights[:] += sum(wei_reduction[thr_mb][:]) */
if (jcp.nthr_mb > 1) {
simple_barrier::barrier(&reduction_barrier, jcp.nthr);
const int work = g_work * oc_b_work * ic_b_work;
const size_t off
= wht_blk_off(diff_weights_d, g, oc_b, ic_b);
data_t *d = diff_weights + off;
- data_t *s = ws_reduction_ + (thr_mb - 1) * wei_size + off;
+ data_t *s = wei_reduction + (thr_mb - 1) * wei_size + off;
acc_ker_->accumulate(d, s, acc_size);
};
auto ker_bias = [&](int ithr, int nthr) {
- auto rb = this->reducer_bias_;
- assert(nthr == rb->balancer_.nthr_);
+ assert(nthr == rb->balancer().nthr_);
- const int b_job_start = rb->balancer_.ithr_job_off(ithr);
- const int b_njobs = rb->balancer_.ithr_njobs(ithr);
+ const int b_job_start = rb->balancer().ithr_job_off(ithr);
+ const int b_njobs = rb->balancer().ithr_njobs(ithr);
if (b_njobs == 0)
return;
/* reduction dimension */
int img_start{ 0 }, img_end{ 0 };
- balance211(jcp.mb, rb->balancer_.nthr_per_group_,
- rb->balancer_.id_in_group(ithr), img_start, img_end);
+ balance211(jcp.mb, rb->balancer().nthr_per_group_,
+ rb->balancer().id_in_group(ithr), img_start, img_end);
/* jobs */
int g_start{ 0 }, ocb_start{ 0 };
const size_t _oc = g * jcp.nb_load + ocb;
const data_t *d_dst = &diff_dst[diff_dst_d.blk_off(img, _oc)];
- data_t *d_bias = &rb->get_local_ptr(
- ithr, diff_bias)[b_job_loc * rb->balancer_.job_size_];
+ data_t *d_bias = rb->get_local_ptr(ithr, diff_bias,
+ reducer_bia_scratchpad)
+ + b_job_loc * rb->balancer().job_size_;
if (img == img_start)
for (int o = 0; o < 16; ++o)
nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_load);
}
}
- rb->reduce(ithr, diff_bias);
+ rb->reduce(ithr, diff_bias, reducer_bia_scratchpad);
};
parallel(jcp.nthr, [&](const int ithr, const int nthr) {
ker(ithr, jcp.nthr);
- if (conf_.with_bias())
+ if (pd()->with_bias())
ker_bias(ithr, jcp.nthr);
});
/* TODO: put this in ker_bias */
- if (conf_.want_padded_bias()) {
+ if (pd()->wants_padded_bias()) {
assert(jcp.ngroups == 1);
- for (int oc = 0; oc < jcp.oc_without_padding; ++oc)
- diff_bias_in[oc] = diff_bias[oc];
+ utils::array_copy(diff_bias_in, diff_bias, jcp.oc_without_padding);
}
}