1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "c_types_map.hpp"
19 #include "type_helpers.hpp"
21 #include "cpu_memory.hpp"
23 #include "jit_uni_x8s8s32x_1x1_conv_kernel.hpp"
25 #define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field)
33 using namespace mkldnn::impl::prop_kind;
34 using namespace mkldnn::impl::memory_format;
35 using namespace mkldnn::impl::utils;
36 using namespace mkldnn::impl::types;
38 using namespace Xbyak;
40 template <cpu_isa_t isa>
41 void jit_uni_x8s8s32x_1x1_conv_fwd_kernel<isa>::cvt2ps(data_type_t type_in,
42 Vmm vmm_in, const Xbyak::Operand &op) {
45 case data_type::s32: vmovups(vmm_in, op); break;
46 case data_type::s8: vpmovsxbd(vmm_in, op); break;
47 case data_type::u8: vpmovzxbd(vmm_in, op); break;
48 default: assert(!"unsupported data type");
50 if (type_in != data_type::f32)
51 vcvtdq2ps(vmm_in, vmm_in);
54 template <cpu_isa_t isa>
55 void jit_uni_x8s8s32x_1x1_conv_fwd_kernel<isa>::loop_os(int oc_loop_blk)
57 mov(aux_reg_dst_data, reg_dst_data);
62 mov(reg_ow_loop_work, jcp.ow);
65 assert(jcp.os_block == jcp.ur);
66 cmp(reg_ow_loop_work, jcp.ow_tail);
67 je(loop_ow_tail, T_NEAR);
69 ic_loop(oc_loop_blk, jcp.ur);
71 sub(reg_ow_loop_work, jcp.ur);
73 add(reg_src_data, jcp.os_loop_src_step);
74 add(aux_reg_dst_data, jcp.os_loop_dst_step);
76 sub(reg_loop_os_iter, jcp.os_block);
77 cmp(reg_loop_os_iter, jcp.os_block);
81 if (jcp.ow_tail > 0) {
82 ic_loop(oc_loop_blk, jcp.ow_tail);
85 add(reg_src_data, jcp.os_loop_src_tail_step);
86 add(aux_reg_dst_data, jcp.os_loop_dst_tail_step);
88 mov(reg_ow_loop_work, jcp.ow);
90 sub(reg_loop_os_iter, jcp.ow_tail);
91 cmp(reg_loop_os_iter, 0);
97 template <cpu_isa_t isa>
98 void jit_uni_x8s8s32x_1x1_conv_fwd_kernel<isa>::ic_loop(int oc_loop_blk, int ur)
100 auto vreg_wei = [=](int i) {
101 return Vmm(ur * oc_loop_blk + i);
104 auto vreg_accum_vmm = [=](int i, int j) {
105 return Vmm(j * oc_loop_blk + i);
108 auto vreg_accum_xmm = [=](int i, int j) {
109 return Xmm(j * oc_loop_blk + i);
112 auto src_ptr = [=](int u, int j) {
113 size_t offt = j * jcp.ic * jcp.stride_w + u*jcp.ic_block;
114 return ptr[aux_reg_src_data + jcp.typesize_in * offt];
117 auto wei_ptr = [=](int u, int i) {
118 size_t offt = i*jcp.nb_ic*jcp.oc_block*jcp.ic_block + u*jcp.ic_block * jcp.oc_block;
119 return ptr[aux_reg_weight_data + offt * jcp.typesize_in];
122 auto output_ptr = [=](int i, int j) {
123 return ptr[aux_reg_dst_data + (i * jcp.oc_block + j * jcp.oc) *
128 for (int i = 0; i < oc_loop_blk; ++i) {
129 for (int j = 0; j < ur; ++j) {
130 auto vmm_acc = vreg_accum_vmm(i, j);
131 uni_vpxor(vmm_acc, vmm_acc, vmm_acc);
135 for (int i = 0; i < oc_loop_blk; ++i)
136 uni_vmovdqu(vreg_wei(i), wei_ptr(0, i));
138 uni_vpbroadcastd(vreg_src, src_ptr(0, 0));
142 mov(reg_scales, ptr[this->param1 + GET_OFF(scales)]);
143 uni_vpxor(vmm_zero, vmm_zero, vmm_zero);
145 for (int j = 0; j < ur; ++j)
146 for (int i = 0; i < oc_loop_blk; ++i) {
147 int b_off = i*jcp.oc_block;
150 switch (jcp.bia_dt) {
152 case data_type::s32: vmovups(vmm_bias, ptr[reg_bias_data + b_off*jcp.typesize_bia]); break;
153 case data_type::s8: vpmovsxbd(vmm_bias, ptr[reg_bias_data + b_off*jcp.typesize_bia]); break;
154 case data_type::u8: vpmovzxbd(vmm_bias, ptr[reg_bias_data + b_off*jcp.typesize_bia]); break;
155 default: assert(!"unsupported dst data type");
158 if (jcp.bia_dt != data_type::f32)
159 vcvtdq2ps(vmm_bias, vmm_bias);
161 Vmm vmm_dst = vreg_accum_vmm(i, j);
162 Xmm xmm_dst = vreg_accum_xmm(i, j);
164 vcvtdq2ps(vmm_dst, vmm_dst);
167 vaddps(vmm_dst, vmm_dst, vmm_bias);
169 int s_off = jcp.is_oc_scale * (sizeof(float) * (i*jcp.oc_block));
170 vmulps(vmm_dst, vmm_dst, ptr[reg_scales + s_off]);
173 Ymm vmm_prev_dst = Ymm(12);
174 cvt2ps(jcp.dst_dt, vmm_prev_dst, output_ptr(i, j));
175 vaddps(vmm_dst, vmm_prev_dst);
179 vmaxps(vmm_dst, vmm_zero, vmm_dst);
182 vmaxps(vmm_dst, vmm_zero, vmm_dst);
184 if (jcp.dst_dt != data_type::f32) {
185 if (attr_.round_mode_ == round_mode::nearest)
186 if (isa == avx512_common) {
187 vcvtps2dq(vmm_dst | T_rn_sae, vmm_dst);
189 vcvtps2dq(vmm_dst, vmm_dst);
191 else if (attr_.round_mode_ == round_mode::down) {
192 if (isa == avx512_common) {
193 vcvtps2dq(vmm_dst | T_rd_sae, vmm_dst);
195 vroundps(vmm_dst, vmm_dst, 1);
196 vcvtps2dq(vmm_dst, vmm_dst);
199 assert(!"unimplemented");
202 switch (jcp.dst_dt) {
204 case data_type::s32: vmovups(output_ptr(i, j), vmm_dst); break;
206 if (isa == avx512_common) {
207 vpmovsdb(xmm_dst, vmm_dst);
208 vmovups(output_ptr(i, j), xmm_dst);
209 } else if (isa == avx2) {
210 Ymm ymm_dst = Ymm(vmm_dst.getIdx());
212 vpackssdw(ymm_dst, ymm_dst, ymm_dst);
213 vpermq(ymm_dst, ymm_dst, 0x08);
214 vpacksswb(xmm_dst, xmm_dst, xmm_dst);
215 vmovq(output_ptr(i, j), xmm_dst);
219 if (isa == avx512_common) {
220 vpmovusdb(xmm_dst, vmm_dst);
221 vmovups(output_ptr(i, j), xmm_dst);
222 } else if (isa == avx2) {
223 Ymm ymm_dst = Ymm(vmm_dst.getIdx());
225 vpackusdw(ymm_dst, ymm_dst, ymm_dst);
226 vpermq(ymm_dst, ymm_dst, 0x08);
227 vpackuswb(xmm_dst, xmm_dst, xmm_dst);
228 vmovq(output_ptr(i, j), xmm_dst);
231 default: assert(!"unknown dst_dt");
236 auto fma_block = [=]() {
237 for (int j = 0; j < ur; ++j) {
238 for (int i = 0; i < oc_loop_blk; i++) {
239 vpmaddubsw(vreg_sum_0, vreg_src, vreg_wei(i));
240 vpmaddwd(vreg_sum_0, vreg_sum_0, vmm_one);
241 vpaddd(vreg_accum_vmm(i, j), vreg_accum_vmm(i, j), vreg_sum_0);
244 uni_vmovdqu(vreg_wei(i), wei_ptr(1, i));
249 uni_vpbroadcastd(vreg_src, src_ptr(0, j + 1));
252 uni_vpbroadcastd(vreg_src, src_ptr(1, 0));
255 mov(aux_reg_weight_data, reg_weight_data);
256 mov(aux_reg_src_data, reg_src_data);
263 xor_(reg_loop_ic_iter, reg_loop_ic_iter);
265 cmp(reg_loop_ic_iter, jcp.nb_ic);
270 add(aux_reg_src_data, jcp.ic_block * jcp.typesize_in);
271 add(aux_reg_weight_data, jcp.ic_block * jcp.oc_block * jcp.typesize_in);
272 inc(reg_loop_ic_iter);
273 jmp(ic_loop, T_NEAR);
281 template <cpu_isa_t isa>
282 void jit_uni_x8s8s32x_1x1_conv_fwd_kernel<isa>::generate()
286 mov(reg_scratch, 0x1);
287 movq(xmm_one, reg_scratch);
288 vpbroadcastw(vmm_one, xmm_one);
290 mov(reg_weight_data, ptr[param1 + GET_OFF(oc_data)]);
291 mov(reg_dst_data, ptr[param1 + GET_OFF(output_data)]);
293 mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]);
296 mov(reg_oc_loop_work, ptr[param1 + GET_OFF(oc_dim)]);
297 mov(reg_src_data, ptr[param1 + GET_OFF(is_data)]);
298 mov(reg_loop_os_iter, ptr[param1 + GET_OFF(os_dim)]);
300 Label oc_blocks_tail_label;
303 int oc_blocks_tail = jcp.nb_oc % jcp.nb_oc_blocking;
305 cmp(reg_oc_loop_work, jcp.nb_oc_blocking);
306 jne(oc_blocks_tail ? oc_blocks_tail_label : exit_label, T_NEAR);
308 loop_os(jcp.nb_oc_blocking); // channel main loop
309 jmp(exit_label, T_NEAR);
311 if (oc_blocks_tail) {
312 L(oc_blocks_tail_label);
314 cmp(reg_oc_loop_work, oc_blocks_tail);
315 jne(exit_label, T_NEAR);
317 loop_os(oc_blocks_tail); // channel tail loop
325 template <cpu_isa_t isa>
326 bool jit_uni_x8s8s32x_1x1_conv_fwd_kernel<isa>::post_ops_ok(
327 jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) {
328 const auto &p = attr.post_ops_;
330 auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); };
331 auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
334 case 0: return true; // no post_ops
335 case 1: return !jcp.with_eltwise && (is_relu(0) || is_sum(0)); // sum OR relu
336 case 2: return !jcp.with_eltwise && (is_sum(0) && is_relu(1)); // sum->relu
337 default: return false;
343 template <cpu_isa_t isa>
344 bool jit_uni_x8s8s32x_1x1_conv_fwd_kernel<isa>::maybe_relu(int position) {
345 using namespace primitive_kind;
346 const auto &p = attr_.post_ops_;
349 /* relu before sum */
352 || p.contain(eltwise, 0)
353 || (jcp.dst_dt == data_type::u8 && !p.contain(sum, 0));
354 } else if (position == 1) {
356 const int sum_idx = p.contain(sum, 0)
357 ? 0 : (p.contain(sum, 1) ? 1 : -1);
362 || p.contain(eltwise, sum_idx + 1)
363 || jcp.dst_dt == data_type::u8;
369 template <cpu_isa_t isa>
370 status_t jit_uni_x8s8s32x_1x1_conv_fwd_kernel<isa>::init_conf(jit_1x1_conv_conf_t &jcp,
371 const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
372 const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
373 const memory_desc_wrapper &bias_pd, const primitive_attr_t &attr,
374 bool with_relu, float relu_negative_slope)
376 if (!mayiuse(isa)) return status::unimplemented;
378 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
380 jcp.prop_kind = cd.prop_kind;
382 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
383 jcp.mb = src_d.dims()[0];
385 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
386 jcp.ic = src_d.dims()[1] / jcp.ngroups;
388 jcp.ih = src_d.dims()[2];
389 jcp.iw = src_d.dims()[3];
390 jcp.oh = dst_d.dims()[2];
391 jcp.ow = dst_d.dims()[3];
393 jcp.kh = weights_d.dims()[with_groups + 2];
394 jcp.kw = weights_d.dims()[with_groups + 3];
396 jcp.t_pad = cd.padding[0][0];
397 jcp.l_pad = cd.padding[0][1];
399 jcp.stride_h = cd.strides[0];
400 jcp.stride_w = cd.strides[1];
402 jcp.with_bias = cd.bias_desc.format != memory_format::undef;
403 jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
404 jcp.dst_dt = cd.dst_desc.data_type;
406 jcp.src_fmt = src_d.format();
407 jcp.with_eltwise = with_relu;
408 jcp.eltwise_alpha = relu_negative_slope;
410 jcp.os = jcp.oh * jcp.ow;
411 jcp.is = jcp.ih * jcp.iw;
413 auto desired_wei_fmt = OhIw8o4i;
414 auto desired_gr_wei_fmt = gOhIw8o4i;
416 int simd_w = isa == avx512_common ? 16 : 8;
420 && src_d.format() == nhwc
421 && one_of(weights_d.format(), desired_wei_fmt, desired_gr_wei_fmt)
422 && one_of(cd.bias_desc.format, memory_format::undef, any, x)
423 && dst_d.format() == nhwc
424 && jcp.oc % simd_w == 0 && jcp.ic % simd_w == 0
425 && jcp.t_pad == 0 && jcp.l_pad == 0
426 && jcp.kh == 1 && jcp.kw == 1
427 && jcp.stride_h == 1 && jcp.stride_w == 1;
429 if (!args_ok) return status::unimplemented;
432 jcp.oc_block = simd_w;
435 jcp.ow_tail = jcp.ow % jcp.ur;
437 int oc_blocking{ 0 };
438 int oc_blocking_max{ 0 };
439 int os_blocking{ 0 };
440 int os_blocking_max{ 0 };
441 int ic_blocking{ 0 };
446 jcp.os_block = jcp.ur;
448 jcp.typesize_in = types::data_type_size(src_d.data_type());
449 jcp.typesize_out = types::data_type_size(dst_d.data_type());
450 jcp.typesize_acc = sizeof(int32_t);
451 jcp.typesize_bia = jcp.with_bias
452 ? types::data_type_size(bias_pd.data_type())
455 const auto &oscales = attr.output_scales_;
456 jcp.is_oc_scale = oscales.mask_ == 1 << 1;
458 const auto &p = attr.post_ops_;
459 jcp.with_sum = p.find(primitive_kind::sum) != -1;
461 assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0));
463 jcp.ic_loop_src_step = jcp.ic_block * jcp.ic_loop_unroll * jcp.typesize_in;
464 jcp.ic_loop_wei_step = jcp.ic_block * jcp.ic_loop_unroll * jcp.oc_block * jcp.typesize_in;
466 jcp.os_loop_dst_step = jcp.ur * jcp.oc * jcp.typesize_out;
467 jcp.os_loop_acc_step = jcp.ur * jcp.oc_block * jcp.typesize_acc;
468 jcp.os_loop_src_step = jcp.stride_w * jcp.ur * jcp.ic * jcp.typesize_in;
469 jcp.os_loop_dst_tail_step = jcp.ow_tail * jcp.oc * jcp.typesize_out;
470 jcp.os_loop_acc_tail_step = jcp.ow_tail * jcp.oc_block * jcp.typesize_acc;
471 jcp.os_loop_src_tail_step = jcp.stride_w * jcp.ow_tail * jcp.ic * jcp.typesize_in
472 + ((jcp.stride_h-1)*jcp.iw*jcp.ic*jcp.typesize_in);
474 oc_blocking = 4 * jcp.oc_block;
475 oc_blocking_max = 4 * jcp.oc_block;
476 os_blocking = 48; // affects oc balancing across threads
477 os_blocking_max = 320;
478 ic_blocking = 4*128; // affects L1$ utilization
481 assert(oc_blocking_max);
483 assert(os_blocking_max);
486 assert(jcp.os_block % jcp.ur == 0);
487 jcp.ur_tail = jcp.is_dim % jcp.ur;
489 jcp.nb_oh_blocking = nstl::max(1, os_blocking / jcp.ow);
490 jcp.nb_oh_blocking_max = nstl::max(1, os_blocking_max / jcp.ow);
491 jcp.nb_oc_blocking = oc_blocking / jcp.oc_block;
492 jcp.nb_oc_blocking_max = oc_blocking_max / jcp.oc_block;
493 jcp.nb_ic_blocking = ic_blocking / jcp.ic_block;
495 jcp.nb_oc = div_up(jcp.oc_dim, jcp.oc_block);
497 jcp.nb_ic = jcp.ic / jcp.ic_block;
499 return status::success;
502 template struct jit_uni_x8s8s32x_1x1_conv_fwd_kernel<avx2>;
503 template struct jit_uni_x8s8s32x_1x1_conv_fwd_kernel<sse42>;