updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_def_conv_kernel_f32.cpp
1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <common/memory_tracking.hpp>
18 #include "c_types_map.hpp"
19 #include "nstl.hpp"
20 #include "type_helpers.hpp"
21 #include "utils.hpp"
22 #include "cpu_memory.hpp"
23
24 #include "jit_uni_def_conv_kernel_f32.hpp"
25
26 #define GET_OFF(field) offsetof(jit_def_conv_call_s, field)
27
28 namespace mkldnn {
29 namespace impl {
30 namespace cpu {
31
32 using namespace mkldnn::impl::prop_kind;
33 using namespace mkldnn::impl::memory_format;
34 using namespace mkldnn::impl::memory_tracking::names;
35 using namespace mkldnn::impl::utils;
36
37 using namespace Xbyak;
38
39 template <cpu_isa_t isa>
40 void jit_uni_def_conv_fwd_kernel_f32<isa>::apply_filter(int ow_step, int oc_blocks_step, int oc_step, int ic_step) {
41     int repeats = isa == sse42 && oc_step > (jcp.oc_block / 2) ? 2 : 1;
42
43     for (int kh = 0; kh < jcp.kh; kh++) {
44         for (int kw = 0; kw < jcp.kw; kw++) {
45             for (int ic = 0; ic < ic_step; ic++) {
46                 for (int ow = 0; ow < ow_step; ow++) {
47                     Vmm vmm_src = get_vmm_src(ow);
48                     size_t inp_off = (size_t) ow * jcp.kh * jcp.kw * jcp.ic + kh * jcp.kw * jcp.ic + kw * jcp.ic + ic;
49
50                     uni_vbroadcastss(vmm_src, ptr[aux2_reg_input_buffer + inp_off * jcp.typesize_in]);
51                 }
52
53                 for (int r = 0; r < repeats; r++) {
54                     for (int ocb = 0; ocb < oc_blocks_step; ocb++) {
55                         Vmm vmm_ker = get_vmm_ker(0);
56                         size_t ker_off = (size_t) ocb * jcp.nb_ic * jcp.kh * jcp.kw * jcp.ic_block * jcp.oc_block +
57                                          kh * jcp.kw * jcp.ic_block * jcp.oc_block +
58                                          kw * jcp.ic_block * jcp.oc_block +
59                                          ic * jcp.oc_block + r * jcp.oc_block / 2;
60
61                         uni_vmovups(vmm_ker, ptr[aux2_reg_kernel + ker_off * jcp.typesize_in]);
62                         for (int ow = 0; ow < ow_step; ow++) {
63                             Vmm vmm_src = get_vmm_src(ow);
64                             Vmm vmm_acc = get_vmm_acc(r * jcp.ur_w * jcp.nb_oc_blocking + ocb * ow_step + ow);
65
66                             if (isa == sse42 && ow > 0) {
67                                 uni_vmovups(vmm_ker, ptr[aux2_reg_kernel + ker_off * jcp.typesize_in]);
68                             }
69
70                             uni_vfmadd231ps(vmm_acc, vmm_ker, vmm_src);
71                         }
72                     }
73                 }
74             }
75         }
76     }
77 }
78
79 template <cpu_isa_t isa>
80 void jit_uni_def_conv_fwd_kernel_f32<isa>::init_accums(int ow_step, int oc_blocks_step, int oc_step) {
81     int repeats = isa == sse42 && oc_step > (jcp.oc_block / 2) ? 2 : 1;
82     for (int r = 0; r < repeats; r++) {
83         for (int ocb = 0; ocb < oc_blocks_step; ocb++) {
84             for (int ow = 0; ow < ow_step; ow++) {
85                 Vmm vmm_acc = get_vmm_acc(r * jcp.ur_w * jcp.nb_oc_blocking + ocb * ow_step + ow);
86
87                 uni_vpxor(vmm_acc, vmm_acc, vmm_acc);
88             }
89         }
90     }
91 }
92
93 template <cpu_isa_t isa>
94 void jit_uni_def_conv_fwd_kernel_f32<isa>::ic_loop(int ow_step, int oc_blocks_step, int oc_step) {
95     Label ic_main_loop;
96     Label ic_tail;
97     Label exit;
98
99     push(reg_oc_work);
100     push(aux_reg_bias);
101
102     mov(aux2_reg_kernel, aux_reg_kernel);
103     mov(aux2_reg_input_buffer, reg_input_buffer);
104
105     mov(reg_ic_iter, jcp.ic);
106
107     init_accums(ow_step, oc_blocks_step, oc_step);
108
109     L(ic_main_loop); {
110         cmp(reg_ic_iter, jcp.ic_block);
111         jl(ic_tail, T_NEAR);
112
113         apply_filter(ow_step, oc_blocks_step, oc_step, jcp.ic_block);
114
115         add(aux2_reg_input_buffer, jcp.ic_block * jcp.typesize_in);
116         add(aux2_reg_kernel, jcp.kh * jcp.kw * jcp.ic_block * jcp.oc_block * jcp.typesize_in);
117         sub(reg_ic_iter, jcp.ic_block);
118         jmp(ic_main_loop, T_NEAR);
119     }
120
121     L(ic_tail); {
122         if (jcp.ic % jcp.ic_block != 0) {
123             apply_filter(ow_step, oc_blocks_step, oc_step, jcp.ic % jcp.ic_block);
124         }
125     }
126
127     pop(aux_reg_bias);
128     pop(reg_oc_work);
129 }
130
131 template <cpu_isa_t isa>
132 void jit_uni_def_conv_fwd_kernel_f32<isa>::interpolate_input(int ow_step) {
133     Label dg_loop;
134     Label dg_loop_end;
135
136     mov(reg_table, l_table);
137     mov(aux_reg_def_off, reg_def_off);
138     mov(aux_reg_input, reg_input);
139     mov(aux2_reg_input_buffer, aux_reg_input_buffer);
140     xor_(reg_dg_iter, reg_dg_iter);
141
142     const int ic_per_def_group = jcp.ic / jcp.dg;
143     L(dg_loop); {
144         cmp(reg_dg_iter, jcp.dg);
145         jge(dg_loop_end, T_NEAR);
146
147         for (int ow = 0; ow < ow_step; ow++) {
148             for (int kh = 0; kh < jcp.kh; kh++) {
149                 for (int kw = 0; kw < jcp.kw; kw++) {
150                     Label init_with_zeros;
151                     Label ic_loop_main;
152                     Label ic_loop_tail;
153                     Label ic_loop_zeros;
154                     Label loop_end;
155                     Label h_sec_opt;
156                     Label h_sec_opt_exit;
157                     Label w_sec_opt;
158                     Label w_sec_opt_exit;
159
160                     mov(aux2_reg_input, aux_reg_input);
161                     add(aux2_reg_input, (ow * jcp.stride_w * jcp.ic) * jcp.typesize_in);
162
163                     mov(aux3_reg_input_buffer, aux2_reg_input_buffer);
164                     add(aux3_reg_input_buffer, (ow * jcp.kh * jcp.kw * jcp.ic) * jcp.typesize_in);
165
166                     Xmm xmm_tmp = Xmm(0);
167
168                     Xmm xmm_map_h = Xmm(2);
169                     Xmm xmm_ih_in = Xmm(4);
170                     Xmm xmm_ih_im = Xmm(1);
171                     Xmm xmm_cur_height = xmm_ih_im;
172                     Xmm xmm_h_low = xmm_ih_in;
173                     Xmm xmm_h_high = xmm_cur_height;
174                     Xmm xmm_lh = xmm_map_h;
175                     Xmm xmm_hh = Xmm(3);
176
177                     Xmm xmm_map_w = Xmm(6);
178                     Xmm xmm_iw_in = Xmm(8);
179                     Xmm xmm_iw_im = Xmm(5);
180                     Xmm xmm_cur_width = xmm_iw_im;
181                     Xmm xmm_w_low = xmm_iw_in;
182                     Xmm xmm_w_high = xmm_cur_width;
183                     Xmm xmm_lw = xmm_map_w;
184                     Xmm xmm_hw = Xmm(7);
185
186                     Xmm xmm_v1_off = Xmm(9);
187                     Xmm xmm_v2_off = Xmm(10);
188                     Xmm xmm_v3_off = Xmm(11);
189                     Xmm xmm_v4_off = Xmm(12);
190
191                     Xmm xmm_w1 = xmm_h_low;
192                     Xmm xmm_w2 = xmm_h_high;
193                     Xmm xmm_w3 = xmm_w_low;
194                     Xmm xmm_w4 = xmm_w_high;
195
196                     Xmm xmm_v1 = xmm_lh;
197                     Xmm xmm_v2 = xmm_hh;
198                     Xmm xmm_v3 = xmm_lw;
199                     Xmm xmm_v4 = xmm_hw;
200
201                     Vmm vmm_w1 = Vmm(xmm_h_low.getIdx());
202                     Vmm vmm_w2 = Vmm(xmm_h_high.getIdx());
203                     Vmm vmm_w3 = Vmm(xmm_w_low.getIdx());
204                     Vmm vmm_w4 = Vmm(xmm_w_high.getIdx());
205
206                     Vmm vmm_v1 = Vmm(xmm_lh.getIdx());
207                     Vmm vmm_v2 = Vmm(xmm_hh.getIdx());
208                     Vmm vmm_v3 = Vmm(xmm_lw.getIdx());
209                     Vmm vmm_v4 = Vmm(xmm_hw.getIdx());
210
211                     size_t def_off_h = ((2 * (kh * jcp.kw + kw) + 0) * jcp.oh * jcp.ow) + ow;
212                     mov(reg_tmp_32, ptr[aux_reg_def_off + def_off_h * jcp.typesize_off]);
213                     movq(xmm_tmp, reg_tmp_64);
214                     mov(reg_tmp_32, float2int((float) (kh * (jcp.dilate_h + 1))));
215                     movq(xmm_map_h, reg_tmp_64);
216                     addss(xmm_map_h, xmm_tmp);
217
218                     mov(reg_tmp_32, jcp.stride_h);
219                     imul(reg_tmp_32, reg_oh_pos);
220                     sub(reg_tmp_32, jcp.t_pad);
221                     movq(xmm_ih_in, reg_tmp_64);
222
223                     cvtsi2ss(xmm_ih_im, reg_tmp_32);
224                     addss(xmm_ih_im, xmm_map_h);
225
226                     movss(xmm_tmp, xmm_ih_im);
227                     cmpss(xmm_tmp, table_val(0), 1);
228                     movq(reg_tmp_64, xmm_tmp);
229                     cmp(reg_tmp_32, 0);
230                     jne(init_with_zeros, T_NEAR);
231
232                     cmpss(xmm_ih_im, table_val(1), 1);
233                     movq(reg_tmp_64, xmm_ih_im);
234                     cmp(reg_tmp_32, 0);
235                     je(init_with_zeros, T_NEAR);
236
237
238                     size_t def_off_w = ((2 * (kh * jcp.kw + kw) + 1) * jcp.oh * jcp.ow) + ow;
239                     mov(reg_tmp_32, ptr[aux_reg_def_off + def_off_w * jcp.typesize_off]);
240                     movq(xmm_tmp, reg_tmp_64);
241                     mov(reg_tmp_32, float2int((float) (kw * (jcp.dilate_w + 1))));
242                     movq(xmm_map_w, reg_tmp_64);
243                     addss(xmm_map_w, xmm_tmp);
244
245                     mov(reg_tmp_32, jcp.stride_w);
246                     imul(reg_tmp_32, reg_ow_pos);
247                     sub(reg_tmp_32, jcp.l_pad - ow * jcp.stride_w);
248                     movq(xmm_iw_in, reg_tmp_64);
249
250                     cvtsi2ss(xmm_iw_im, reg_tmp_32);
251                     addss(xmm_iw_im, xmm_map_w);
252
253                     movss(xmm_tmp, xmm_iw_im);
254                     cmpss(xmm_tmp, table_val(0), 1);
255                     movq(reg_tmp_64, xmm_tmp);
256                     cmp(reg_tmp_32, 0);
257                     jne(init_with_zeros, T_NEAR);
258
259                     cmpss(xmm_iw_im, table_val(2), 1);
260                     movq(reg_tmp_64, xmm_iw_im);
261                     cmp(reg_tmp_32, 0);
262                     je(init_with_zeros, T_NEAR);
263
264
265                     movd(xmm_cur_height, table_val(3));
266                     psubd(xmm_cur_height, xmm_ih_in);
267
268                     roundps(xmm_h_low, xmm_map_h, 1);
269                     cvtps2dq(xmm_h_low, xmm_h_low);
270
271                     movups(xmm_tmp, xmm_cur_height);
272                     pcmpgtd(xmm_tmp, xmm_h_low);
273
274                     movq(reg_tmp_64, xmm_tmp);
275                     cmp(reg_tmp_32, 0);
276                     jne(h_sec_opt, T_NEAR);
277
278                     movups(xmm_h_low, xmm_cur_height);
279                     movups(xmm_h_high, xmm_h_low);
280                     jmp(h_sec_opt_exit);
281
282                     L(h_sec_opt);
283
284                     movups(xmm_h_high, xmm_h_low);
285                     paddd(xmm_h_high, table_val(5));
286
287                     L(h_sec_opt_exit);
288
289                     cvtdq2ps(xmm_tmp, xmm_h_low);
290                     subss(xmm_lh, xmm_tmp);
291                     movss(xmm_hh, table_val(5));
292                     cvtdq2ps(xmm_hh, xmm_hh);
293                     subss(xmm_hh, xmm_lh);
294
295
296                     movd(xmm_cur_width, table_val(4));
297                     psubd(xmm_cur_width, xmm_iw_in);
298
299                     roundps(xmm_w_low, xmm_map_w, 1);
300                     cvtps2dq(xmm_w_low, xmm_w_low);
301
302                     movups(xmm_tmp, xmm_cur_width);
303                     pcmpgtd(xmm_tmp, xmm_w_low);
304
305                     movq(reg_tmp_64, xmm_tmp);
306                     cmp(reg_tmp_32, 0);
307                     jne(w_sec_opt, T_NEAR);
308
309                     movups(xmm_w_low, xmm_cur_width);
310                     movups(xmm_w_high, xmm_w_low);
311                     jmp(w_sec_opt_exit);
312
313                     L(w_sec_opt);
314
315                     movups(xmm_w_high, xmm_w_low);
316                     paddd(xmm_w_high, table_val(5));
317
318                     L(w_sec_opt_exit);
319
320                     cvtdq2ps(xmm_tmp, xmm_w_low);
321                     subss(xmm_lw, xmm_tmp);
322                     movss(xmm_hw, table_val(5));
323                     cvtdq2ps(xmm_hw, xmm_hw);
324                     subss(xmm_hw, xmm_lw);
325
326
327                     movups(xmm_v1_off, table_val(2));
328                     cvtps2dq(xmm_v1_off, xmm_v1_off);
329                     movups(xmm_v3_off, xmm_v1_off);
330
331                     pmulld(xmm_v1_off, xmm_h_low);
332                     movups(xmm_v2_off, xmm_v1_off);
333                     paddd(xmm_v1_off, xmm_w_low);
334                     paddd(xmm_v2_off, xmm_w_high);
335
336                     pmulld(xmm_v3_off, xmm_h_high);
337                     movups(xmm_v4_off, xmm_v3_off);
338                     paddd(xmm_v3_off, xmm_w_low);
339                     paddd(xmm_v4_off, xmm_w_high);
340
341
342                     movss(xmm_w1, xmm_hh);
343                     mulss(xmm_w1, xmm_hw);
344                     uni_vbroadcastss(vmm_w1, xmm_w1);
345
346                     movss(xmm_w2, xmm_hh);
347                     mulss(xmm_w2, xmm_lw);
348                     uni_vbroadcastss(vmm_w2, xmm_w2);
349
350                     movss(xmm_w3, xmm_lh);
351                     mulss(xmm_w3, xmm_hw);
352                     uni_vbroadcastss(vmm_w3, xmm_w3);
353
354                     movss(xmm_w4, xmm_lh);
355                     mulss(xmm_w4, xmm_lw);
356                     uni_vbroadcastss(vmm_w4, xmm_w4);
357
358                     int simd_w = vlen / jcp.typesize_in;
359                     mov(reg_ic_iter, ic_per_def_group);
360                     L(ic_loop_main);
361                     {
362                         cmp(reg_ic_iter, simd_w);
363                         jl(ic_loop_tail, T_NEAR);
364
365                         size_t input_buffer_off = (size_t) kh * jcp.kw * jcp.ic + kw * jcp.ic;
366
367                         pmovsxdq(xmm_v1_off, xmm_v1_off);
368                         movq(reg_tmp_64, xmm_v1_off);
369                         imul(reg_tmp_64, reg_tmp_64, jcp.ic * jcp.typesize_in);
370                         add(reg_tmp_64, aux2_reg_input);
371                         uni_vmovups(vmm_v1, ptr[reg_tmp_64]);
372                         uni_vmulps(vmm_v1, vmm_v1, vmm_w1);
373
374                         pmovsxdq(xmm_v2_off, xmm_v2_off);
375                         movq(reg_tmp_64, xmm_v2_off);
376                         imul(reg_tmp_64, reg_tmp_64, jcp.ic * jcp.typesize_in);
377                         add(reg_tmp_64, aux2_reg_input);
378                         uni_vmovups(vmm_v2, ptr[reg_tmp_64]);
379                         uni_vmulps(vmm_v2, vmm_v2, vmm_w2);
380
381                         pmovsxdq(xmm_v3_off, xmm_v3_off);
382                         movq(reg_tmp_64, xmm_v3_off);
383                         imul(reg_tmp_64, reg_tmp_64, jcp.ic * jcp.typesize_in);
384                         add(reg_tmp_64, aux2_reg_input);
385                         uni_vmovups(vmm_v3, ptr[reg_tmp_64]);
386                         uni_vmulps(vmm_v3, vmm_v3, vmm_w3);
387
388                         pmovsxdq(xmm_v4_off, xmm_v4_off);
389                         movq(reg_tmp_64, xmm_v4_off);
390                         imul(reg_tmp_64, reg_tmp_64, jcp.ic * jcp.typesize_in);
391                         add(reg_tmp_64, aux2_reg_input);
392                         uni_vmovups(vmm_v4, ptr[reg_tmp_64]);
393                         uni_vmulps(vmm_v4, vmm_v4, vmm_w4);
394
395                         uni_vaddps(vmm_v1, vmm_v1, vmm_v2);
396                         uni_vaddps(vmm_v1, vmm_v1, vmm_v3);
397                         uni_vaddps(vmm_v1, vmm_v1, vmm_v4);
398                         uni_vmovups(ptr[aux3_reg_input_buffer + input_buffer_off * jcp.typesize_in], vmm_v1);
399
400                         add(aux2_reg_input, simd_w * jcp.typesize_in);
401                         add(aux3_reg_input_buffer, simd_w * jcp.typesize_in);
402                         sub(reg_ic_iter, simd_w);
403                         jmp(ic_loop_main, T_NEAR);
404                     };
405
406                     L(ic_loop_tail);
407                     {
408                         cmp(reg_ic_iter, 1);
409                         jl(loop_end, T_NEAR);
410
411                         size_t input_buffer_off = (size_t) kh * jcp.kw * jcp.ic + kw * jcp.ic;
412
413                         pmovsxdq(xmm_v1_off, xmm_v1_off);
414                         movq(reg_tmp_64, xmm_v1_off);
415                         imul(reg_tmp_64, reg_tmp_64, jcp.ic * jcp.typesize_in);
416                         add(reg_tmp_64, aux2_reg_input);
417                         movss(xmm_v1, ptr[reg_tmp_64]);
418                         mulss(xmm_v1, xmm_w1);
419
420                         pmovsxdq(xmm_v2_off, xmm_v2_off);
421                         movq(reg_tmp_64, xmm_v2_off);
422                         imul(reg_tmp_64, reg_tmp_64, jcp.ic * jcp.typesize_in);
423                         add(reg_tmp_64, aux2_reg_input);
424                         movss(xmm_v2, ptr[reg_tmp_64]);
425                         mulss(xmm_v2, xmm_w2);
426
427                         pmovsxdq(xmm_v3_off, xmm_v3_off);
428                         movq(reg_tmp_64, xmm_v3_off);
429                         imul(reg_tmp_64, reg_tmp_64, jcp.ic * jcp.typesize_in);
430                         add(reg_tmp_64, aux2_reg_input);
431                         movss(xmm_v3, ptr[reg_tmp_64]);
432                         mulss(xmm_v3, xmm_w3);
433
434                         pmovsxdq(xmm_v4_off, xmm_v4_off);
435                         movq(reg_tmp_64, xmm_v4_off);
436                         imul(reg_tmp_64, reg_tmp_64, jcp.ic * jcp.typesize_in);
437                         add(reg_tmp_64, aux2_reg_input);
438                         movss(xmm_v4, ptr[reg_tmp_64]);
439                         mulss(xmm_v4, xmm_w4);
440
441                         addss(xmm_v1, xmm_v2);
442                         addss(xmm_v1, xmm_v3);
443                         addss(xmm_v1, xmm_v4);
444                         movss(ptr[aux3_reg_input_buffer + input_buffer_off * jcp.typesize_in], xmm_v1);
445
446                         add(aux2_reg_input, jcp.typesize_in);
447                         add(aux3_reg_input_buffer, jcp.typesize_in);
448                         sub(reg_ic_iter, 1);
449                         jmp(ic_loop_tail, T_NEAR);
450                     };
451
452                     jmp(loop_end, T_NEAR);
453
454                     L(init_with_zeros);
455
456                     mov(reg_ic_iter, 0);
457                     L(ic_loop_zeros);
458                     {
459                         cmp(reg_ic_iter, ic_per_def_group);
460                         je(loop_end, T_NEAR);
461
462                         size_t input_buffer_off = (size_t) kh * jcp.kw * jcp.ic + kw * jcp.ic;
463
464                         pxor(xmm_tmp, xmm_tmp);
465                         movss(ptr[aux3_reg_input_buffer + input_buffer_off * jcp.typesize_in], xmm_tmp);
466                         add(aux3_reg_input_buffer, jcp.typesize_in);
467                         inc(reg_ic_iter);
468                         jmp(ic_loop_zeros, T_NEAR);
469                     }
470
471                     L(loop_end);
472                 }
473             }
474         }
475
476         add(aux_reg_def_off, 2 * jcp.kh * jcp.kw * jcp.oh * jcp.ow * jcp.typesize_off);
477         add(aux_reg_input, ic_per_def_group * jcp.typesize_in);
478         add(aux2_reg_input_buffer, ic_per_def_group * jcp.typesize_in);
479         inc(reg_dg_iter);
480         jmp(dg_loop, T_NEAR);
481     }
482
483     L(dg_loop_end);
484 }
485
486 template <cpu_isa_t isa>
487 void jit_uni_def_conv_fwd_kernel_f32<isa>::store_output(int ow_step, int oc_blocks_step, int oc_step) {
488     int repeats = isa == sse42 && oc_step > (jcp.oc_block / 2) ? 2 : 1;
489
490     if (jcp.with_bias) {
491         for (int r = 0; r < repeats; r++) {
492             for (int ocb = 0; ocb < oc_blocks_step; ocb++) {
493                 size_t bias_off = (size_t) ocb * jcp.oc_block + r * jcp.oc_block / 2;
494                 uni_vmovups(Vmm(0), ptr[aux_reg_bias + bias_off * jcp.typesize_bia]);
495
496                 for (int ow = 0; ow < ow_step; ow++) {
497                     Vmm vmm_acc = get_vmm_acc(r * jcp.ur_w * jcp.nb_oc_blocking + ocb * ow_step + ow);
498
499                     uni_vaddps(vmm_acc, vmm_acc, Vmm(0));
500                 }
501             }
502         }
503     }
504
505     if (isa == avx512_common && oc_step != jcp.oc_block) {
506         int mask = (1 << oc_step) - 1;
507         mov(reg_tmp_32, mask);
508         kmovw(ktail_mask, reg_tmp_32);
509     }
510
511     for (int r = 0; r < repeats; r++) {
512         int tail_size = isa == sse42 ? nstl::min(jcp.oc_block / 2, oc_step - r * jcp.oc_block / 2) : oc_step;
513         bool is_scalar_store = isa == sse42 ? tail_size < jcp.oc_block / 2 : tail_size < jcp.oc_block;
514         if (is_scalar_store) {
515             for (int ow = 0; ow < ow_step; ow++) {
516                 Vmm vmm_dst = get_vmm_acc(r * jcp.ur_w * jcp.nb_oc_blocking + ow);
517                 Xmm xmm_dst = get_xmm_acc(r * jcp.ur_w * jcp.nb_oc_blocking + ow);
518
519                 if (isa == avx512_common) {
520                     size_t out_off = (size_t) ow * jcp.oc;
521
522                     uni_vmovups(ptr[aux_reg_output + out_off * jcp.typesize_out], vmm_dst | ktail_mask);
523                 } else {
524                     for (int oc = 0; oc < tail_size; oc++) {
525                         size_t out_off = (size_t) ow * jcp.oc + oc + r * (jcp.oc_block / 2);
526
527                         movq(reg_tmp_64, xmm_dst);
528                         mov(ptr[aux_reg_output + out_off * jcp.typesize_out], reg_tmp_32);
529
530                         if (isa == sse42) {
531                             psrldq(vmm_dst, jcp.typesize_out);
532                         } else {
533                             Ymm ymm_dst = get_ymm_acc(ow);
534                             Vmm vmm_tmp = Vmm(0);
535                             Ymm ymm_tmp = Ymm(0);
536
537                             vperm2i128(ymm_tmp, ymm_dst, ymm_dst, 0x01);
538                             vpalignr(ymm_dst, vmm_tmp, ymm_dst, jcp.typesize_out);
539                         }
540                     }
541                 }
542             }
543         } else {
544             for (int ocb = 0; ocb < oc_blocks_step; ocb++) {
545                 for (int ow = 0; ow < ow_step; ow++) {
546                     Vmm vmm_acc = get_vmm_acc(r * jcp.ur_w * jcp.nb_oc_blocking + ocb * ow_step + ow);
547                     size_t out_off = (size_t) ow * jcp.oc + ocb * jcp.oc_block + r * (jcp.oc_block / 2);
548
549                     uni_vmovups(ptr[aux_reg_output + out_off * jcp.typesize_out], vmm_acc);
550                 }
551             }
552         }
553     }
554 }
555
556 template <cpu_isa_t isa>
557 void jit_uni_def_conv_fwd_kernel_f32<isa>::oc_loop(int ow_step) {
558     Label oc_unrolled_loop;
559     Label oc_main_loop;
560     Label oc_tail;
561
562     mov(aux_reg_input_buffer, reg_input_buffer);
563
564     push(reg_output);
565     push(reg_bias);
566     push(reg_input);
567     push(reg_kernel);
568
569     interpolate_input(ow_step);
570
571     pop(reg_kernel);
572     pop(reg_input);
573     pop(reg_bias);
574     pop(reg_output);
575
576     push(reg_ow_pos);
577
578     mov(aux_reg_kernel, reg_kernel);
579     mov(aux_reg_output, reg_output);
580     mov(aux_reg_bias, reg_bias);
581
582     mov(reg_oc_work, jcp.oc);
583
584     L(oc_unrolled_loop); {
585         cmp(reg_oc_work, jcp.nb_oc_blocking * jcp.oc_block);
586         jl(oc_main_loop, T_NEAR);
587
588         ic_loop(ow_step, jcp.nb_oc_blocking, jcp.oc_block);
589         store_output(ow_step, jcp.nb_oc_blocking, jcp.oc_block);
590
591         add(aux_reg_kernel, jcp.nb_oc_blocking * jcp.nb_ic * jcp.kh * jcp.kw * jcp.ic_block * jcp.oc_block * jcp.typesize_in);
592         add(aux_reg_output, jcp.nb_oc_blocking * jcp.oc_block * jcp.typesize_out);
593         add(aux_reg_bias, jcp.nb_oc_blocking * jcp.oc_block * jcp.typesize_bia);
594         sub(reg_oc_work, jcp.nb_oc_blocking * jcp.oc_block);
595
596         jmp(oc_unrolled_loop, T_NEAR);
597     }
598
599     L(oc_main_loop); {
600         cmp(reg_oc_work, jcp.oc_block);
601         jl(oc_tail, T_NEAR);
602
603         ic_loop(ow_step, 1, jcp.oc_block);
604         store_output(ow_step, 1, jcp.oc_block);
605
606         add(aux_reg_kernel, jcp.nb_ic * jcp.kh * jcp.kw * jcp.ic_block * jcp.oc_block * jcp.typesize_in);
607         add(aux_reg_output, jcp.oc_block * jcp.typesize_out);
608         add(aux_reg_bias, jcp.oc_block * jcp.typesize_bia);
609         sub(reg_oc_work, jcp.oc_block);
610
611         jmp(oc_main_loop, T_NEAR);
612     }
613
614     L(oc_tail); {
615         if (jcp.oc % jcp.oc_block != 0) {
616             ic_loop(ow_step, 1, jcp.oc % jcp.oc_block);
617             store_output(ow_step, 1, jcp.oc % jcp.oc_block);
618         }
619     }
620
621     pop(reg_ow_pos);
622 }
623
624 template <cpu_isa_t isa>
625 void jit_uni_def_conv_fwd_kernel_f32<isa>::ow_loop() {
626     Label ow_loop_main;
627     Label ow_tail;
628
629     mov(reg_ow_pos, 0);
630
631     L(ow_loop_main); {
632         cmp(reg_ow_pos, jcp.ow - jcp.ur_w);
633         jg(ow_tail, T_NEAR);
634
635         oc_loop(jcp.ur_w);
636
637         add(reg_input, jcp.ur_w * jcp.stride_w * jcp.ic * jcp.typesize_in);
638         add(reg_def_off, jcp.ur_w * jcp.typesize_off);
639         add(reg_output, jcp.ur_w * jcp.oc * jcp.typesize_out);
640
641         add(reg_ow_pos, jcp.ur_w);
642         jmp(ow_loop_main, T_NEAR);
643     }
644
645     L(ow_tail); {
646         if (jcp.ow % jcp.ur_w != 0)
647             oc_loop(jcp.ow % jcp.ur_w);
648     }
649 }
650
651 template <cpu_isa_t isa>
652 void jit_uni_def_conv_fwd_kernel_f32<isa>::generate()
653 {
654     this->preamble();
655
656     mov(reg_input, ptr[this->param1 + GET_OFF(src)]);
657     mov(reg_def_off, ptr[this->param1 + GET_OFF(off)]);
658     mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]);
659     if (jcp.with_bias)
660         mov(reg_bias, ptr[this->param1 + GET_OFF(bias)]);
661     mov(reg_output, ptr[this->param1 + GET_OFF(dst)]);
662     mov(reg_input_buffer, ptr[this->param1 + GET_OFF(buf)]);
663     mov(reg_oh_pos, ptr[param1 + GET_OFF(oh_pos)]);
664
665     ow_loop();
666
667     this->postamble();
668
669     prepare_table();
670 }
671
672 template <cpu_isa_t isa>
673 void jit_uni_def_conv_fwd_kernel_f32<isa>::prepare_table() {
674     align(64);
675     L(l_table);
676     for (size_t d = 0; d < vlen / sizeof(int32_t); ++d) {
677         dd(0);
678     }
679
680     for (size_t d = 0; d < vlen / sizeof(int32_t); ++d) {
681         dd(float2int((float)jcp.ih));
682     }
683
684     for (size_t d = 0; d < vlen / sizeof(int32_t); ++d) {
685         dd(float2int((float)jcp.iw));
686     }
687
688     for (size_t d = 0; d < vlen / sizeof(int32_t); ++d) {
689         dd(jcp.ih - 1);
690     }
691
692     for (size_t d = 0; d < vlen / sizeof(int32_t); ++d) {
693         dd(jcp.iw - 1);
694     }
695
696     for (size_t d = 0; d < vlen / sizeof(int32_t); ++d) {
697         dd(1);
698     }
699 }
700
701 template <cpu_isa_t isa>
702 bool jit_uni_def_conv_fwd_kernel_f32<isa>::post_ops_ok(jit_def_conv_conf_t &jcp, const primitive_attr_t &attr) {
703     const auto &p = attr.post_ops_;
704
705     return p.len_ == 0;
706 }
707
708 template <cpu_isa_t isa>
709 status_t jit_uni_def_conv_fwd_kernel_f32<isa>::init_conf(jit_def_conv_conf_t &jcp,
710         const deformable_convolution_desc_t &cd, cpu_memory_t::pd_t &src_pd,
711         cpu_memory_t::pd_t &offsets_pd, cpu_memory_t::pd_t &weights_pd,
712         cpu_memory_t::pd_t &dst_pd, cpu_memory_t::pd_t &bias_pd,
713         const primitive_attr_t &attr)
714 {
715     if (!mayiuse(isa)) return status::unimplemented;
716
717     const memory_desc_wrapper src_d(&src_pd);
718     const memory_desc_wrapper offsets_d(&offsets_pd);
719     const memory_desc_wrapper weights_d(&weights_pd);
720     const memory_desc_wrapper dst_d(&dst_pd);
721     const memory_desc_wrapper bias_d(&bias_pd);
722
723     jcp.prop_kind = cd.prop_kind;
724
725     jcp.dg = cd.deformable_group;
726
727     const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
728     jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
729     jcp.mb = src_d.dims()[0];
730
731     jcp.oc = dst_d.dims()[1] / jcp.ngroups;
732     jcp.ic = src_d.dims()[1] / jcp.ngroups;
733
734     jcp.ih = src_d.dims()[2];
735     jcp.iw = src_d.dims()[3];
736     jcp.oh = dst_d.dims()[2];
737     jcp.ow = dst_d.dims()[3];
738
739     jcp.kh = weights_d.dims()[with_groups + 2];
740     jcp.kw = weights_d.dims()[with_groups + 3];
741
742     jcp.t_pad = cd.padding[0][0];
743     jcp.l_pad = cd.padding[0][1];
744
745     jcp.stride_h = cd.strides[0];
746     jcp.stride_w = cd.strides[1];
747
748     jcp.dilate_h = cd.dilates[0];
749     jcp.dilate_w = cd.dilates[1];
750
751     jcp.with_bias = cd.bias_desc.format != memory_format::undef;
752
753     const int simd_w = isa == avx512_common ? 16 : 8;
754     jcp.ic_block = simd_w;
755     jcp.nb_ic = div_up(jcp.ic, jcp.ic_block);
756
757     jcp.oc_block = simd_w;
758     jcp.oc_padded = rnd_up(jcp.oc, jcp.oc_block);
759     jcp.nb_oc = div_up(jcp.oc, jcp.oc_block);
760
761     if (jcp.ngroups != 1)
762         return status::unimplemented;
763
764     if (jcp.ic % jcp.dg != 0)
765         return status::unimplemented;
766
767     if (!post_ops_ok(jcp, attr))
768         return status::unimplemented;
769
770     auto desired_act_fmt = nhwc;
771     auto desired_off_fmt = nchw;
772     auto desired_wei_fmt = with_groups ? isa == avx512_common ? gOIhw16i16o : gOIhw8i8o
773                                        : isa == avx512_common ? OIhw16i16o : OIhw8i8o;
774
775     if (src_d.format() == any)
776         CHECK(src_pd.set_format(desired_act_fmt));
777     if (src_d.format() != desired_act_fmt)
778         return status::unimplemented;
779
780     if (offsets_d.format() == any)
781         CHECK(offsets_pd.set_format(desired_off_fmt));
782     if (offsets_d.format() != desired_off_fmt)
783         return status::unimplemented;
784
785     if (weights_d.format() == any)
786         CHECK(weights_pd.set_format(desired_wei_fmt));
787     if (weights_d.format() != desired_wei_fmt)
788         return status::unimplemented;
789
790     if (jcp.with_bias) {
791         if (bias_d.format() == any)
792             CHECK(bias_pd.set_format(x));
793         if (bias_d.format() != x)
794             return status::unimplemented;
795     }
796
797     if (dst_d.format() == any)
798         CHECK(dst_pd.set_format(desired_act_fmt));
799     if (dst_d.format() != desired_act_fmt)
800         return status::unimplemented;
801
802     jcp.src_dt = cd.src_descs[0].data_type;
803     jcp.off_dt = cd.src_descs[1].data_type;
804     jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
805     jcp.dst_dt = cd.dst_desc.data_type;
806
807     jcp.typesize_in = (int)types::data_type_size(jcp.src_dt);
808     jcp.typesize_off = (int)types::data_type_size(jcp.off_dt);
809     jcp.typesize_out = (int)types::data_type_size(jcp.dst_dt);
810     jcp.typesize_bia = (int)(jcp.with_bias ? types::data_type_size(jcp.bia_dt) : 0);
811
812     jcp.ur_w = isa == avx512_common ? 6 : 3;
813     jcp.nb_oc_blocking = isa == sse42 ? 2 : 4;
814
815     jcp.nthr = mkldnn_get_max_threads();
816
817     return status::success;
818 }
819
820 template <cpu_isa_t isa>
821 void jit_uni_def_conv_fwd_kernel_f32<isa>::init_scratchpad(
822         memory_tracking::registrar_t &scratchpad, const jit_def_conv_conf_t &jcp, const primitive_attr_t &attr) {
823
824     scratchpad.book(key_def_conv_buffer, (size_t)jcp.nthr * jcp.ur_w * jcp.kh * jcp.kw * jcp.ic * jcp.typesize_in);
825     if (jcp.oc != jcp.oc_padded) {
826         scratchpad.book(key_conv_padded_bias, (size_t)jcp.typesize_bia * jcp.oc_padded);
827     }
828 }
829
830 template struct jit_uni_def_conv_fwd_kernel_f32<avx512_common>;
831 template struct jit_uni_def_conv_fwd_kernel_f32<avx2>;
832 template struct jit_uni_def_conv_fwd_kernel_f32<sse42>;
833
834 }
835 }
836 }