1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include <common/memory_tracking.hpp>
18 #include "c_types_map.hpp"
20 #include "type_helpers.hpp"
22 #include "cpu_memory.hpp"
24 #include "jit_uni_def_conv_kernel_f32.hpp"
26 #define GET_OFF(field) offsetof(jit_def_conv_call_s, field)
32 using namespace mkldnn::impl::prop_kind;
33 using namespace mkldnn::impl::memory_format;
34 using namespace mkldnn::impl::memory_tracking::names;
35 using namespace mkldnn::impl::utils;
37 using namespace Xbyak;
39 template <cpu_isa_t isa>
40 void jit_uni_def_conv_fwd_kernel_f32<isa>::apply_filter(int ow_step, int oc_blocks_step, int oc_step, int ic_step) {
41 int repeats = isa == sse42 && oc_step > (jcp.oc_block / 2) ? 2 : 1;
43 for (int kh = 0; kh < jcp.kh; kh++) {
44 for (int kw = 0; kw < jcp.kw; kw++) {
45 for (int ic = 0; ic < ic_step; ic++) {
46 for (int ow = 0; ow < ow_step; ow++) {
47 Vmm vmm_src = get_vmm_src(ow);
48 size_t inp_off = (size_t) ow * jcp.kh * jcp.kw * jcp.ic + kh * jcp.kw * jcp.ic + kw * jcp.ic + ic;
50 uni_vbroadcastss(vmm_src, ptr[aux2_reg_input_buffer + inp_off * jcp.typesize_in]);
53 for (int r = 0; r < repeats; r++) {
54 for (int ocb = 0; ocb < oc_blocks_step; ocb++) {
55 Vmm vmm_ker = get_vmm_ker(0);
56 size_t ker_off = (size_t) ocb * jcp.nb_ic * jcp.kh * jcp.kw * jcp.ic_block * jcp.oc_block +
57 kh * jcp.kw * jcp.ic_block * jcp.oc_block +
58 kw * jcp.ic_block * jcp.oc_block +
59 ic * jcp.oc_block + r * jcp.oc_block / 2;
61 uni_vmovups(vmm_ker, ptr[aux2_reg_kernel + ker_off * jcp.typesize_in]);
62 for (int ow = 0; ow < ow_step; ow++) {
63 Vmm vmm_src = get_vmm_src(ow);
64 Vmm vmm_acc = get_vmm_acc(r * jcp.ur_w * jcp.nb_oc_blocking + ocb * ow_step + ow);
66 if (isa == sse42 && ow > 0) {
67 uni_vmovups(vmm_ker, ptr[aux2_reg_kernel + ker_off * jcp.typesize_in]);
70 uni_vfmadd231ps(vmm_acc, vmm_ker, vmm_src);
79 template <cpu_isa_t isa>
80 void jit_uni_def_conv_fwd_kernel_f32<isa>::init_accums(int ow_step, int oc_blocks_step, int oc_step) {
81 int repeats = isa == sse42 && oc_step > (jcp.oc_block / 2) ? 2 : 1;
82 for (int r = 0; r < repeats; r++) {
83 for (int ocb = 0; ocb < oc_blocks_step; ocb++) {
84 for (int ow = 0; ow < ow_step; ow++) {
85 Vmm vmm_acc = get_vmm_acc(r * jcp.ur_w * jcp.nb_oc_blocking + ocb * ow_step + ow);
87 uni_vpxor(vmm_acc, vmm_acc, vmm_acc);
93 template <cpu_isa_t isa>
94 void jit_uni_def_conv_fwd_kernel_f32<isa>::ic_loop(int ow_step, int oc_blocks_step, int oc_step) {
102 mov(aux2_reg_kernel, aux_reg_kernel);
103 mov(aux2_reg_input_buffer, reg_input_buffer);
105 mov(reg_ic_iter, jcp.ic);
107 init_accums(ow_step, oc_blocks_step, oc_step);
110 cmp(reg_ic_iter, jcp.ic_block);
113 apply_filter(ow_step, oc_blocks_step, oc_step, jcp.ic_block);
115 add(aux2_reg_input_buffer, jcp.ic_block * jcp.typesize_in);
116 add(aux2_reg_kernel, jcp.kh * jcp.kw * jcp.ic_block * jcp.oc_block * jcp.typesize_in);
117 sub(reg_ic_iter, jcp.ic_block);
118 jmp(ic_main_loop, T_NEAR);
122 if (jcp.ic % jcp.ic_block != 0) {
123 apply_filter(ow_step, oc_blocks_step, oc_step, jcp.ic % jcp.ic_block);
131 template <cpu_isa_t isa>
132 void jit_uni_def_conv_fwd_kernel_f32<isa>::interpolate_input(int ow_step) {
136 mov(reg_table, l_table);
137 mov(aux_reg_def_off, reg_def_off);
138 mov(aux_reg_input, reg_input);
139 mov(aux2_reg_input_buffer, aux_reg_input_buffer);
140 xor_(reg_dg_iter, reg_dg_iter);
142 const int ic_per_def_group = jcp.ic / jcp.dg;
144 cmp(reg_dg_iter, jcp.dg);
145 jge(dg_loop_end, T_NEAR);
147 for (int ow = 0; ow < ow_step; ow++) {
148 for (int kh = 0; kh < jcp.kh; kh++) {
149 for (int kw = 0; kw < jcp.kw; kw++) {
150 Label init_with_zeros;
156 Label h_sec_opt_exit;
158 Label w_sec_opt_exit;
160 mov(aux2_reg_input, aux_reg_input);
161 add(aux2_reg_input, (ow * jcp.stride_w * jcp.ic) * jcp.typesize_in);
163 mov(aux3_reg_input_buffer, aux2_reg_input_buffer);
164 add(aux3_reg_input_buffer, (ow * jcp.kh * jcp.kw * jcp.ic) * jcp.typesize_in);
166 Xmm xmm_tmp = Xmm(0);
168 Xmm xmm_map_h = Xmm(2);
169 Xmm xmm_ih_in = Xmm(4);
170 Xmm xmm_ih_im = Xmm(1);
171 Xmm xmm_cur_height = xmm_ih_im;
172 Xmm xmm_h_low = xmm_ih_in;
173 Xmm xmm_h_high = xmm_cur_height;
174 Xmm xmm_lh = xmm_map_h;
177 Xmm xmm_map_w = Xmm(6);
178 Xmm xmm_iw_in = Xmm(8);
179 Xmm xmm_iw_im = Xmm(5);
180 Xmm xmm_cur_width = xmm_iw_im;
181 Xmm xmm_w_low = xmm_iw_in;
182 Xmm xmm_w_high = xmm_cur_width;
183 Xmm xmm_lw = xmm_map_w;
186 Xmm xmm_v1_off = Xmm(9);
187 Xmm xmm_v2_off = Xmm(10);
188 Xmm xmm_v3_off = Xmm(11);
189 Xmm xmm_v4_off = Xmm(12);
191 Xmm xmm_w1 = xmm_h_low;
192 Xmm xmm_w2 = xmm_h_high;
193 Xmm xmm_w3 = xmm_w_low;
194 Xmm xmm_w4 = xmm_w_high;
201 Vmm vmm_w1 = Vmm(xmm_h_low.getIdx());
202 Vmm vmm_w2 = Vmm(xmm_h_high.getIdx());
203 Vmm vmm_w3 = Vmm(xmm_w_low.getIdx());
204 Vmm vmm_w4 = Vmm(xmm_w_high.getIdx());
206 Vmm vmm_v1 = Vmm(xmm_lh.getIdx());
207 Vmm vmm_v2 = Vmm(xmm_hh.getIdx());
208 Vmm vmm_v3 = Vmm(xmm_lw.getIdx());
209 Vmm vmm_v4 = Vmm(xmm_hw.getIdx());
211 size_t def_off_h = ((2 * (kh * jcp.kw + kw) + 0) * jcp.oh * jcp.ow) + ow;
212 mov(reg_tmp_32, ptr[aux_reg_def_off + def_off_h * jcp.typesize_off]);
213 movq(xmm_tmp, reg_tmp_64);
214 mov(reg_tmp_32, float2int((float) (kh * (jcp.dilate_h + 1))));
215 movq(xmm_map_h, reg_tmp_64);
216 addss(xmm_map_h, xmm_tmp);
218 mov(reg_tmp_32, jcp.stride_h);
219 imul(reg_tmp_32, reg_oh_pos);
220 sub(reg_tmp_32, jcp.t_pad);
221 movq(xmm_ih_in, reg_tmp_64);
223 cvtsi2ss(xmm_ih_im, reg_tmp_32);
224 addss(xmm_ih_im, xmm_map_h);
226 movss(xmm_tmp, xmm_ih_im);
227 cmpss(xmm_tmp, table_val(0), 1);
228 movq(reg_tmp_64, xmm_tmp);
230 jne(init_with_zeros, T_NEAR);
232 cmpss(xmm_ih_im, table_val(1), 1);
233 movq(reg_tmp_64, xmm_ih_im);
235 je(init_with_zeros, T_NEAR);
238 size_t def_off_w = ((2 * (kh * jcp.kw + kw) + 1) * jcp.oh * jcp.ow) + ow;
239 mov(reg_tmp_32, ptr[aux_reg_def_off + def_off_w * jcp.typesize_off]);
240 movq(xmm_tmp, reg_tmp_64);
241 mov(reg_tmp_32, float2int((float) (kw * (jcp.dilate_w + 1))));
242 movq(xmm_map_w, reg_tmp_64);
243 addss(xmm_map_w, xmm_tmp);
245 mov(reg_tmp_32, jcp.stride_w);
246 imul(reg_tmp_32, reg_ow_pos);
247 sub(reg_tmp_32, jcp.l_pad - ow * jcp.stride_w);
248 movq(xmm_iw_in, reg_tmp_64);
250 cvtsi2ss(xmm_iw_im, reg_tmp_32);
251 addss(xmm_iw_im, xmm_map_w);
253 movss(xmm_tmp, xmm_iw_im);
254 cmpss(xmm_tmp, table_val(0), 1);
255 movq(reg_tmp_64, xmm_tmp);
257 jne(init_with_zeros, T_NEAR);
259 cmpss(xmm_iw_im, table_val(2), 1);
260 movq(reg_tmp_64, xmm_iw_im);
262 je(init_with_zeros, T_NEAR);
265 movd(xmm_cur_height, table_val(3));
266 psubd(xmm_cur_height, xmm_ih_in);
268 roundps(xmm_h_low, xmm_map_h, 1);
269 cvtps2dq(xmm_h_low, xmm_h_low);
271 movups(xmm_tmp, xmm_cur_height);
272 pcmpgtd(xmm_tmp, xmm_h_low);
274 movq(reg_tmp_64, xmm_tmp);
276 jne(h_sec_opt, T_NEAR);
278 movups(xmm_h_low, xmm_cur_height);
279 movups(xmm_h_high, xmm_h_low);
284 movups(xmm_h_high, xmm_h_low);
285 paddd(xmm_h_high, table_val(5));
289 cvtdq2ps(xmm_tmp, xmm_h_low);
290 subss(xmm_lh, xmm_tmp);
291 movss(xmm_hh, table_val(5));
292 cvtdq2ps(xmm_hh, xmm_hh);
293 subss(xmm_hh, xmm_lh);
296 movd(xmm_cur_width, table_val(4));
297 psubd(xmm_cur_width, xmm_iw_in);
299 roundps(xmm_w_low, xmm_map_w, 1);
300 cvtps2dq(xmm_w_low, xmm_w_low);
302 movups(xmm_tmp, xmm_cur_width);
303 pcmpgtd(xmm_tmp, xmm_w_low);
305 movq(reg_tmp_64, xmm_tmp);
307 jne(w_sec_opt, T_NEAR);
309 movups(xmm_w_low, xmm_cur_width);
310 movups(xmm_w_high, xmm_w_low);
315 movups(xmm_w_high, xmm_w_low);
316 paddd(xmm_w_high, table_val(5));
320 cvtdq2ps(xmm_tmp, xmm_w_low);
321 subss(xmm_lw, xmm_tmp);
322 movss(xmm_hw, table_val(5));
323 cvtdq2ps(xmm_hw, xmm_hw);
324 subss(xmm_hw, xmm_lw);
327 movups(xmm_v1_off, table_val(2));
328 cvtps2dq(xmm_v1_off, xmm_v1_off);
329 movups(xmm_v3_off, xmm_v1_off);
331 pmulld(xmm_v1_off, xmm_h_low);
332 movups(xmm_v2_off, xmm_v1_off);
333 paddd(xmm_v1_off, xmm_w_low);
334 paddd(xmm_v2_off, xmm_w_high);
336 pmulld(xmm_v3_off, xmm_h_high);
337 movups(xmm_v4_off, xmm_v3_off);
338 paddd(xmm_v3_off, xmm_w_low);
339 paddd(xmm_v4_off, xmm_w_high);
342 movss(xmm_w1, xmm_hh);
343 mulss(xmm_w1, xmm_hw);
344 uni_vbroadcastss(vmm_w1, xmm_w1);
346 movss(xmm_w2, xmm_hh);
347 mulss(xmm_w2, xmm_lw);
348 uni_vbroadcastss(vmm_w2, xmm_w2);
350 movss(xmm_w3, xmm_lh);
351 mulss(xmm_w3, xmm_hw);
352 uni_vbroadcastss(vmm_w3, xmm_w3);
354 movss(xmm_w4, xmm_lh);
355 mulss(xmm_w4, xmm_lw);
356 uni_vbroadcastss(vmm_w4, xmm_w4);
358 int simd_w = vlen / jcp.typesize_in;
359 mov(reg_ic_iter, ic_per_def_group);
362 cmp(reg_ic_iter, simd_w);
363 jl(ic_loop_tail, T_NEAR);
365 size_t input_buffer_off = (size_t) kh * jcp.kw * jcp.ic + kw * jcp.ic;
367 pmovsxdq(xmm_v1_off, xmm_v1_off);
368 movq(reg_tmp_64, xmm_v1_off);
369 imul(reg_tmp_64, reg_tmp_64, jcp.ic * jcp.typesize_in);
370 add(reg_tmp_64, aux2_reg_input);
371 uni_vmovups(vmm_v1, ptr[reg_tmp_64]);
372 uni_vmulps(vmm_v1, vmm_v1, vmm_w1);
374 pmovsxdq(xmm_v2_off, xmm_v2_off);
375 movq(reg_tmp_64, xmm_v2_off);
376 imul(reg_tmp_64, reg_tmp_64, jcp.ic * jcp.typesize_in);
377 add(reg_tmp_64, aux2_reg_input);
378 uni_vmovups(vmm_v2, ptr[reg_tmp_64]);
379 uni_vmulps(vmm_v2, vmm_v2, vmm_w2);
381 pmovsxdq(xmm_v3_off, xmm_v3_off);
382 movq(reg_tmp_64, xmm_v3_off);
383 imul(reg_tmp_64, reg_tmp_64, jcp.ic * jcp.typesize_in);
384 add(reg_tmp_64, aux2_reg_input);
385 uni_vmovups(vmm_v3, ptr[reg_tmp_64]);
386 uni_vmulps(vmm_v3, vmm_v3, vmm_w3);
388 pmovsxdq(xmm_v4_off, xmm_v4_off);
389 movq(reg_tmp_64, xmm_v4_off);
390 imul(reg_tmp_64, reg_tmp_64, jcp.ic * jcp.typesize_in);
391 add(reg_tmp_64, aux2_reg_input);
392 uni_vmovups(vmm_v4, ptr[reg_tmp_64]);
393 uni_vmulps(vmm_v4, vmm_v4, vmm_w4);
395 uni_vaddps(vmm_v1, vmm_v1, vmm_v2);
396 uni_vaddps(vmm_v1, vmm_v1, vmm_v3);
397 uni_vaddps(vmm_v1, vmm_v1, vmm_v4);
398 uni_vmovups(ptr[aux3_reg_input_buffer + input_buffer_off * jcp.typesize_in], vmm_v1);
400 add(aux2_reg_input, simd_w * jcp.typesize_in);
401 add(aux3_reg_input_buffer, simd_w * jcp.typesize_in);
402 sub(reg_ic_iter, simd_w);
403 jmp(ic_loop_main, T_NEAR);
409 jl(loop_end, T_NEAR);
411 size_t input_buffer_off = (size_t) kh * jcp.kw * jcp.ic + kw * jcp.ic;
413 pmovsxdq(xmm_v1_off, xmm_v1_off);
414 movq(reg_tmp_64, xmm_v1_off);
415 imul(reg_tmp_64, reg_tmp_64, jcp.ic * jcp.typesize_in);
416 add(reg_tmp_64, aux2_reg_input);
417 movss(xmm_v1, ptr[reg_tmp_64]);
418 mulss(xmm_v1, xmm_w1);
420 pmovsxdq(xmm_v2_off, xmm_v2_off);
421 movq(reg_tmp_64, xmm_v2_off);
422 imul(reg_tmp_64, reg_tmp_64, jcp.ic * jcp.typesize_in);
423 add(reg_tmp_64, aux2_reg_input);
424 movss(xmm_v2, ptr[reg_tmp_64]);
425 mulss(xmm_v2, xmm_w2);
427 pmovsxdq(xmm_v3_off, xmm_v3_off);
428 movq(reg_tmp_64, xmm_v3_off);
429 imul(reg_tmp_64, reg_tmp_64, jcp.ic * jcp.typesize_in);
430 add(reg_tmp_64, aux2_reg_input);
431 movss(xmm_v3, ptr[reg_tmp_64]);
432 mulss(xmm_v3, xmm_w3);
434 pmovsxdq(xmm_v4_off, xmm_v4_off);
435 movq(reg_tmp_64, xmm_v4_off);
436 imul(reg_tmp_64, reg_tmp_64, jcp.ic * jcp.typesize_in);
437 add(reg_tmp_64, aux2_reg_input);
438 movss(xmm_v4, ptr[reg_tmp_64]);
439 mulss(xmm_v4, xmm_w4);
441 addss(xmm_v1, xmm_v2);
442 addss(xmm_v1, xmm_v3);
443 addss(xmm_v1, xmm_v4);
444 movss(ptr[aux3_reg_input_buffer + input_buffer_off * jcp.typesize_in], xmm_v1);
446 add(aux2_reg_input, jcp.typesize_in);
447 add(aux3_reg_input_buffer, jcp.typesize_in);
449 jmp(ic_loop_tail, T_NEAR);
452 jmp(loop_end, T_NEAR);
459 cmp(reg_ic_iter, ic_per_def_group);
460 je(loop_end, T_NEAR);
462 size_t input_buffer_off = (size_t) kh * jcp.kw * jcp.ic + kw * jcp.ic;
464 pxor(xmm_tmp, xmm_tmp);
465 movss(ptr[aux3_reg_input_buffer + input_buffer_off * jcp.typesize_in], xmm_tmp);
466 add(aux3_reg_input_buffer, jcp.typesize_in);
468 jmp(ic_loop_zeros, T_NEAR);
476 add(aux_reg_def_off, 2 * jcp.kh * jcp.kw * jcp.oh * jcp.ow * jcp.typesize_off);
477 add(aux_reg_input, ic_per_def_group * jcp.typesize_in);
478 add(aux2_reg_input_buffer, ic_per_def_group * jcp.typesize_in);
480 jmp(dg_loop, T_NEAR);
486 template <cpu_isa_t isa>
487 void jit_uni_def_conv_fwd_kernel_f32<isa>::store_output(int ow_step, int oc_blocks_step, int oc_step) {
488 int repeats = isa == sse42 && oc_step > (jcp.oc_block / 2) ? 2 : 1;
491 for (int r = 0; r < repeats; r++) {
492 for (int ocb = 0; ocb < oc_blocks_step; ocb++) {
493 size_t bias_off = (size_t) ocb * jcp.oc_block + r * jcp.oc_block / 2;
494 uni_vmovups(Vmm(0), ptr[aux_reg_bias + bias_off * jcp.typesize_bia]);
496 for (int ow = 0; ow < ow_step; ow++) {
497 Vmm vmm_acc = get_vmm_acc(r * jcp.ur_w * jcp.nb_oc_blocking + ocb * ow_step + ow);
499 uni_vaddps(vmm_acc, vmm_acc, Vmm(0));
505 if (isa == avx512_common && oc_step != jcp.oc_block) {
506 int mask = (1 << oc_step) - 1;
507 mov(reg_tmp_32, mask);
508 kmovw(ktail_mask, reg_tmp_32);
511 for (int r = 0; r < repeats; r++) {
512 int tail_size = isa == sse42 ? nstl::min(jcp.oc_block / 2, oc_step - r * jcp.oc_block / 2) : oc_step;
513 bool is_scalar_store = isa == sse42 ? tail_size < jcp.oc_block / 2 : tail_size < jcp.oc_block;
514 if (is_scalar_store) {
515 for (int ow = 0; ow < ow_step; ow++) {
516 Vmm vmm_dst = get_vmm_acc(r * jcp.ur_w * jcp.nb_oc_blocking + ow);
517 Xmm xmm_dst = get_xmm_acc(r * jcp.ur_w * jcp.nb_oc_blocking + ow);
519 if (isa == avx512_common) {
520 size_t out_off = (size_t) ow * jcp.oc;
522 uni_vmovups(ptr[aux_reg_output + out_off * jcp.typesize_out], vmm_dst | ktail_mask);
524 for (int oc = 0; oc < tail_size; oc++) {
525 size_t out_off = (size_t) ow * jcp.oc + oc + r * (jcp.oc_block / 2);
527 movq(reg_tmp_64, xmm_dst);
528 mov(ptr[aux_reg_output + out_off * jcp.typesize_out], reg_tmp_32);
531 psrldq(vmm_dst, jcp.typesize_out);
533 Ymm ymm_dst = get_ymm_acc(ow);
534 Vmm vmm_tmp = Vmm(0);
535 Ymm ymm_tmp = Ymm(0);
537 vperm2i128(ymm_tmp, ymm_dst, ymm_dst, 0x01);
538 vpalignr(ymm_dst, vmm_tmp, ymm_dst, jcp.typesize_out);
544 for (int ocb = 0; ocb < oc_blocks_step; ocb++) {
545 for (int ow = 0; ow < ow_step; ow++) {
546 Vmm vmm_acc = get_vmm_acc(r * jcp.ur_w * jcp.nb_oc_blocking + ocb * ow_step + ow);
547 size_t out_off = (size_t) ow * jcp.oc + ocb * jcp.oc_block + r * (jcp.oc_block / 2);
549 uni_vmovups(ptr[aux_reg_output + out_off * jcp.typesize_out], vmm_acc);
556 template <cpu_isa_t isa>
557 void jit_uni_def_conv_fwd_kernel_f32<isa>::oc_loop(int ow_step) {
558 Label oc_unrolled_loop;
562 mov(aux_reg_input_buffer, reg_input_buffer);
569 interpolate_input(ow_step);
578 mov(aux_reg_kernel, reg_kernel);
579 mov(aux_reg_output, reg_output);
580 mov(aux_reg_bias, reg_bias);
582 mov(reg_oc_work, jcp.oc);
584 L(oc_unrolled_loop); {
585 cmp(reg_oc_work, jcp.nb_oc_blocking * jcp.oc_block);
586 jl(oc_main_loop, T_NEAR);
588 ic_loop(ow_step, jcp.nb_oc_blocking, jcp.oc_block);
589 store_output(ow_step, jcp.nb_oc_blocking, jcp.oc_block);
591 add(aux_reg_kernel, jcp.nb_oc_blocking * jcp.nb_ic * jcp.kh * jcp.kw * jcp.ic_block * jcp.oc_block * jcp.typesize_in);
592 add(aux_reg_output, jcp.nb_oc_blocking * jcp.oc_block * jcp.typesize_out);
593 add(aux_reg_bias, jcp.nb_oc_blocking * jcp.oc_block * jcp.typesize_bia);
594 sub(reg_oc_work, jcp.nb_oc_blocking * jcp.oc_block);
596 jmp(oc_unrolled_loop, T_NEAR);
600 cmp(reg_oc_work, jcp.oc_block);
603 ic_loop(ow_step, 1, jcp.oc_block);
604 store_output(ow_step, 1, jcp.oc_block);
606 add(aux_reg_kernel, jcp.nb_ic * jcp.kh * jcp.kw * jcp.ic_block * jcp.oc_block * jcp.typesize_in);
607 add(aux_reg_output, jcp.oc_block * jcp.typesize_out);
608 add(aux_reg_bias, jcp.oc_block * jcp.typesize_bia);
609 sub(reg_oc_work, jcp.oc_block);
611 jmp(oc_main_loop, T_NEAR);
615 if (jcp.oc % jcp.oc_block != 0) {
616 ic_loop(ow_step, 1, jcp.oc % jcp.oc_block);
617 store_output(ow_step, 1, jcp.oc % jcp.oc_block);
624 template <cpu_isa_t isa>
625 void jit_uni_def_conv_fwd_kernel_f32<isa>::ow_loop() {
632 cmp(reg_ow_pos, jcp.ow - jcp.ur_w);
637 add(reg_input, jcp.ur_w * jcp.stride_w * jcp.ic * jcp.typesize_in);
638 add(reg_def_off, jcp.ur_w * jcp.typesize_off);
639 add(reg_output, jcp.ur_w * jcp.oc * jcp.typesize_out);
641 add(reg_ow_pos, jcp.ur_w);
642 jmp(ow_loop_main, T_NEAR);
646 if (jcp.ow % jcp.ur_w != 0)
647 oc_loop(jcp.ow % jcp.ur_w);
651 template <cpu_isa_t isa>
652 void jit_uni_def_conv_fwd_kernel_f32<isa>::generate()
656 mov(reg_input, ptr[this->param1 + GET_OFF(src)]);
657 mov(reg_def_off, ptr[this->param1 + GET_OFF(off)]);
658 mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]);
660 mov(reg_bias, ptr[this->param1 + GET_OFF(bias)]);
661 mov(reg_output, ptr[this->param1 + GET_OFF(dst)]);
662 mov(reg_input_buffer, ptr[this->param1 + GET_OFF(buf)]);
663 mov(reg_oh_pos, ptr[param1 + GET_OFF(oh_pos)]);
672 template <cpu_isa_t isa>
673 void jit_uni_def_conv_fwd_kernel_f32<isa>::prepare_table() {
676 for (size_t d = 0; d < vlen / sizeof(int32_t); ++d) {
680 for (size_t d = 0; d < vlen / sizeof(int32_t); ++d) {
681 dd(float2int((float)jcp.ih));
684 for (size_t d = 0; d < vlen / sizeof(int32_t); ++d) {
685 dd(float2int((float)jcp.iw));
688 for (size_t d = 0; d < vlen / sizeof(int32_t); ++d) {
692 for (size_t d = 0; d < vlen / sizeof(int32_t); ++d) {
696 for (size_t d = 0; d < vlen / sizeof(int32_t); ++d) {
701 template <cpu_isa_t isa>
702 bool jit_uni_def_conv_fwd_kernel_f32<isa>::post_ops_ok(jit_def_conv_conf_t &jcp, const primitive_attr_t &attr) {
703 const auto &p = attr.post_ops_;
708 template <cpu_isa_t isa>
709 status_t jit_uni_def_conv_fwd_kernel_f32<isa>::init_conf(jit_def_conv_conf_t &jcp,
710 const deformable_convolution_desc_t &cd, cpu_memory_t::pd_t &src_pd,
711 cpu_memory_t::pd_t &offsets_pd, cpu_memory_t::pd_t &weights_pd,
712 cpu_memory_t::pd_t &dst_pd, cpu_memory_t::pd_t &bias_pd,
713 const primitive_attr_t &attr)
715 if (!mayiuse(isa)) return status::unimplemented;
717 const memory_desc_wrapper src_d(&src_pd);
718 const memory_desc_wrapper offsets_d(&offsets_pd);
719 const memory_desc_wrapper weights_d(&weights_pd);
720 const memory_desc_wrapper dst_d(&dst_pd);
721 const memory_desc_wrapper bias_d(&bias_pd);
723 jcp.prop_kind = cd.prop_kind;
725 jcp.dg = cd.deformable_group;
727 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
728 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
729 jcp.mb = src_d.dims()[0];
731 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
732 jcp.ic = src_d.dims()[1] / jcp.ngroups;
734 jcp.ih = src_d.dims()[2];
735 jcp.iw = src_d.dims()[3];
736 jcp.oh = dst_d.dims()[2];
737 jcp.ow = dst_d.dims()[3];
739 jcp.kh = weights_d.dims()[with_groups + 2];
740 jcp.kw = weights_d.dims()[with_groups + 3];
742 jcp.t_pad = cd.padding[0][0];
743 jcp.l_pad = cd.padding[0][1];
745 jcp.stride_h = cd.strides[0];
746 jcp.stride_w = cd.strides[1];
748 jcp.dilate_h = cd.dilates[0];
749 jcp.dilate_w = cd.dilates[1];
751 jcp.with_bias = cd.bias_desc.format != memory_format::undef;
753 const int simd_w = isa == avx512_common ? 16 : 8;
754 jcp.ic_block = simd_w;
755 jcp.nb_ic = div_up(jcp.ic, jcp.ic_block);
757 jcp.oc_block = simd_w;
758 jcp.oc_padded = rnd_up(jcp.oc, jcp.oc_block);
759 jcp.nb_oc = div_up(jcp.oc, jcp.oc_block);
761 if (jcp.ngroups != 1)
762 return status::unimplemented;
764 if (jcp.ic % jcp.dg != 0)
765 return status::unimplemented;
767 if (!post_ops_ok(jcp, attr))
768 return status::unimplemented;
770 auto desired_act_fmt = nhwc;
771 auto desired_off_fmt = nchw;
772 auto desired_wei_fmt = with_groups ? isa == avx512_common ? gOIhw16i16o : gOIhw8i8o
773 : isa == avx512_common ? OIhw16i16o : OIhw8i8o;
775 if (src_d.format() == any)
776 CHECK(src_pd.set_format(desired_act_fmt));
777 if (src_d.format() != desired_act_fmt)
778 return status::unimplemented;
780 if (offsets_d.format() == any)
781 CHECK(offsets_pd.set_format(desired_off_fmt));
782 if (offsets_d.format() != desired_off_fmt)
783 return status::unimplemented;
785 if (weights_d.format() == any)
786 CHECK(weights_pd.set_format(desired_wei_fmt));
787 if (weights_d.format() != desired_wei_fmt)
788 return status::unimplemented;
791 if (bias_d.format() == any)
792 CHECK(bias_pd.set_format(x));
793 if (bias_d.format() != x)
794 return status::unimplemented;
797 if (dst_d.format() == any)
798 CHECK(dst_pd.set_format(desired_act_fmt));
799 if (dst_d.format() != desired_act_fmt)
800 return status::unimplemented;
802 jcp.src_dt = cd.src_descs[0].data_type;
803 jcp.off_dt = cd.src_descs[1].data_type;
804 jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
805 jcp.dst_dt = cd.dst_desc.data_type;
807 jcp.typesize_in = (int)types::data_type_size(jcp.src_dt);
808 jcp.typesize_off = (int)types::data_type_size(jcp.off_dt);
809 jcp.typesize_out = (int)types::data_type_size(jcp.dst_dt);
810 jcp.typesize_bia = (int)(jcp.with_bias ? types::data_type_size(jcp.bia_dt) : 0);
812 jcp.ur_w = isa == avx512_common ? 6 : 3;
813 jcp.nb_oc_blocking = isa == sse42 ? 2 : 4;
815 jcp.nthr = mkldnn_get_max_threads();
817 return status::success;
820 template <cpu_isa_t isa>
821 void jit_uni_def_conv_fwd_kernel_f32<isa>::init_scratchpad(
822 memory_tracking::registrar_t &scratchpad, const jit_def_conv_conf_t &jcp, const primitive_attr_t &attr) {
824 scratchpad.book(key_def_conv_buffer, (size_t)jcp.nthr * jcp.ur_w * jcp.kh * jcp.kw * jcp.ic * jcp.typesize_in);
825 if (jcp.oc != jcp.oc_padded) {
826 scratchpad.book(key_conv_padded_bias, (size_t)jcp.typesize_bia * jcp.oc_padded);
830 template struct jit_uni_def_conv_fwd_kernel_f32<avx512_common>;
831 template struct jit_uni_def_conv_fwd_kernel_f32<avx2>;
832 template struct jit_uni_def_conv_fwd_kernel_f32<sse42>;