* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
+
+#include <assert.h>
#include <float.h>
+
#include "c_types_map.hpp"
+#include "memory_tracking.hpp"
+#include "mkldnn_thread.hpp"
#include "nstl.hpp"
#include "type_helpers.hpp"
-#include "mkldnn_thread.hpp"
#include "utils.hpp"
+
#include "cpu_memory.hpp"
+#include "cpu_barrier.hpp"
#include "jit_uni_1x1_conv_utils.hpp"
#include "jit_avx512_common_1x1_conv_kernel.hpp"
int depthwise_inj_idx = 0;
const auto &p = attr_.post_ops_;
- if (p.len_ == 0 && eltwise_injectors.size() == 1) {
- eltwise_injectors[0]->compute_vector_range(0, ur * load_loop_blk);
- }
-
for (int i = 0; i < p.len_; i++) {
auto& post_op = p.entry_[i];
if (post_op.is_eltwise()) {
- eltwise_injectors[eltwise_inj_idx]->compute_vector_range(0, ur * load_loop_blk);
+ if (jcp.ver == ver_4vnni) {
+ zmm_t zmm_zero = vreg_bcast;
+ vpxord(zmm_zero, zmm_zero, zmm_zero);
+
+ for (int i_ur = 0; i_ur < ur; ++i_ur) {
+ for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
+ Zmm zmm = vreg_accum(i_load, i_ur);
+ vpcmpd(k1, zmm, zmm_zero, _cmp_lt_os);
+ vpmulld(zmm | k1, zmm, zmm_zero);
+ }
+ }
+ } else {
+ eltwise_injectors[eltwise_inj_idx]->compute_vector_range(0, ur * load_loop_blk);
+ }
eltwise_inj_idx++;
} else if (post_op.is_depthwise()) {
mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
void jit_avx512_common_1x1_conv_kernel::generate()
{
- if (jcp.with_eltwise) {
- eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<avx512_common>(
- this, jcp.eltwise_alg, jcp.eltwise_alpha, 0
- ));
- }
-
const auto &p = attr_.post_ops_;
for (int i = 0; i < p.len_; i++) {
auto &post_op = p.entry_[i];
mov(EVEX_compress_addr(rsp, bcast_loop_work_offt), reg_bcast_loop_work);
mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]);
mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]);
+ if (one_of(jcp.prop_kind, forward_training, forward_inference))
+ mov(reg_relu_ns, reinterpret_cast<size_t>(&jcp.eltwise.alpha));
if (jcp.prop_kind == backward_weights)
mov(reg_output_stride, ptr[param1 + GET_OFF(output_stride)]);
mov(reg_oc_off, ptr[param1 + GET_OFF(oc_off)]);
auto is_simple = [&](int idx) { return is_eltwise(idx) || is_depthwise(idx); };
switch (p.len_) {
- case 0: return true; // no post_ops
- case 1:
- return true // sum OR eltwise OR depthwise
- && !jcp.with_eltwise && (is_simple(0) || is_sum(0));
- case 2:
- return true // sum->relu
- && !jcp.with_eltwise && ((is_sum(0) && is_simple(1)) ||
- (is_simple(0) && is_simple(1)));
- case 3:
- return true // sum->relu
- && !jcp.with_eltwise && (is_sum(0) && is_simple(1) && is_simple(2));
+ case 0: return true;
+ case 1: return is_simple(0) || is_sum(0);
+ case 2: return (is_sum(0) && is_simple(1)) || (is_simple(0) && is_simple(1));
+ case 3: return is_sum(0) && is_simple(1) && is_simple(2);
default: return false;
}
return false;
}
-status_t jit_avx512_common_1x1_conv_kernel::init_conf(
- jit_1x1_conv_conf_t &jcp, const convolution_desc_t &cd,
- const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d,
- const memory_desc_wrapper &dst_d, const primitive_attr_t &attr,
- bool with_relu, float relu_negative_slope,
- int nthreads, bool reduce_src)
-{
+status_t jit_avx512_common_1x1_conv_kernel::init_conf(jit_1x1_conv_conf_t &jcp,
+ const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
+ const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
+ const primitive_attr_t &attr, int nthreads, bool reduce_src) {
if (!mayiuse(avx512_common)) return status::unimplemented;
const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
jcp.stride_w = cd.strides[ndims - 3];
jcp.src_fmt = src_d.format();
- jcp.with_bias = one_of(jcp.prop_kind, forward_training, forward_inference)
- ? cd.bias_desc.format != memory_format::undef : false;
- jcp.with_eltwise = with_relu;
- jcp.eltwise_alg = mkldnn_eltwise_relu;
- jcp.eltwise_alpha = relu_negative_slope;
+ jcp.with_bias = pick_by_prop_kind(jcp.prop_kind, cd.bias_desc.format,
+ memory_format::undef, cd.diff_bias_desc.format)
+ != memory_format::undef;
jcp.os = jcp.oh * jcp.ow;
jcp.is = jcp.ih * jcp.iw;
const auto &p = attr.post_ops_;
jcp.with_sum = p.find(primitive_kind::sum) != -1;
+ const int eltwise_ind = p.find(primitive_kind::eltwise);
+ jcp.with_eltwise = eltwise_ind != -1;
+ if (jcp.with_eltwise) {
+ jcp.eltwise = p.entry_[eltwise_ind].eltwise;
+ if (dst_d.data_type() == data_type::s32) return status::unimplemented;
+ }
bool args_ok = true
&& jcp.ngroups == 1
} else {
bool is4ops = (jcp.ver == ver_4fma || jcp.ver == ver_4vnni);
-// max_regs = is4ops ? 28 : 30;
- // FIXME (ichuraev): it is a fix for densnet-121
- max_regs = 28;
+ max_regs = is4ops ? 28 : 30;
min_regs = 9;
size_treshold = is4ops ? 28 : 14;
ur_step = is4ops ? 4 : 1;
load_blocking = jcp.load_block;
}
+ if (jcp.ver == ver_4fma && jcp.bcast_dim * jcp.mb < jcp.load_dim
+ && jcp.oh * jcp.ow > 64
+ && IMPLICATION(reduce_src, jcp.load_dim < 1024)) {
+ /* Looking for best loading dimension blocking
+ * to get the best thread and data read/write efficiency
+ * by finding the optimal 'load_chunk' value
+ * Example:
+ * for 72 threads and convolution with mb=1, ih=iw=7, oc = 512
+ * the 'best' load_chunk value should be 1
+ * TODO: remove heuristic constants in above condition
+ * TODO: check this blocking for other ISA
+ */
+ float best_eff = -1.f;
+ int best_lgc = 1;
+
+ for (int load_chunk = 1; load_chunk <= nb_load; load_chunk++) {
+ int lgc = div_up(nb_load, load_chunk);
+ if (lgc > nthreads)
+ continue;
+ int thr_per_grp = div_up(nthreads, lgc);
+ int bcast_per_thr = div_up(jcp.mb * nb_bcast, thr_per_grp)
+ * jcp.bcast_block;
+ int load_per_thr = load_chunk * simd_w;
+ float data_norm = (bcast_per_thr + load_per_thr) / 2.f;
+ float data_eff = (bcast_per_thr * load_per_thr)
+ / (data_norm * data_norm);
+ float thr_eff_over_grp = (float)nstl::max(1, nthreads / lgc)
+ / div_up(nthreads, lgc);
+ float thr_eff_in_grp = ((float)jcp.mb * nb_bcast)
+ / rnd_up(jcp.mb * nb_bcast, thr_per_grp);
+ float thr_eff = thr_eff_over_grp * thr_eff_in_grp;
+ float load_eff = (float)nb_load / rnd_up(nb_load, lgc);
+ float overall_eff = data_eff + thr_eff + load_eff;
+ if (overall_eff > best_eff) {
+ best_eff = overall_eff;
+ best_lgc = lgc;
+ }
+ }
+ jcp.load_grp_count = best_lgc;
+ load_blocking
+ = div_up(nb_load, jcp.load_grp_count) * jcp.load_block;
+ }
bcast_blocking = div_up(jcp.mb * jcp.ngroups * nb_bcast,
div_up(nthreads, jcp.load_grp_count))
* jcp.bcast_block;
return status::success;
}
+void jit_avx512_common_1x1_conv_kernel::init_scratchpad(
+ memory_tracking::registrar_t &scratchpad,
+ const jit_1x1_conv_conf_t &jcp) {
+ using namespace mkldnn::impl::memory_tracking::names;
+
+ if (jcp.prop_kind != backward_data && jcp.with_bias
+ && jcp.oc != jcp.oc_without_padding)
+ scratchpad.book(key_conv_padded_bias, jcp.typesize_out * jcp.oc);
+
+ if (jcp.prop_kind == backward_weights) {
+ const size_t wei_size = (size_t)jcp.ngroups * jcp.oc * jcp.ic;
+ scratchpad.book(key_conv_wei_reduction,
+ jcp.typesize_out * wei_size * (jcp.nthr_mb - 1));
+ }
+
+ if (jcp.transpose_src) {
+ const size_t tr_src_size =
+ (size_t)jcp.nthr_mb * jcp.ngroups * jcp.ic * jcp.tr_is;
+ scratchpad.book(key_conv_tr_src, jcp.typesize_out * tr_src_size);
+ scratchpad.book(key_conv_tr_src_bctx,
+ sizeof(simple_barrier::ctx_t) * jcp.nthr);
+ }
+}
+
void jit_avx512_common_1x1_conv_kernel::balance(jit_1x1_conv_conf_t &jcp,
int nthreads)
{