* limitations under the License.
*******************************************************************************/
-#include <cstring>
-#include "mkldnn_types.h"
-
#include "c_types_map.hpp"
-#include "jit_avx2_convolution.hpp"
-#include "utils.hpp"
#include "mkldnn_thread.hpp"
#include "type_helpers.hpp"
+#include "utils.hpp"
+#include <cstring>
+
+#include "jit_avx2_convolution.hpp"
namespace mkldnn {
namespace impl {
using namespace mkldnn::impl::status;
using namespace mkldnn::impl::memory_format;
+using namespace mkldnn::impl::memory_tracking::names;
using namespace mkldnn::impl::utils;
-
#define src_blk_off(f, n, c, d, h, w) \
- (conf_.ndims() == 3) \
+ (pd()->ndims() == 3) \
? (f).blk_off(n, c, w) \
- : (conf_.ndims() == 4) \
+ : (pd()->ndims() == 4) \
? (f).blk_off(n, c, h, w) \
: (f).blk_off(n, c, d, h, w)
#define wht_blk_off_(f, g, ...) \
- conf_.with_groups() ? (f).blk_off(g, __VA_ARGS__) : (f).blk_off(__VA_ARGS__)
+ pd()->with_groups() ? (f).blk_off(g, __VA_ARGS__) : (f).blk_off(__VA_ARGS__)
#define wht_blk_off(f, g, oc, ic, kd, kh, kw) \
- (conf_.ndims() == 3) \
+ (pd()->ndims() == 3) \
? wht_blk_off_(f, g, oc, ic, kw) \
- : (conf_.ndims() == 4) \
+ : (pd()->ndims() == 4) \
? wht_blk_off_(f, g, oc, ic, kh, kw) \
: wht_blk_off_(f, g, oc, ic, kd, kh, kw)
-template <bool with_relu>
-void _jit_avx2_convolution_fwd_t<with_relu>::execute_forward() {
+void jit_avx2_convolution_fwd_t::execute_forward() const {
auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
auto bias = reinterpret_cast<const data_t *>(this->input_memory(2));
auto dst = reinterpret_cast<data_t *>(this->memory());
- const memory_desc_wrapper src_d(conf_.src_pd());
- const memory_desc_wrapper dst_d(conf_.dst_pd());
- const memory_desc_wrapper weights_d(conf_.weights_pd(0));
- const memory_desc_wrapper bias_d(conf_.weights_pd(1));
+ const memory_desc_wrapper src_d(pd()->src_pd());
+ const memory_desc_wrapper dst_d(pd()->dst_pd());
+ const memory_desc_wrapper weights_d(pd()->weights_pd(0));
+ const memory_desc_wrapper bias_d(pd()->weights_pd(1));
const auto &jcp = kernel_->jcp;
- const int MB = conf_.MB();
+ const int MB = pd()->MB();
int ocb_work = div_up(jcp.nb_oc, jcp.nb_oc_blocking);
const size_t work_amount = MB * jcp.ngroups * ocb_work * jcp.od
int ocb_num = jcp.nb_oc_blocking;
for (int icb = icbb; icb < icbb + icb_step; ++icb) {
- jit_conv_call_s par_conv = {};
+ auto par_conv = jit_conv_call_s();
const int ij = oh * jcp.stride_h;
const int i_t_overflow = nstl::max(0, jcp.t_pad - ij);
+ (jcp.kd-1) * (jcp.dilate_d+1) - jcp.f_pad+1) - jcp.id;
const size_t _oc = g * jcp.nb_oc + ocb;
- const size_t _ic = g * jcp.nb_ic + icb;
+ const size_t _ic = g * jcp.nb_ic * jcp.nonblk_group_off + icb;
const int ih = nstl::max(ij - jcp.t_pad
+ div_up(i_t_overflow,
}
};
- if (conf_.want_padded_bias()) {
- for (int oc = 0; oc < jcp.oc_without_padding; ++oc)
- padded_bias_[oc] = bias[oc];
- bias = padded_bias_;
+ if (pd()->wants_padded_bias()) {
+ auto padded_bias = scratchpad().get<data_t>(key_conv_padded_bias);
+ utils::array_copy(padded_bias, bias, jcp.oc_without_padding);
+ utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
+ jcp.oc - jcp.oc_without_padding);
+ bias = padded_bias;
}
parallel(0, ker);
+
+ if (pd()->wants_zero_pad_dst())
+ output_memory_primitive(0)->zero_pad();
}
-template <bool with_relu>
-void _jit_avx2_convolution_fwd_t<with_relu>::execute_forward_fusing() {
+void jit_avx2_convolution_fwd_t::execute_forward_with_dw_conv() const {
auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
auto bias = reinterpret_cast<const data_t *>(this->input_memory(2));
auto dst = reinterpret_cast<data_t *>(this->memory());
- const memory_desc_wrapper src_d(conf_.src_pd());
- const memory_desc_wrapper weights_d(conf_.weights_pd(0));
- const memory_desc_wrapper bias_d(conf_.weights_pd(1));
+ const memory_desc_wrapper src_d(pd()->src_pd());
+ const memory_desc_wrapper weights_d(pd()->weights_pd(0));
+ const memory_desc_wrapper bias_d(pd()->weights_pd(1));
const auto &jcp = kernel_->jcp;
const auto &jcp_dw = kernel_dw_->jcp;
- const int MB = conf_.MB();
+ const int MB = pd()->MB();
- auto dw_bias = jcp.dw_conv_biases;
+ auto dw_bias = jcp_dw.conv_biases;
int ocb_work = div_up(jcp.nb_oc, jcp.nb_oc_blocking);
const size_t work_amount = MB * jcp.ngroups * ocb_work * jcp.oh;
for (int h = 0; h < num_rows; h++) {
if ((oh + h) < 0 || (oh + h) >= jcp.oh) {
for (int chb = ocb; chb < ocb + ocb_num; chb++) {
- memset(ws_p + (((oh + h) + 1) % jcp.dw_conv_ker_h) * jcp.ow * jcp.oc_block +
- (chb - ocb) * jcp.dw_conv_ker_h * jcp.ow * jcp.oc_block, 0, jcp.ow * jcp.oc_block * sizeof(float));
+ memset(ws_p + (((oh + h) + 1) % jcp_dw.kh) * jcp.ow * jcp.oc_block +
+ (chb - ocb) * jcp_dw.kh * jcp.ow * jcp.oc_block, 0, jcp.ow * jcp.oc_block * sizeof(float));
}
} else {
for (int icb = 0; icb < jcp.nb_ic; ++icb) {
par_conv.src = &src[src_d.blk_off(n,
jcp.ic == 3 ? 0 : _ic, ih, 0)];
- par_conv.dst = &ws_p[(((oh + h) + 1) % jcp.dw_conv_ker_h) * jcp.ow *
+ par_conv.dst = &ws_p[(((oh + h) + 1) % jcp_dw.kh) * jcp.ow *
jcp.oc_block];
const int wh = div_up(i_t_overflow, (jcp.dilate_h + 1));
- par_conv.filt = &weights[conf_.with_groups()
+ par_conv.filt = &weights[pd()->with_groups()
? weights_d.blk_off(g, ocb,
jcp.ic == 3 ? 0 : icb, wh, 0)
: weights_d.blk_off(ocb,
dst_idx/jcp_dw.stride_h*jcp_dw.ow*jcp_dw.ch_block];
par_conv_dw.kh_padding = jcp_dw.kh;
- par_conv_dw.filt = &jcp.dw_conv_weights[chb * jcp_dw.kh * jcp_dw.kw * jcp_dw.ch_block];
+ par_conv_dw.filt = &jcp_dw.conv_weights[chb * jcp_dw.kh * jcp_dw.kw * jcp_dw.ch_block];
par_conv_dw.bias = &dw_bias[chb * jcp_dw.ch_block];
par_conv_dw.ur_w = (size_t)(jcp_dw.ow);
+ par_conv_dw.oc_work = nstl::min((chb + 1) * jcp_dw.ch_block, (int)jcp_dw.oc) - chb*jcp_dw.ch_block;
+ par_conv_dw.oc_off = chb * jcp_dw.ch_block * sizeof(float);
kernel_dw_->jit_ker(&par_conv_dw);
}
size_t start{0}, end{0};
balance211(work_amount, nthr, ithr, start, end);
- auto pbuf = dw_conv_buffer_ + ithr * dw_conv_buffer_size_;
+ auto dw_conv_buffer = scratchpad().get<data_t>(key_dw_conv_buffer);
+ size_t dw_conv_buffer_size_ = (size_t)jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block * jcp.nb_oc_blocking;
+ auto pbuf = dw_conv_buffer + ithr * dw_conv_buffer_size_;
size_t n{0}, g{0}, ocbb{0}, oh{0};
nd_iterator_init(start, n, MB, g, jcp.ngroups, ocbb, ocb_work,
}
};
- if (conf_.want_padded_bias()) {
- for (int oc = 0; oc < jcp.oc_without_padding; ++oc)
- padded_bias_[oc] = bias[oc];
- bias = padded_bias_;
-
- for (int oc = 0; oc < jcp.oc_without_padding; ++oc)
- dw_padded_bias_[oc] = dw_bias[oc];
- dw_bias = dw_padded_bias_;
+ if (pd()->wants_padded_bias()) {
+ auto padded_bias = scratchpad().get<data_t>(key_conv_padded_bias);
+ utils::array_copy(padded_bias, bias, jcp.oc_without_padding);
+ utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
+ jcp.oc - jcp.oc_without_padding);
+ bias = padded_bias;
+
+ auto dw_padded_bias = scratchpad().get<data_t>(key_dw_conv_padded_bias);
+ utils::array_copy(dw_padded_bias, dw_bias, jcp.oc_without_padding);
+ utils::array_set(dw_padded_bias + jcp.oc_without_padding, 0.f,
+ jcp.oc - jcp.oc_without_padding);
+ dw_bias = dw_padded_bias;
}
parallel(0, ker);
-}
-template void _jit_avx2_convolution_fwd_t<true>::execute_forward();
-template void _jit_avx2_convolution_fwd_t<false>::execute_forward();
-template void _jit_avx2_convolution_fwd_t<true>::execute_forward_fusing();
-template void _jit_avx2_convolution_fwd_t<false>::execute_forward_fusing();
+ if (pd()->wants_zero_pad_dst())
+ output_memory_primitive(0)->zero_pad();
+}
-void jit_avx2_convolution_bwd_data_t::execute_backward_data() {
+void jit_avx2_convolution_bwd_data_t::execute_backward_data() const {
auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(0));
auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
auto diff_src = reinterpret_cast<data_t *>(this->memory());
- const memory_desc_wrapper diff_dst_d(conf_.diff_dst_pd());
- const memory_desc_wrapper diff_src_d(conf_.diff_src_pd());
- const memory_desc_wrapper weights_d(conf_.weights_pd(0));
+ const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
+ const memory_desc_wrapper diff_src_d(pd()->diff_src_pd());
+ const memory_desc_wrapper weights_d(pd()->weights_pd(0));
const auto &jcp = kernel_->jcp;
- const int MB = conf_.MB();
+ const int MB = pd()->MB();
int icb_work = jcp.nb_ic / jcp.nb_ic_blocking;
- const size_t work_amount = MB * jcp.ngroups * icb_work * jcp.ih;
+ int ih_block_size = jcp.ih;
+ int num_ih_blocks = utils::div_up(jcp.ih, ih_block_size);
+ size_t work_amount = MB * jcp.ngroups * icb_work * num_ih_blocks;
+ if (work_amount < (size_t)2 * mkldnn_get_max_threads()) {
+ ih_block_size = 1;
+ num_ih_blocks = utils::div_up(jcp.ih, ih_block_size);
+ work_amount *= num_ih_blocks;
+ }
auto ker = [&](const int ithr, const int nthr) {
size_t start{0}, end{0};
balance211(work_amount, nthr, ithr, start, end);
- size_t n{0}, g{0}, icbb{0}, ih{0};
- nd_iterator_init(start, n, MB, g, jcp.ngroups, icbb, icb_work, ih, jcp.ih);
+ size_t n{0}, g{0}, icbb{0}, ihb{0};
+ nd_iterator_init(start, n, MB, g, jcp.ngroups, icbb, icb_work,
+ ihb, num_ih_blocks);
+
for (size_t iwork = start; iwork < end; ++iwork) {
- for (int oc = 0; oc < jcp.nb_oc; ++oc)
+ for (int oc = 0; oc < jcp.nb_oc; oc += jcp.nb_oc_blocking)
for (int id = 0; id < jcp.id; ++id) {
auto par_conv = jit_conv_call_s();
const int idp = jcp.id + 2 * jcp.f_pad;
const int d_t_overflow = nstl::max(0,
- jcp.kd - 1 - id - jcp.f_pad);
+ jcp.kd - 1 - id - jcp.f_pad);
const int back_pad = idp - jcp.id - jcp.f_pad;
const int d_b_overflow = nstl::max(0,
- jcp.kd - 1 - (jcp.id - 1 - id) - back_pad);
+ jcp.kd - 1 - (jcp.id - 1 - id) - back_pad);
const int od = id + jcp.f_pad - d_b_overflow;
- const int simd_w = 8;
-
- const int i_t_overflow = nstl::max(0,
- jcp.kh - 1 - (int)ih - jcp.t_pad);
- const int b_pad = jcp.ihp - jcp.ih - jcp.t_pad;
- const int i_b_overflow = nstl::max(0,
- jcp.kh - 1 - (jcp.ih - 1 - (int)ih) - b_pad);
- int oh = ih + jcp.t_pad - i_b_overflow;
-
- int stride_off_h = oh % jcp.stride_h;
- oh /= jcp.stride_h;
-
- par_conv.src = &diff_src[src_blk_off(diff_src_d, n,
- /*jcp.ic == 3 ? 0 :*/
- g * jcp.nb_ic + jcp.nb_ic_blocking * icbb, id, ih, 0)];
- par_conv.dst = &diff_dst[src_blk_off(diff_dst_d,
- n, g * jcp.nb_oc + oc, od, oh, 0)];
- par_conv.filt = &weights[wht_blk_off(weights_d, g, oc,
- jcp.ic == 3 ? 0 : jcp.nb_ic_blocking * icbb,
- d_b_overflow, i_b_overflow + stride_off_h, 0)];
-
- par_conv.src_prf = nullptr;
- par_conv.dst_prf = nullptr;
- par_conv.filt_prf = nullptr;
- // TODO: move initialization into the kernel
- if (oc == 0) {
- for (int iw = 0; iw < jcp.iw; iw++) {
- for (int b = 0; b < jcp.nb_ic_blocking; b++) {
- int current_ic =
- (jcp.ic == 3 ? 0 : g * jcp.nb_ic)
- + jcp.nb_ic_blocking * icbb + b;
- int current_idx =
- src_blk_off(diff_src_d, n, current_ic,
- id, ih, iw);
- for (int v = 0; v < simd_w; v++)
- diff_src[current_idx + v] = 0.0;
- }
- }
- }
+ int ih_start = ihb * ih_block_size;
+ int ih_end = nstl::min(jcp.ih, ih_start + ih_block_size);
+ for (int ih = ih_start; ih < ih_end; ++ih) {
+
+ const int i_t_overflow = nstl::max(0, (jcp.kh - 1
+ - ih - jcp.t_pad) / jcp.stride_h);
+ const int i_b_overflow = nstl::max(0, (jcp.kh - jcp.ih
+ + ih - jcp.b_pad) / jcp.stride_h);
+ int overflow_kh_hi = jcp.kh - 1 - abs((jcp.ih - 1
+ + jcp.b_pad - ih) % jcp.stride_h);
+ int overflow_kh_lo = (ih + jcp.t_pad) % jcp.stride_h;
+
+ par_conv.kd_padding = jcp.kd - d_t_overflow - d_b_overflow;
+ par_conv.kh_padding = (overflow_kh_hi - overflow_kh_lo)
+ / jcp.stride_h + 1 - i_t_overflow - i_b_overflow;
+ par_conv.kw_padding = 0;
- par_conv.kd_padding = jcp.kd - d_t_overflow - d_b_overflow;
- par_conv.kh_padding = nstl::max(0, jcp.kh - i_t_overflow - i_b_overflow - stride_off_h);
- par_conv.kw_padding = 0;
+ const int k_lo = overflow_kh_lo
+ + i_b_overflow * jcp.stride_h;
+ const int oh = (ih + jcp.t_pad - k_lo) / jcp.stride_h;
+
+ par_conv.src = &diff_src[src_blk_off(diff_src_d, n,
+ /*jcp.ic == 3 ? 0 :*/
+ g * jcp.nb_ic + jcp.nb_ic_blocking * icbb, id, ih, 0)];
+ par_conv.dst = &diff_dst[src_blk_off(diff_dst_d,
+ n, g * jcp.nb_oc + oc, od, oh, 0)];
+ par_conv.filt = &weights[wht_blk_off(weights_d, g, oc,
+ jcp.ic == 3 ? 0 : jcp.nb_ic_blocking * icbb,
+ d_b_overflow, k_lo, 0)];
+
+ par_conv.src_prf = nullptr;
+ par_conv.dst_prf = nullptr;
+ par_conv.filt_prf = nullptr;
+ par_conv.channel = oc;
+ par_conv.ch_blocks = nstl::min(jcp.nb_oc - oc,
+ jcp.nb_oc_blocking);
- if (par_conv.kh_padding > 0)
kernel_->jit_ker(&par_conv);
+ }
}
- nd_iterator_step(n, MB, g, jcp.ngroups, icbb, icb_work, ih, jcp.ih);
+ nd_iterator_step(n, MB, g, jcp.ngroups, icbb, icb_work, ihb,
+ num_ih_blocks);
}
};
parallel(0, ker);
}
-void jit_avx2_convolution_bwd_weights_t::execute_backward_weights() {
+void jit_avx2_convolution_bwd_weights_t::execute_backward_weights() const {
auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
auto diff_weights = reinterpret_cast<data_t *>(this->memory(0));
auto diff_bias_in = reinterpret_cast<data_t *>(this->memory(1));
- data_t *diff_bias = conf_.want_padded_bias() ? padded_bias_ : diff_bias_in;
- const memory_desc_wrapper src_d(conf_.src_pd(0));
- const memory_desc_wrapper diff_dst_d(conf_.diff_dst_pd());
- const memory_desc_wrapper diff_weights_d(conf_.diff_weights_pd(0));
+ auto scratchpad = this->scratchpad();
+
+ data_t *diff_bias = pd()->wants_padded_bias()
+ ? scratchpad.get<data_t>(key_conv_padded_bias) : diff_bias_in;
+
+ const memory_desc_wrapper src_d(pd()->src_pd(0));
+ const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
+ const memory_desc_wrapper diff_weights_d(pd()->diff_weights_pd(0));
const auto &jcp = kernel_->jcp;
+ auto reducer_bia_scratchpad = memory_tracking::grantor_t(scratchpad,
+ prefix_reducer_bia);
+ auto rb = this->reducer_bias_;
+ rb->init(reducer_bia_scratchpad);
+
+ auto reducer_wei_scratchpad = memory_tracking::grantor_t(scratchpad,
+ prefix_reducer_wei);
+ auto rw = this->reducer_weights_;
+ rw->init(reducer_wei_scratchpad);
+
auto ker = [&](int ithr, int nthr) {
- auto rw = this->reducer_weights_;
- assert(nthr == rw->balancer_.nthr_);
+ assert(nthr == rw->balancer().nthr_);
- const int w_job_start = rw->balancer_.ithr_job_off(ithr);
- const int w_njobs = rw->balancer_.ithr_njobs(ithr);
+ const int w_job_start = rw->balancer().ithr_job_off(ithr);
+ const int w_njobs = rw->balancer().ithr_njobs(ithr);
if (w_njobs == 0) return;
/* reduction dimension */
int img_od_start{0}, img_od_end{0}, img{0}, od_s{0};
- balance211(jcp.mb * jcp.od, rw->balancer_.nthr_per_group_,
- rw->balancer_.id_in_group(ithr), img_od_start, img_od_end);
+ balance211(jcp.mb * jcp.od, rw->balancer().nthr_per_group_,
+ rw->balancer().id_in_group(ithr), img_od_start, img_od_end);
int img_start = img_od_start, img_end = img_od_end;
nd_iterator_init(img_start, img, jcp.mb, od_s, jcp.od);
/* TODO: put dw <-- 0 in kernel */
if (img == img_first)
- array_set((data_t *)&rw->get_local_ptr(ithr, diff_weights)[
- w_job_loc * rw->balancer_.job_size_], 0,
- rw->balancer_.job_size_);
+ array_set(rw->get_local_ptr(ithr, diff_weights,
+ reducer_wei_scratchpad) +
+ w_job_loc * rw->balancer().job_size_, 0,
+ rw->balancer().job_size_);
for (int od = od_s; od < od_e; ++od) {
const int id = od * jcp.stride_d;
par_conv.src = &src[src_blk_off(src_d, img, _ic, id, 0, 0)];
par_conv.dst =
&diff_dst[src_blk_off(diff_dst_d, img, _oc, od, 0, 0)];
- par_conv.filt = &rw->get_local_ptr(ithr, diff_weights)[
- w_job_loc * rw->balancer_.job_size_];
+ par_conv.filt = rw->get_local_ptr(ithr, diff_weights,
+ reducer_wei_scratchpad) +
+ w_job_loc * rw->balancer().job_size_;
kernel_->jit_ker(&par_conv);
}
}
nd_iterator_jump(img_start, img_end, img, jcp.mb, od_s, jcp.od);
}
- rw->reduce(ithr, diff_weights);
+ rw->reduce(ithr, diff_weights, reducer_wei_scratchpad);
};
auto ker_bias = [&](int ithr, int nthr) {
- auto rb = this->reducer_bias_;
- assert(nthr == rb->balancer_.nthr_);
+ assert(nthr == rb->balancer().nthr_);
- const int b_job_start = rb->balancer_.ithr_job_off(ithr);
- const int b_njobs = rb->balancer_.ithr_njobs(ithr);
+ const int b_job_start = rb->balancer().ithr_job_off(ithr);
+ const int b_njobs = rb->balancer().ithr_njobs(ithr);
if (b_njobs == 0) return;
/* reduction dimension */
int img_start{0}, img_end{0};
- balance211(jcp.mb, rb->balancer_.nthr_per_group_,
- rb->balancer_.id_in_group(ithr), img_start, img_end);
+ balance211(jcp.mb, rb->balancer().nthr_per_group_,
+ rb->balancer().id_in_group(ithr), img_start, img_end);
/* jobs */
int g_start{0}, ocb_start{0};
const size_t _oc = g * jcp.nb_oc + ocb;
const data_t *d_dst = &diff_dst[diff_dst_d.blk_off(img, _oc)];
- data_t *d_bias = &rb->get_local_ptr(ithr, diff_bias)[
- b_job_loc * rb->balancer_.job_size_];
+ data_t *d_bias = rb->get_local_ptr(ithr, diff_bias,
+ reducer_bia_scratchpad) +
+ b_job_loc * rb->balancer().job_size_;
if (img == img_start)
for (int o = 0; o < 8; ++o)
nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_oc);
}
}
- rb->reduce(ithr, diff_bias);
+ rb->reduce(ithr, diff_bias, reducer_bia_scratchpad);
};
-
parallel(0, [&](const int ithr, const int nthr) {
ker(ithr, nthr);
- if (conf_.with_bias())
+ if (pd()->with_bias())
ker_bias(ithr, nthr);
});
/* TODO: put this in ker_bias */
- if (conf_.want_padded_bias()) {
+ if (pd()->wants_padded_bias()) {
assert(jcp.ngroups == 1);
for (int oc = 0; oc < jcp.oc_without_padding; ++oc)
diff_bias_in[oc] = diff_bias[oc];