* limitations under the License.
*******************************************************************************/
+#include <assert.h>
+
#include "c_types_map.hpp"
+#include "memory_tracking.hpp"
#include "nstl.hpp"
#include "type_helpers.hpp"
#include "utils.hpp"
+
#include "cpu_memory.hpp"
#include "jit_avx2_1x1_conv_kernel_f32.hpp"
default:
if (jcp.with_dw_conv) {
return ptr[aux_reg_output_data +
- (i * jcp.dw_conv_ker_h * jcp.ow + j) * jcp.oc_block * sizeof(float)];
+ (i * jcp_dw.kh * jcp.ow + j) * jcp.oc_block * sizeof(float)];
} else {
return ptr[aux_reg_output_data +
(i * jcp.os + j) * jcp.oc_block * sizeof(float)];
};
auto store = [=]() {
- Label store_done, store_noadd;
+ Label store_noadd;
if (!jcp.with_sum) {
test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
int eltwise_inj_idx = 0;
int depthwise_inj_idx = 0;
const auto &p = attr_.post_ops_;
- if (p.len_ == 0 && eltwise_injectors.size() == 1) {
- eltwise_injectors[0]->compute_vector_range(0, ur * load_loop_blk);
- }
int end_idx = jcp.with_dw_conv ? p.find(primitive_kind::convolution) : p.len_;
for (int i = 0; i < end_idx; i++) {
for (int i = 0; i < load_loop_blk; ++i) {
vmovups(output_ptr(i, j), vreg_accum(i, j));
}
-
- L(store_done);
};
auto fma_block = [=](bool last_block) {
if (mayiuse(avx2))
vfmadd231ps(vreg_accum(i, j), vreg_load(i), vreg_bcast);
else { // Intel(R) Advanced Vector Extensions (Intel(R) AVX) support
- auto tmp = vmask;
- vmulps(tmp, vreg_bcast, vreg_load(i));
- vaddps(vreg_accum(i, j), vreg_accum(i, j), tmp);
+ vmulps(vtmp, vreg_bcast, vreg_load(i));
+ vaddps(vreg_accum(i, j), vreg_accum(i, j), vtmp);
}
if (j == ur - 1 && !(last_block
&& u == jcp.reduce_loop_unroll - 1))
void jit_avx2_1x1_conv_kernel_f32::generate()
{
- if (jcp.with_eltwise) {
- eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<avx2>(
- this, jcp.eltwise_alg, jcp.eltwise_alpha, 0
- ));
- }
-
const auto &p = attr_.post_ops_;
int end_idx = jcp.with_dw_conv ? p.find(primitive_kind::convolution) : p.len_;
for (int i = 0; i < end_idx; i++) {
auto is_simple = [&](int idx) { return is_eltwise(idx) || is_depthwise(idx); };
switch (p.len_) {
- case 0: return true; // no post_ops
- case 1:
- return true // sum OR eltwise OR dw_conv
- && !jcp.with_eltwise && (is_simple(0) || is_sum(0) || is_dw_conv(0));
- case 2:
- return true // sum->eltwise OR dw_conv->eltwise OR eltwise->dw_conv OR dw_conv->sum OR sum->depthwise OR
- // eltwise->depthwise OR depthwise->depthwise
- && !jcp.with_eltwise && ((is_sum(0) && is_simple(1)) || (is_dw_conv(0) && is_eltwise(1)) ||
- (is_eltwise(0) && is_dw_conv(1)) || (is_dw_conv(0) && is_sum(1)) ||
- (is_simple(0) && is_simple(1)));
- case 3:
- return true // eltwise->dw_conv->eltwise OR dw_conv->sum->eltwise OR sum->eltwise->depthwise OR
- // sum->depthwise->eltwise OR sum->depthwise->depthwise
- && !jcp.with_eltwise && ((is_eltwise(0) && is_dw_conv(1) && is_eltwise(2)) ||
- (is_dw_conv(0) && is_sum(1) && is_eltwise(2)) ||
- (is_sum(0) && is_simple(1) && is_simple(2)));
- case 4: return true // eltwise->dw_conv->sum->eltwise
- && !jcp.with_eltwise && (is_eltwise(0) && is_dw_conv(1) && is_sum(2) && is_eltwise(3));
+ case 0: return true;
+ case 1: return is_simple(0) || is_sum(0) || is_dw_conv(0);
+ case 2: return (is_sum(0) && is_simple(1)) || (is_dw_conv(0) && is_eltwise(1)) ||
+ (is_eltwise(0) && is_dw_conv(1)) || (is_dw_conv(0) && is_sum(1)) ||
+ (is_simple(0) && is_simple(1));
+ case 3: return (is_eltwise(0) && is_dw_conv(1) && is_eltwise(2)) ||
+ (is_dw_conv(0) && is_sum(1) && is_eltwise(2)) ||
+ (is_sum(0) && is_simple(1) && is_simple(2));
+ case 4: return (is_eltwise(0) && is_dw_conv(1) && is_sum(2) && is_eltwise(3));
default: return false;
}
status_t jit_avx2_1x1_conv_kernel_f32::init_conf(jit_1x1_conv_conf_t &jcp,
const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
- const primitive_attr_t &attr, bool with_relu, float relu_negative_slope)
+ const primitive_attr_t &attr)
{
if (!mayiuse(avx)) return status::unimplemented;
jcp.src_fmt = src_d.format();
jcp.with_bias = cd.bias_desc.format != memory_format::undef;
- jcp.with_eltwise = with_relu;
- jcp.eltwise_alg = mkldnn_eltwise_relu;
- jcp.eltwise_alpha = relu_negative_slope;
-
- if (!post_ops_ok(jcp, attr)) {
+ if (!post_ops_ok(jcp, attr))
return status::unimplemented;
- }
const auto &p = attr.post_ops_;
- jcp.with_dw_conv = false;
+
int dw_conv_ind = p.find(primitive_kind::convolution);
- if (dw_conv_ind != -1) {
- jcp.with_dw_conv = true;
- jcp.dw_conv_in_h = p.entry_[dw_conv_ind].dw_conv.in_h;
- jcp.dw_conv_in_w = p.entry_[dw_conv_ind].dw_conv.in_w;
- jcp.dw_conv_ker_h = p.entry_[dw_conv_ind].dw_conv.ker_h;
- jcp.dw_conv_ker_w = p.entry_[dw_conv_ind].dw_conv.ker_w;
- jcp.dw_conv_str_h = p.entry_[dw_conv_ind].dw_conv.str_h;
- jcp.dw_conv_str_w = p.entry_[dw_conv_ind].dw_conv.str_w;
- jcp.dw_conv_weights = p.entry_[dw_conv_ind].dw_conv.weights_data;
- jcp.dw_conv_biases = p.entry_[dw_conv_ind].dw_conv.biases_data;
+ jcp.with_dw_conv = dw_conv_ind != -1;
+ jcp.with_dw_conv = dw_conv_ind != -1;
+ if (jcp.with_dw_conv) {
+ jcp.dw_conv_oh = jcp.oh;
+ jcp.dw_conv_ow = jcp.ow;
+ jcp.oh = p.entry_[dw_conv_ind].dw_conv.in_h;
+ jcp.ow = p.entry_[dw_conv_ind].dw_conv.in_w;
}
if (jcp.with_dw_conv && !mayiuse(avx2))
return status::unimplemented;
- if (jcp.with_dw_conv) {
- int dw_conv_eltwise_ind = p.find(primitive_kind::eltwise, dw_conv_ind);
- if (dw_conv_eltwise_ind != -1) {
- jcp.dw_conv_with_eltwise = true;
- jcp.dw_conv_eltwise_alg = p.entry_[dw_conv_eltwise_ind].eltwise.alg;
- jcp.dw_conv_eltwise_alpha = p.entry_[dw_conv_eltwise_ind].eltwise.alpha;
- jcp.dw_conv_eltwise_beta = p.entry_[dw_conv_eltwise_ind].eltwise.beta;
+ if (!mayiuse(avx2)) {
+ for (int i = 0; i < p.len_; i++) {
+ auto &post_op = p.entry_[i];
+ if (post_op.is_eltwise()) {
+ if (post_op.eltwise.alg != alg_kind::eltwise_relu)
+ return status::unimplemented;
+ } else if (post_op.is_depthwise()) {
+ return status::unimplemented;
+ }
}
}
jcp.with_sum = p.find(primitive_kind::sum, 0, dw_conv_ind) != -1;
- if (jcp.with_dw_conv) {
- jcp.dw_conv_with_sum = p.find(primitive_kind::sum, dw_conv_ind) != -1;
- }
- if (jcp.with_dw_conv) {
- jcp.oh = jcp.dw_conv_in_h;
- jcp.ow = jcp.dw_conv_in_w;
- }
+ jcp.src_dt = cd.src_desc.data_type;
+ jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
+ jcp.dst_dt = cd.dst_desc.data_type;
jcp.os = jcp.oh * jcp.ow;
jcp.is = jcp.ih * jcp.iw;
return status::success;
}
+void jit_avx2_1x1_conv_kernel_f32::init_scratchpad(
+ memory_tracking::registrar_t &scratchpad,
+ const jit_1x1_conv_conf_t &jcp, const jit_conv_conf_t &jcp_dw) {
+ using namespace mkldnn::impl::memory_tracking::names;
+
+ if (jcp.prop_kind != backward_data && jcp.oc != jcp.oc_without_padding)
+ scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc);
+
+ if (jcp.with_dw_conv) {
+ const int nthreads = mkldnn_get_max_threads();
+ size_t dw_conv_buffer_size_ = (size_t)jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block * (jcp.oc / jcp.oc_block);
+ scratchpad.book(key_dw_conv_buffer, sizeof(float) * dw_conv_buffer_size_ * nthreads);
+
+ if (jcp.oc != jcp.oc_without_padding)
+ scratchpad.book(key_dw_conv_padded_bias, sizeof(float) * jcp.oc);
+ }
+}
+
}
}
}