* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
-#include <float.h>
+
+#include <assert.h>
+
#include "c_types_map.hpp"
+#include "memory_tracking.hpp"
#include "nstl.hpp"
#include "type_helpers.hpp"
-#include "mkldnn_thread.hpp"
#include "utils.hpp"
+
#include "cpu_memory.hpp"
#include "jit_uni_1x1_conv_utils.hpp"
using namespace Xbyak;
-bool jit_avx512_core_x8s8s32x_1x1_conv_kernel::maybe_relu(int position)
-{
- using namespace primitive_kind;
- const auto &p = attr_.post_ops_;
-
- if (position == 0) {
- /* relu before sum */
- return false
- || jcp.with_eltwise
- || p.contain(eltwise, 0)
- || (jcp.dst_dt == data_type::u8 && !p.contain(sum, 0));
- } else if (position == 1) {
- /* relu after sum */
- const int sum_idx = p.contain(sum, 0)
- ? 0 : (p.contain(sum, 1) ? 1 : -1);
- if (sum_idx == -1)
- return false;
-
- return false
- || p.contain(eltwise, sum_idx + 1)
- || jcp.dst_dt == data_type::u8;
- }
-
- return false;
-}
-
void jit_avx512_core_x8s8s32x_1x1_conv_kernel::bcast_loop(int load_loop_blk)
{
mov(aux1_reg_bcast_data, reg_bcast_data);
};
auto vreg_accum = [=](int i_load, int i_ur) {
- return Zmm(i_ur * load_loop_blk + i_load);
+ return Zmm(i_ur + i_load * ur);
};
auto zmm_bias_alpha = [=]() {
zmm_t mask_zmm = mask_flag ? r | ktail_mask | T_z : r;
vmulps(mask_zmm, r, scale_ptr(i_load));
- if (maybe_relu(0)) {
- vpxord(zmm_zero, zmm_zero, zmm_zero);
- vmaxps(r, zmm_zero, r);
- }
- if (p_sum_scale) { // post_op: sum
- vpxord(zmm_zero, zmm_zero, zmm_zero);
- auto zmm_prev_dst = zmm_zero;
+ }
+ }
+
+ int eltwise_inj_idx = 0;
+ int depthwise_inj_idx = 0;
+ for (int i = 0; i < p.len_; i++) {
+ auto& post_op = p.entry_[i];
+ if (post_op.is_eltwise()) {
+ eltwise_injectors[eltwise_inj_idx]->compute_vector_range(0, ur * load_loop_blk);
+
+ eltwise_inj_idx++;
+ } else if (post_op.is_depthwise()) {
+ mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
+ mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
+
+ add(reg_d_weights, reg_oc_off);
+ add(reg_d_bias, reg_oc_off);
- cvt2ps(jcp.dst_dt, zmm_prev_dst, output_ptr(i_load, i_ur),
- mask_flag);
+ for (int k = 0; k < load_loop_blk; k++) {
+ depthwise_injectors[depthwise_inj_idx]->compute_vector_range(
+ k * ur, k * ur + ur, reg_d_weights, reg_d_bias);
- if (*p_sum_scale == 1.f)
- vaddps(r, zmm_prev_dst);
- else
- vfmadd231ps(r, zmm_prev_dst, zword_b[reg_ptr_sum_scale]);
+ add(reg_d_weights, jcp.oc_block * sizeof(float));
+ add(reg_d_bias, jcp.oc_block * sizeof(float));
}
- if (maybe_relu(1)) {
+
+ depthwise_inj_idx++;
+ } else if (post_op.is_sum(false)) {
+ for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
+ const bool mask_flag = mask_flag_in &&
+ i_load == load_loop_blk - 1;
+ for (int i_ur = 0; i_ur < ur; ++i_ur) {
+ vpxord(zmm_zero, zmm_zero, zmm_zero);
+ auto zmm_prev_dst = zmm_zero;
+
+ auto r = vreg_accum(i_load, i_ur);
+ cvt2ps(jcp.dst_dt, zmm_prev_dst, output_ptr(i_load, i_ur),
+ mask_flag);
+
+ if (*p_sum_scale == 1.f)
+ vaddps(r, zmm_prev_dst);
+ else
+ vfmadd231ps(r, zmm_prev_dst, zword_b[reg_ptr_sum_scale]);
+ }
+ }
+ }
+ }
+
+ for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
+ const bool mask_flag = mask_flag_in &&
+ i_load == load_loop_blk - 1;
+ for (int i_ur = 0; i_ur < ur; ++i_ur) {
+ auto r = vreg_accum(i_load, i_ur);
+ if (jcp.dst_dt == data_type::u8) {
vpxord(zmm_zero, zmm_zero, zmm_zero);
vmaxps(r, zmm_zero, r);
}
for (int i_ur = 0; i_ur < ur; ++i_ur) {
auto r = vreg_accum(i_load, i_ur);
zmm_t r_zmm = mask_flag ? r | ktail_mask : r;
+
switch (jcp.dst_dt) {
case data_type::f32:
case data_type::s32:
Label reduce_loop;
Label reduce_loop_tail;
+ push(reg_oc_off);
+
mov(aux_reg_load_data, reg_load_data);
mov(aux_reg_bcast_data, aux1_reg_bcast_data);
fma_block(false);
}
+ pop(reg_oc_off);
+
if (jcp.oc_without_padding != jcp.oc) {
Label end_store, common_store;
mov(EVEX_compress_addr(rsp, reg_bcast_data_off), reg_bcast_data);
void jit_avx512_core_x8s8s32x_1x1_conv_kernel::generate()
{
+ const auto &p = attr_.post_ops_;
+ for (int i = 0; i < p.len_; i++) {
+ auto &post_op = p.entry_[i];
+ if (post_op.is_eltwise()) {
+ eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<avx512_common>(
+ this,
+ post_op.eltwise.alg,
+ post_op.eltwise.alpha,
+ post_op.eltwise.beta
+ ));
+ } else if (post_op.is_depthwise()) {
+ depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32<avx512_common>(
+ this,
+ post_op.depthwise.alg
+ ));
+ }
+ }
+
preamble();
xor_(reg_scratch, reg_scratch);
mov(EVEX_compress_addr(rsp, bcast_loop_work_off), reg_bcast_loop_work);
mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]);
mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]);
-
+ mov(reg_oc_off, ptr[param1 + GET_OFF(oc_off)]);
auto load_loop_body = [=](int load_loop_blk) {
bcast_loop(load_loop_blk);
add(reg_output_data,
load_loop_blk * jcp.load_block * jcp.typesize_out);
sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step);
+ add(reg_oc_off, load_loop_blk * jcp.oc_block * sizeof(float));
};
const int simd_w = 16;
cmp(reg_load_loop_work, 0);
je(load_loop_blk[num_ur_cases], T_NEAR);
}
+
+ for (int _i = 1; _i <= label_idx + 1; _i++) {
+ prefetcht0(ptr [ reg_load_data + _i * jcp.ic * jcp.oc_block ]);
+ prefetcht1(ptr [ reg_output_data + _i * jcp.oc_block ]);
+ }
+
load_loop_body(label_idx + 1);
if (label_idx - 1 > 0) {
cmp(reg_load_loop_work, 2 * label_idx * simd_w);
add(rsp, stack_space_needed);
postamble();
+
+ for (auto& inj : eltwise_injectors)
+ inj->prepare_table();
}
bool jit_avx512_core_x8s8s32x_1x1_conv_kernel::post_ops_ok(
using namespace primitive_kind;
const auto &p = attr.post_ops_;
- auto is_relu = [&](int idx) {
- return p.entry_[idx].kind == eltwise
- && p.entry_[idx].eltwise.scale == 1.
- && p.entry_[idx].eltwise.alg == alg_kind::eltwise_relu
- && p.entry_[idx].eltwise.alpha == 0.;
- };
+ auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
+ auto is_depthwise = [&](int idx) { return p.entry_[idx].is_depthwise(); };
+ auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(false); };
+ auto is_simple = [&](int idx) { return is_eltwise(idx) || is_depthwise(idx); };
switch (p.len_) {
- case 0: return true;
- case 1: return true
- && IMPLICATION(jcp.with_eltwise, p.contain(sum, 0))
- && IMPLICATION(!jcp.with_eltwise, is_relu(0) || p.contain(sum, 0));
- case 2: return true
- && IMPLICATION(jcp.with_eltwise, p.contain(sum, 0) && is_relu(1))
- && IMPLICATION(!jcp.with_eltwise, false
- || (p.contain(sum, 0) && is_relu(1))
- || (p.contain(sum, 1) && is_relu(0)));
- case 3: return true
- && jcp.with_eltwise == false
- && (is_relu(0) && p.contain(sum, 1) && is_relu(2));
- default: return false;
+ case 0: return true;
+ case 1: return is_simple(0) || is_sum(0);
+ case 2: return (is_sum(0) && is_simple(1)) || (is_simple(0) && is_sum(1)) ||
+ (is_simple(0) && is_simple(1));
+ case 3: return (is_simple(0) && is_sum(1) && is_simple(2));
+ default: return false;
}
return false;
jit_1x1_conv_conf_t &jcp, const convolution_desc_t &cd,
const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d,
const memory_desc_wrapper &dst_d, const memory_desc_wrapper &bias_d,
- const primitive_attr_t &attr, bool with_relu, float relu_negative_slope,
- int nthreads, bool reduce_src)
-{
+ const primitive_attr_t &attr, int nthreads, bool reduce_src) {
if (!mayiuse(avx512_core)) return status::unimplemented;
const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
jcp.stride_w = cd.strides[1];
jcp.src_fmt = src_d.format();
jcp.with_bias = cd.bias_desc.format != memory_format::undef;
- jcp.with_eltwise = with_relu;
- jcp.eltwise_alpha = relu_negative_slope;
- if (!IMPLICATION(with_relu, relu_negative_slope == 0.))
- return status::unimplemented;
jcp.signed_input = (src_d.data_type() == data_type::s8) ? true : false;
max_regs = 8;
jcp.expl_bcast = true;
- const int spatial = jcp.oh;
- jcp.ur = 1;
- for (int ur_w = max_regs; ur_w >= min_regs; ur_w--) {
- if ((spatial >= size_treshold && spatial % ur_w == 0)
- || (spatial < size_treshold && jcp.os % ur_w == 0)) {
- jcp.ur = ur_w;
- break;
- }
- }
- if (jcp.ur == 1) {
+ if (jcp.mb == 1 && jcp.ic > 128
+ && (jcp.oh <= size_treshold && jcp.ow <= size_treshold)) {
jcp.ur = nstl::min(max_regs, jcp.os);
- int os_tail = jcp.os % max_regs;
- for (int i = max_regs; i >= min_regs; i--) {
- int i_tail = jcp.os % i;
- if (i_tail > os_tail || i_tail == 0) {
- jcp.ur = i;
- os_tail = i_tail;
- if (i_tail == 0)
- break;
+ } else {
+ const int spatial = jcp.oh;
+ jcp.ur = 1;
+ for (int ur_w = max_regs; ur_w >= min_regs; ur_w--) {
+ if ((spatial >= size_treshold && spatial % ur_w == 0)
+ || (spatial < size_treshold && jcp.os % ur_w == 0)) {
+ jcp.ur = ur_w;
+ break;
+ }
+ }
+ if (jcp.ur == 1) {
+ jcp.ur = nstl::min(max_regs, jcp.os);
+ int os_tail = jcp.os % max_regs;
+ for (int i = max_regs; i >= min_regs; i--) {
+ int i_tail = jcp.os % i;
+ if (i_tail > os_tail || i_tail == 0) {
+ jcp.ur = i;
+ os_tail = i_tail;
+ if (i_tail == 0)
+ break;
+ }
}
}
}
return status::success;
}
+void jit_avx512_core_x8s8s32x_1x1_conv_kernel::init_scratchpad(
+ memory_tracking::registrar_t &scratchpad,
+ const jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) {
+ using namespace mkldnn::impl::memory_tracking::names;
+
+ if (jcp.signed_input && jcp.ver != ver_vnni) {
+ size_t count = nstl::max(attr.output_scales_.count_, 16);
+ scratchpad.book(key_conv_adjusted_scales, sizeof(float) * count);
+ }
+}
+
}
}
}