using namespace mkldnn::impl::prop_kind;
using namespace mkldnn::impl::memory_format;
+using namespace mkldnn::impl::memory_tracking::names;
using namespace mkldnn::impl::utils;
using namespace Xbyak;
for (int jj = 0; jj < ur_w; jj++) {
int o_off;
if (jcp.with_dw_conv)
- o_off = (ii * jcp.dw_conv_ker_h * ow + jj) * oc_blk;
+ o_off = (ii * jcp_dw.kh * ow + jj) * oc_blk;
else
o_off = (ii * oh * ow + jj) * oc_blk;
Label skip_kh_loop;
mov(kj, reg_kh);
- if ((jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) {
+ if ((jcp.dilate_h >= jcp.ih)
+ || (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) {
cmp(kj, 0);
je(skip_kh_loop, T_NEAR);
}
int depthwise_inj_idx = 0;
const auto &p = attr_.post_ops_;
- if (p.len_ == 0 && eltwise_injectors.size() == 1) {
- eltwise_injectors[0]->compute_vector_range(1, oc_blocks * ur_w + 1);
- }
-
int end_idx = jcp.with_dw_conv ? p.find(primitive_kind::convolution) : p.len_;
for (int i = 0; i < end_idx; i++) {
auto& post_op = p.entry_[i];
for (int jj = 0; jj < ur_w; jj++) {
int o_off;
if (jcp.with_dw_conv)
- o_off = (ii * jcp.dw_conv_ker_h * ow + jj) * oc_blk;
+ o_off = (ii * jcp_dw.kh * ow + jj) * oc_blk;
else
o_off = (ii * oh * ow + jj) * oc_blk;
}
}
- L(done);
-
mov(aux_reg_kernel, reg_kernel);
mov(aux_reg_input, reg_input);
add(aux_reg_kernel, sizeof(float) * 4);
void jit_sse42_conv_fwd_kernel_f32::generate()
{
- if (jcp.with_eltwise) {
- eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<sse42>(
- this, jcp.eltwise_alg, jcp.eltwise_alpha, 0
- ));
- }
-
const auto &p = attr_.post_ops_;
int end_idx = jcp.with_dw_conv ? p.find(primitive_kind::convolution) : p.len_;
for (int i = 0; i < end_idx; i++) {
auto is_simple = [&](int idx) { return is_eltwise(idx) || is_depthwise(idx); };
switch (p.len_) {
- case 0: return true; // no post_ops
- case 1:
- return true // sum OR eltwise OR dw_conv
- && !jcp.with_eltwise && (is_simple(0) || is_sum(0) || is_dw_conv(0));
- case 2:
- return true // sum->eltwise OR dw_conv->eltwise OR eltwise->dw_conv OR dw_conv->sum OR sum->depthwise OR
- // eltwise->depthwise OR depthwise->depthwise
- && !jcp.with_eltwise && ((is_sum(0) && is_simple(1)) || (is_dw_conv(0) && is_eltwise(1)) ||
- (is_eltwise(0) && is_dw_conv(1)) || (is_dw_conv(0) && is_sum(1)) ||
- (is_simple(0) && is_simple(1)));
- case 3:
- return true // eltwise->dw_conv->eltwise OR dw_conv->sum->eltwise OR sum->eltwise->depthwise OR
- // sum->depthwise->eltwise OR sum->depthwise->depthwise
- && !jcp.with_eltwise && ((is_eltwise(0) && is_dw_conv(1) && is_eltwise(2)) ||
- (is_dw_conv(0) && is_sum(1) && is_eltwise(2)) ||
- (is_sum(0) && is_simple(1) && is_simple(2)));
- case 4: return true // eltwise->dw_conv->sum->eltwise
- && !jcp.with_eltwise && (is_eltwise(0) && is_dw_conv(1) && is_sum(2) && is_eltwise(3));
+ case 0: return true;
+ case 1: return is_simple(0) || is_sum(0) || is_dw_conv(0);
+ case 2: return (is_sum(0) && is_simple(1)) || (is_dw_conv(0) && is_eltwise(1)) ||
+ (is_eltwise(0) && is_dw_conv(1)) || (is_dw_conv(0) && is_sum(1)) ||
+ (is_simple(0) && is_simple(1));
+ case 3: return (is_eltwise(0) && is_dw_conv(1) && is_eltwise(2)) ||
+ (is_dw_conv(0) && is_sum(1) && is_eltwise(2)) ||
+ (is_sum(0) && is_simple(1) && is_simple(2));
+ case 4: return (is_eltwise(0) && is_dw_conv(1) && is_sum(2) && is_eltwise(3));
default: return false;
}
status_t jit_sse42_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp,
const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
- const primitive_attr_t &attr, bool with_relu, float relu_negative_slope)
+ const primitive_attr_t &attr)
{
if (!mayiuse(sse42)) return status::unimplemented;
jcp.src_fmt = src_d.format();
jcp.with_bias = cd.bias_desc.format != memory_format::undef;
- jcp.with_eltwise = with_relu;
- jcp.eltwise_alg = mkldnn_eltwise_relu;
- jcp.eltwise_alpha = relu_negative_slope;
if (!post_ops_ok(jcp, attr))
return status::unimplemented;
const auto &p = attr.post_ops_;
- jcp.with_dw_conv = false;
- int dw_conv_ind = p.find(primitive_kind::convolution);
- if (dw_conv_ind != -1) {
- jcp.with_dw_conv = true;
- jcp.dw_conv_in_h = p.entry_[dw_conv_ind].dw_conv.in_h;
- jcp.dw_conv_in_w = p.entry_[dw_conv_ind].dw_conv.in_w;
- jcp.dw_conv_ker_h = p.entry_[dw_conv_ind].dw_conv.ker_h;
- jcp.dw_conv_ker_w = p.entry_[dw_conv_ind].dw_conv.ker_w;
- jcp.dw_conv_str_h = p.entry_[dw_conv_ind].dw_conv.str_h;
- jcp.dw_conv_str_w = p.entry_[dw_conv_ind].dw_conv.str_w;
- jcp.dw_conv_weights = p.entry_[dw_conv_ind].dw_conv.weights_data;
- jcp.dw_conv_biases = p.entry_[dw_conv_ind].dw_conv.biases_data;
- }
+ int dw_conv_ind = p.find(primitive_kind::convolution);
+ jcp.with_dw_conv = dw_conv_ind != -1;
if (jcp.with_dw_conv) {
- int dw_conv_eltwise_ind = p.find(primitive_kind::eltwise, dw_conv_ind);
- if (dw_conv_eltwise_ind != -1) {
- jcp.dw_conv_with_eltwise = true;
- jcp.dw_conv_eltwise_alg = p.entry_[dw_conv_eltwise_ind].eltwise.alg;
- jcp.dw_conv_eltwise_alpha = p.entry_[dw_conv_eltwise_ind].eltwise.alpha;
- jcp.dw_conv_eltwise_beta = p.entry_[dw_conv_eltwise_ind].eltwise.beta;
- }
+ jcp.dw_conv_oh = jcp.oh;
+ jcp.dw_conv_ow = jcp.ow;
+ jcp.oh = p.entry_[dw_conv_ind].dw_conv.in_h;
+ jcp.ow = p.entry_[dw_conv_ind].dw_conv.in_w;
}
jcp.with_sum = p.find(primitive_kind::sum, 0, dw_conv_ind) != -1;
- if (jcp.with_dw_conv) {
- jcp.dw_conv_with_sum = p.find(primitive_kind::sum, dw_conv_ind) != -1;
- }
- if (jcp.with_dw_conv) {
- jcp.oh = jcp.dw_conv_in_h;
- jcp.ow = jcp.dw_conv_in_w;
- }
+ jcp.src_dt = cd.src_desc.data_type;
+ jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
+ jcp.dst_dt = cd.dst_desc.data_type;
const bool flat = jcp.ic == 3 || jcp.ic == 1;
const bool mimo = !flat;
return status::success;
}
+void jit_sse42_conv_fwd_kernel_f32::init_scratchpad(
+ memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp, const jit_conv_conf_t &jcp_dw) {
+ if (jcp.with_bias && jcp.oc != jcp.oc_without_padding)
+ scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc);
+
+ if (jcp.with_dw_conv) {
+ const int nthreads = mkldnn_get_max_threads();
+ size_t dw_conv_buffer_size_ = (size_t)jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block * jcp.nb_oc_blocking;
+ scratchpad.book(key_dw_conv_buffer, sizeof(float) * dw_conv_buffer_size_ * nthreads);
+
+ if (jcp.oc != jcp.oc_without_padding)
+ scratchpad.book(key_dw_conv_padded_bias, sizeof(float) * jcp.oc);
+ }
+}
+
}
}
}