* limitations under the License.
*******************************************************************************/
-#include <common/primitive_attr.hpp>
#include "mkldnn_types.h"
#include "c_types_map.hpp"
#include "utils.hpp"
#include "type_helpers.hpp"
#include "mkldnn_thread.hpp"
-
#include "ref_eltwise.hpp"
namespace mkldnn {
using namespace mkldnn::impl::status;
using namespace mkldnn::impl::memory_format;
+using namespace mkldnn::impl::memory_tracking::names;
using namespace mkldnn::impl::utils;
-template <bool with_relu>
-void _gemm_convolution_fwd_t<with_relu>::execute_forward() {
+void gemm_convolution_fwd_t::execute_forward() const {
auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
auto bias = reinterpret_cast<const data_t *>(this->input_memory(2));
auto dst = reinterpret_cast<data_t*>(this->memory());
- jit_gemm_conv_conf_t &jcp = this->conf_.jcp_;
- const int MB = conf_.MB();
+ auto col = scratchpad().get<data_t>(key_conv_gemm_col);
+
+ const auto &jcp = this->pd()->jcp_;
+ const int MB = pd()->MB();
- const memory_desc_wrapper src_d(conf_.src_pd());
- const memory_desc_wrapper dst_d(conf_.dst_pd());
+ const memory_desc_wrapper src_d(pd()->src_pd());
+ const memory_desc_wrapper dst_d(pd()->dst_pd());
const int M = jcp.os * jcp.od;
const size_t src_step = (src_d.blk_off(1) - src_d.off_l(0)) / jcp.ngroups;
src += src_d.off_l(0);
dst += dst_d.off_l(0);
+ assert(IMPLICATION(
+ jcp.id != 1, jcp.oh_block == jcp.oh && jcp.ow_block == jcp.ow));
+ assert(IMPLICATION(jcp.ow_block != jcp.ow, jcp.oh_block == 1));
+
const int K = jcp.ic * jcp.ks;
const int N = jcp.oc;
- const int m = jcp.os;
- const int LDA = jcp.im2col_sz ? m : M;
-
- const data_t one = 1.0;
-
- data_t *col = (jcp.im2col_sz)
- ? (data_t *)this->scratchpad_->get()
- : nullptr;
- parallel_nd(jcp.im2col_sz * jcp.nthr,
- [&](ptrdiff_t i) { col[i] = (data_t)0; });
+ if (jcp.im2col_sz && jcp.id != 1)
+ parallel_nd(jcp.im2col_sz * jcp.nthr,
+ [&](ptrdiff_t i) { col[i] = (data_t)0; });
- const size_t work_amount = jcp.ngroups * MB * jcp.od;
+ const int nb_oh = div_up(jcp.oh, jcp.oh_block);
+ const int nb_ow = div_up(jcp.ow, jcp.ow_block);
+ const size_t work_amount = jcp.ngroups * MB * jcp.od * nb_oh * nb_ow;
parallel(jcp.nthr, [&](const int ithr, const int nthr) {
data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz;
- int g{0}, n{0}, od{0};
+ int g{ 0 }, n{ 0 }, od{ 0 }, ohb{ 0 }, owb{ 0 };
size_t start = 0, end = 0;
balance211(work_amount, nthr, ithr, start, end);
- nd_iterator_init(start, g, jcp.ngroups, n, MB, od, jcp.od);
-
+ nd_iterator_init(start, g, jcp.ngroups, n, MB, od, jcp.od, ohb,
+ nb_oh, owb, nb_ow);
for (size_t iwork = start; iwork < end; ++iwork) {
+ int oh = ohb * jcp.oh_block;
+ int ow = owb * jcp.ow_block;
const data_t *_src = src + (n * jcp.ngroups + g) * src_step;
const data_t *_weights = weights + g * weights_g_size;
- data_t *_dst = dst + (n * jcp.ngroups + g) * dst_step;
-
+ data_t *_dst_im = dst + (n * jcp.ngroups + g) * dst_step;
+ const int h_step = nstl::min(jcp.oh_block, jcp.oh - oh);
+ const int w_step = nstl::min(jcp.ow_block, jcp.ow - ow);
if (jcp.im2col_sz) {
if (jcp.id == 1)
- jit_gemm_convolution_utils::im2col(jcp, _src, _col);
+ jit_gemm_convolution_utils::im2col(
+ jcp, _src, _col, oh, h_step, ow, w_step);
else
jit_gemm_convolution_utils::im2col_3d(jcp, _src, _col, od);
}
const data_t one = 1.0;
+
+ const int m = h_step * w_step;
+ const int LDA = jcp.im2col_sz ? m : M;
+ data_t *_dst = _dst_im + od * jcp.os + oh * jcp.ow + ow;
+
extended_sgemm("N", "N", &m, &N, &K, &one,
jcp.im2col_sz ? _col : _src + od * m, &LDA, _weights, &K,
- &this->beta_, _dst + od * m, &M);
+ &this->beta_, _dst, &M);
- const auto &p = conf_.attr()->post_ops_;
+ data_t *d = _dst;
+ const auto &p = pd()->attr()->post_ops_;
bool need_bias = jcp.with_bias;
if (use_fast_relu) {
- data_t *d = _dst + od * m;
-
- for (int oc = 0; oc < jcp.oc; ++oc) {
+ parallel_nd(jcp.oc, [&](const int oc) {
data_t b = need_bias ? bias[g * jcp.oc + oc] : 0;
+ data_t *d_ = d + oc * M;
+ PRAGMA_OMP_SIMD()
for (int oS = 0; oS < m; ++oS) {
- d[oS] += b;
- if (d[oS] < 0) d[oS] *= fast_relu_ns;
+ d_[oS] += b;
+ if (d_[oS] < 0) d_[oS] *= fast_relu_ns;
}
- d += M;
- }
+ });
need_bias = false;
} else if (p.len_ > 0) {
int depthwise_inj_idx = 0;
for (int i = 0; i < p.len_; i++) {
- data_t *d = _dst + od * m;
auto& post_op = p.entry_[i];
if (post_op.is_eltwise()) {
- for (int oc = 0; oc < jcp.oc; ++oc) {
+ parallel_nd(jcp.oc, [&](const int oc) {
data_t b = need_bias ? bias[g * jcp.oc + oc] : 0;
+ data_t *d_ = d + oc * M;
+ PRAGMA_OMP_SIMD()
for (int oS = 0; oS < m; ++oS) {
- d[oS] += b;
- d[oS] = eltwise_injectors[eltwise_inj_idx]->compute_scalar(d[oS]);
+ d_[oS] += b;
+ d_[oS] = eltwise_injectors[eltwise_inj_idx]->compute_scalar(d_[oS]);
}
- d += M;
- }
+ });
eltwise_inj_idx++;
need_bias = false;
auto depthwise_weights = post_op.depthwise.weights_data;
auto depthwise_bias = post_op.depthwise.biases_data;
- for (int oc = 0; oc < jcp.oc; ++oc) {
+ parallel_nd(jcp.oc, [&](const int oc) {
data_t b = need_bias ? bias[g * jcp.oc + oc] : 0;
+ data_t *d_ = d + oc * M;
+ PRAGMA_OMP_SIMD()
for (int oS = 0; oS < m; ++oS) {
- d[oS] += b;
- d[oS] = depthwise_injectors[depthwise_inj_idx]->compute_scalar(d[oS],
+ d_[oS] += b;
+ d_[oS] = depthwise_injectors[depthwise_inj_idx]->compute_scalar(d_[oS],
depthwise_weights + g * jcp.oc + oc,
depthwise_bias + g * jcp.oc + oc);
}
- d += M;
- }
+ });
depthwise_inj_idx++;
need_bias = false;
}
if (need_bias) {
- data_t *d = _dst + od * m;
-
- for (int oc = 0; oc < jcp.oc; ++oc) {
+ parallel_nd(jcp.oc, [&](const int oc) {
data_t b = bias[g * jcp.oc + oc];
+ data_t *d_ = d + oc * M;
+ PRAGMA_OMP_SIMD()
for (int oS = 0; oS < m; ++oS) {
- d[oS] += b;
+ d_[oS] += b;
}
- d += M;
- }
+ });
}
- nd_iterator_step(g, jcp.ngroups, n, MB, od, jcp.od);
+ nd_iterator_step(g, jcp.ngroups, n, MB, od, jcp.od, ohb, nb_oh,
+ owb, nb_ow);
}
});
}
-void gemm_convolution_bwd_data_t::execute_backward_data() {
+void gemm_convolution_bwd_data_t::execute_backward_data() const {
auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(0));
auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
auto diff_src = reinterpret_cast<data_t*>(this->memory());
- jit_gemm_conv_conf_t &jcp = this->conf_.jcp_;
- const int MB = conf_.MB();
+ auto col = scratchpad().get<data_t>(key_conv_gemm_col);
+
+ const auto &jcp = this->pd()->jcp_;
+ const int MB = pd()->MB();
const int M = jcp.os * jcp.od;
- const size_t src_step = jcp.ic * jcp.ih * jcp.iw * jcp.id;
- const size_t dst_step = jcp.oc * M;
+ const size_t src_step_to_clean = jcp.ic * jcp.ih * jcp.iw * jcp.id;
+ const memory_desc_wrapper diff_src_d(pd()->diff_src_pd());
+ const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
+ const size_t src_step = diff_src_d.blk_off(1) / jcp.ngroups;
+ const size_t dst_step = diff_dst_d.blk_off(1) / jcp.ngroups;
const size_t weights_g_size = jcp.ic * jcp.oc * jcp.ks;
const int m = jcp.os;
const int K = jcp.oc;
const int N = jcp.ic * jcp.ks;
const int LDC = jcp.im2col_sz ? m : M;
- data_t *col = jcp.im2col_sz ? (data_t *)this->scratchpad_->get() : nullptr;
const size_t work_amount = (size_t)jcp.ngroups * MB;
if (jcp.id > 1) {
- const ptrdiff_t diff_src_sz = (ptrdiff_t)(work_amount * src_step);
- parallel_nd(diff_src_sz, [&](ptrdiff_t i) { diff_src[i] = (data_t)0; });
+ for (size_t j = 0; j < work_amount; j++) {
+ int j_step = src_step * j;
+ const ptrdiff_t diff_src_sz = (ptrdiff_t)(src_step_to_clean);
+ parallel_nd(diff_src_sz, [&](ptrdiff_t i) { diff_src[j_step + i] = (data_t)0; });
+ }
}
parallel(jcp.nthr, [&](const int ithr, const int nthr) {
nd_iterator_init(start, g, jcp.ngroups, n, MB);
for (size_t iwork = start; iwork < end; ++iwork) {
- data_t *_diff_src = diff_src + (n * jcp.ngroups + g)*src_step;
+ data_t *_diff_src = diff_src + (n * jcp.ngroups + g) * src_step;
const data_t *_weights = weights + g * weights_g_size;
for (int od = 0; od < jcp.od; ++od) {
const data_t *_diff_dst = diff_dst + (n * jcp.ngroups + g)
});
}
-void gemm_convolution_bwd_weights_t::execute_backward_weights() {
+void gemm_convolution_bwd_weights_t::execute_backward_weights() const {
auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
auto diff_weights = reinterpret_cast<data_t*>(this->memory(0));
auto diff_bias = reinterpret_cast<data_t *>(this->memory(1));
- jit_gemm_conv_conf_t &jcp = this->conf_.jcp_;
+ auto col = scratchpad().get<data_t>(key_conv_gemm_col);
+ auto wei_reduction = scratchpad().get<data_t>(key_conv_wei_reduction);
+
+ const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_;
+
const int K = jcp.os * jcp.od;
const size_t src_step = jcp.ic * jcp.ih * jcp.iw * jcp.id;
const size_t dst_step = jcp.oc * K;
const int M = jcp.ic * jcp.ks;
const int LDA = jcp.im2col_sz ? k : K;
- data_t *col = nullptr, *wei_reduction = nullptr;
- ptrdiff_t wei_offset = 0;
- if (jcp.im2col_sz) {
- col = (data_t *)this->scratchpad_->get();
- wei_offset = jcp.im2col_sz * jcp.nthr;
- }
- if (jcp.need_wei_reduction)
- wei_reduction = (data_t *)this->scratchpad_->get() + wei_offset;
-
parallel_nd(jcp.im2col_sz * jcp.nthr,
[&](ptrdiff_t i) { col[i] = (data_t)0; });
if (jcp.im2col_sz) {
if (jcp.id == 1)
- jit_gemm_convolution_utils::im2col(jcp, _src, _col);
+ jit_gemm_convolution_utils::im2col(
+ jcp, _src, _col, 0, jcp.oh, 0, jcp.ow);
else
jit_gemm_convolution_utils::im2col_3d(jcp, _src,
_col, od);
}
}
diff_bias[g*jcp.oc+oc] = db;
- nd_iterator_step(g, jcp.ngroups, oc, jcp.oc);
});
}
}
-template struct _gemm_convolution_fwd_t<true>;
-template struct _gemm_convolution_fwd_t<false>;
}
}
}