updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_depthwise.cpp
1 /*******************************************************************************
2 * Copyright 2018-2019 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <mkldnn_types.h>
18 #include "mkldnn_types.h"
19 #include "mkldnn_thread.hpp"
20 #include "nstl.hpp"
21 #include "utils.hpp"
22 #include "jit_generator.hpp"
23
24 #include "jit_uni_depthwise.hpp"
25
26 #define GET_OFF(field) offsetof(jit_args, field)
27
28 namespace mkldnn {
29 namespace impl {
30 namespace cpu {
31
32 using namespace Xbyak;
33 using namespace mkldnn::impl::memory_format;
34 using namespace mkldnn::impl::utils;
35
36 struct jit_args {
37     const float *from;
38     const float *to;
39     const float *weights;
40     const float *bias;
41     size_t work_amount;
42 };
43
44 struct jit_uni_depthwise_kernel_f32 : public c_compatible {
45     const depthwise_desc_t &desc_;
46     void (*ker_)(const jit_args *);
47     bool with_bias_;
48
49     void operator()(const jit_args *args) { assert(ker_); ker_(args); }
50
51     jit_uni_depthwise_kernel_f32(const depthwise_desc_t &desc, bool with_bias)
52         : desc_(desc), ker_(nullptr), with_bias_(with_bias) {}
53     virtual ~jit_uni_depthwise_kernel_f32() {}
54 };
55
56 template <cpu_isa_t isa>
57 int jit_uni_depthwise_injector_f32<isa>::aux_vecs_count(alg_kind_t depthwise_alg) {
58     switch (depthwise_alg) {
59         case alg_kind::depthwise_scale_shift: return isa == sse42 ? 1 : 0;
60         case alg_kind::depthwise_prelu: return 2;
61         default: assert(!"unsupported depthwise algorithm");
62     }
63
64     return 0;
65 }
66
67 template <cpu_isa_t isa>
68 void jit_uni_depthwise_injector_f32<isa>::injector_preamble(size_t start_idx, size_t end_idx) {
69     preserved_vecs_count = 0;
70     vecs_to_preserve = (size_t)jit_uni_depthwise_injector_f32<isa>::aux_vecs_count(depthwise_alg);
71
72     for (size_t i = 0; i < vecs_count; i++) {
73         if (preserved_vecs_count >= vecs_to_preserve)
74             break;
75
76         if (i < start_idx || i >= end_idx) {
77             preserved_vec_idxs[preserved_vecs_count] = i;
78             preserved_vecs_count++;
79         }
80     }
81
82     start_idx_tail = start_idx;
83     size_t preserved_vecs_count_tail = vecs_to_preserve - preserved_vecs_count;
84     for (size_t i = 0; i < preserved_vecs_count_tail; i++) {
85         preserved_vec_idxs[preserved_vecs_count] = start_idx + i;
86         preserved_vecs_count++;
87         start_idx_tail = start_idx + i + 1;
88     }
89
90     h->sub(h->rsp, preserved_vecs_count * vlen);
91     for (size_t i = 0; i < preserved_vecs_count; ++i)
92         h->uni_vmovups(h->ptr[h->rsp + i * vlen], Vmm(preserved_vec_idxs[i]));
93
94     assign_regs();
95 }
96
97 template <cpu_isa_t isa>
98 void jit_uni_depthwise_injector_f32<isa>::injector_preamble_tail(size_t start_idx, size_t end_idx) {
99     size_t tail_vecs_to_preserve = start_idx_tail - start_idx;
100     int idx_off = (vecs_to_preserve - tail_vecs_to_preserve);
101
102     if (tail_vecs_to_preserve > 0) {
103         h->add(h->rsp, idx_off * vlen);
104         for (size_t i = 0; i < tail_vecs_to_preserve; ++i)
105             h->uni_vmovups(Vmm(preserved_vec_idxs[idx_off + i]), h->ptr[h->rsp + i * vlen]);
106
107         for (size_t i = 0; i < tail_vecs_to_preserve; ++i) {
108             preserved_vec_idxs[idx_off + i] += tail_vecs_to_preserve;
109         }
110
111         for (size_t i = 0; i < tail_vecs_to_preserve; ++i)
112             h->uni_vmovups(h->ptr[h->rsp + i * vlen], Vmm(preserved_vec_idxs[idx_off + i]));
113         h->sub(h->rsp, idx_off * vlen);
114
115         assign_regs();
116     }
117 }
118
119 template <cpu_isa_t isa>
120 void jit_uni_depthwise_injector_f32<isa>::injector_postamble() {
121     for (size_t i = 0; i < preserved_vecs_count; ++i)
122         h->uni_vmovups(Vmm(preserved_vec_idxs[i]), h->ptr[h->rsp + i * vlen]);
123     h->add(h->rsp, preserved_vecs_count * vlen);
124 }
125
126 template <cpu_isa_t isa>
127 void jit_uni_depthwise_injector_f32<isa>::assign_regs() {
128     vmm_mask = Vmm(preserved_vec_idxs[0]);
129     vmm_aux0 = Vmm(preserved_vec_idxs[1]);
130 }
131
132 template <cpu_isa_t isa>
133 void jit_uni_depthwise_injector_f32<isa>::scale_shift_compute_vector(const Vmm &vmm_src,
134         const Xbyak::Reg64& p_weights, const Xbyak::Reg64& p_bias) {
135     if (isa == sse42) {
136         h->movups(vmm_mask, h->ptr[p_weights]);
137         h->mulps(vmm_src, vmm_mask);
138         h->movups(vmm_mask, h->ptr[p_bias]);
139         h->addps(vmm_src, vmm_mask);
140     } else {
141         h->uni_vmulps(vmm_src, vmm_src, h->ptr[p_weights]);
142         h->uni_vaddps(vmm_src, vmm_src, h->ptr[p_bias]);
143     };
144 }
145
146 template <cpu_isa_t isa>
147 void jit_uni_depthwise_injector_f32<isa>::prelu_compute_vector(const Vmm &vmm_src,
148         const Xbyak::Reg64& p_weights, const Xbyak::Reg64& p_bias) {
149     const unsigned char _cmp_gt_os = 6;
150     const unsigned char _cmp_lt_os = 1;
151
152     if (isa == sse42) {
153         h->pxor(vmm_mask, vmm_mask);
154         h->cmpps(vmm_mask, vmm_src, _cmp_gt_os);
155         h->movups(vmm_aux0, h->ptr[p_weights]);
156         h->mulps(vmm_aux0, vmm_src);
157         h->blendvps(vmm_src, vmm_aux0);
158     } else if (isa == avx2) {
159         h->vxorps(vmm_mask, vmm_mask, vmm_mask);
160         h->vcmpgtps(vmm_mask, vmm_src, vmm_mask);
161         h->vmulps(vmm_aux0, vmm_src, h->ptr[p_weights]);
162         h->vblendvps(vmm_src, vmm_aux0, vmm_src, vmm_mask);
163     } else if (isa == avx512_common) {
164         h->vxorpd(vmm_mask, vmm_mask, vmm_mask);
165         h->vmovups(vmm_aux0, vmm_src);
166         h->vcmpps(k_mask, vmm_src, vmm_mask, _cmp_lt_os);
167         h->vmulps(vmm_src | k_mask, vmm_aux0, h->ptr[p_weights]);
168     }
169 }
170
171 template <cpu_isa_t isa>
172 void jit_uni_depthwise_injector_f32<isa>::compute_body(size_t start_idx, size_t end_idx,
173         const Xbyak::Reg64& p_weights, const Xbyak::Reg64& p_bias) {
174     for (size_t idx = start_idx; idx < end_idx; idx++) {
175         switch (depthwise_alg) {
176             case alg_kind::depthwise_scale_shift:
177                 scale_shift_compute_vector(Vmm(idx), p_weights, p_bias); break;
178             case alg_kind::depthwise_prelu:
179                 prelu_compute_vector(Vmm(idx), p_weights, p_bias); break;
180             default: assert(!"unsupported depthwise algorithm");
181         }
182     }
183 }
184
185 template <cpu_isa_t isa>
186 void jit_uni_depthwise_injector_f32<isa>::compute_vector_range(int start_idx, int end_idx,
187         const Xbyak::Reg64& p_weights, const Xbyak::Reg64& p_bias) {
188     injector_preamble(start_idx, end_idx);
189     compute_body(start_idx_tail, end_idx, p_weights, p_bias);
190     injector_preamble_tail(start_idx, end_idx);
191     compute_body(start_idx, start_idx_tail, p_weights, p_bias);
192     injector_postamble();
193 }
194
195 template struct jit_uni_depthwise_injector_f32<avx512_common>;
196 template struct jit_uni_depthwise_injector_f32<avx2>;
197 template struct jit_uni_depthwise_injector_f32<sse42>;
198
199 /* jit kernels */
200 namespace {
201
202 template <cpu_isa_t isa>
203 struct jit_uni_scale_shift_kernel_f32 : public jit_uni_depthwise_kernel_f32,
204     public jit_generator
205 {
206     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_scale_shift_kernel_f32)
207     jit_uni_scale_shift_kernel_f32(const depthwise_desc_t &desc, bool with_bias)
208         : jit_uni_depthwise_kernel_f32(desc, with_bias), jit_generator() {
209         assert(desc.alg_kind == alg_kind::depthwise_scale_shift);
210         assert(isa == sse42 || isa == avx2 || isa == avx512_common);
211
212         bool isFlat = (desc.src_desc.format == nchw && desc.dst_desc.format == nchw) ||
213                       (desc.src_desc.format == ncdhw && desc.dst_desc.format == ncdhw);
214
215         Reg64 param = abi_param1;
216
217         const int block_size = isa == avx512_common ? 16 : 8;
218         const int main_loop_step = (isFlat || desc.src_desc.format == nc) ? block_size : 1;
219
220         this->preamble();
221
222         mov(reg_from, ptr[param + GET_OFF(from)]);
223         mov(reg_to, ptr[param + GET_OFF(to)]);
224         mov(reg_scale, ptr[param + GET_OFF(weights)]);
225         mov(reg_work_amount, ptr[param + GET_OFF(work_amount)]);
226         if (with_bias_)
227             mov(reg_shift, ptr[param + GET_OFF(bias)]);
228
229         Label main_loop_label;
230         Label tail_loop_label;
231         Label tail_loop_flat_label;
232         Label exit_label;
233
234         int repeats = isa == sse42 ? 2 : 1;
235         for (int i = 0; i < repeats; i++) {
236             if (isFlat) {
237                 uni_vbroadcastss(get_scale_reg(i), ptr[reg_scale]);
238                 if (with_bias_)
239                     uni_vbroadcastss(get_shift_reg(i), ptr[reg_shift]);
240                 else
241                     uni_vpxor(get_shift_reg(i), get_shift_reg(i), get_shift_reg(i));
242             } else {
243                 uni_vmovups(get_scale_reg(i), ptr[reg_scale + i*4*sizeof(float)]);
244                 if (with_bias_)
245                     uni_vmovups(get_shift_reg(i), ptr[reg_shift + i*4*sizeof(float)]);
246                 else
247                     uni_vpxor(get_shift_reg(i), get_shift_reg(i), get_shift_reg(i));
248             }
249         }
250
251         if (isFlat) {
252             uni_vbroadcastss(xmm_scale, ptr[reg_scale]);
253             if (with_bias_)
254                 uni_vbroadcastss(xmm_shift, ptr[reg_shift]);
255             else
256                 uni_vpxor(xmm_shift, xmm_shift, xmm_shift);
257         }
258
259         L(main_loop_label); {
260             cmp(reg_work_amount, main_loop_step-1);
261             jle(isFlat ? tail_loop_flat_label : tail_loop_label, T_NEAR);
262
263             int repeats = isa == sse42 ? 2 : 1;
264             for (int i = 0; i < repeats; i++) {
265                 uni_vmovups(vmm_src, ptr[reg_from + i*4*sizeof(float)]);
266                 uni_vmovups(vmm_dst, get_shift_reg(i));
267                 uni_vfmadd231ps(vmm_dst, vmm_src, get_scale_reg(i));
268                 uni_vmovups(ptr[reg_to + i*4*sizeof(float)], vmm_dst);
269             }
270
271             add(reg_from, block_size*sizeof(float));
272             add(reg_to, block_size*sizeof(float));
273             sub(reg_work_amount, main_loop_step);
274
275             jmp(main_loop_label, T_NEAR);
276         }
277
278         L(tail_loop_label); {
279             cmp(reg_work_amount, 0);
280             jle(exit_label, T_NEAR);
281
282             movss(xmm_src, ptr[reg_from]);
283             movss(xmm_shift, ptr[reg_shift]);
284             movss(xmm_scale, ptr[reg_scale]);
285             uni_vmovups(xmm_dst, xmm_shift);
286             uni_vfmadd231ps(xmm_dst, xmm_src, xmm_scale);
287             movss(ptr[reg_to], xmm_dst);
288
289             add(reg_from, 1*sizeof(float));
290             add(reg_to, 1*sizeof(float));
291             add(reg_shift, 1*sizeof(float));
292             add(reg_scale, 1*sizeof(float));
293             dec(reg_work_amount);
294
295             jmp(tail_loop_label, T_NEAR);
296         }
297
298         L(tail_loop_flat_label); {
299             cmp(reg_work_amount, 0);
300             jle(exit_label, T_NEAR);
301
302             movss(xmm_src, ptr[reg_from]);
303             uni_vmovups(xmm_dst, xmm_shift);
304             uni_vfmadd231ps(xmm_dst, xmm_src, xmm_scale);
305             movss(ptr[reg_to], xmm_dst);
306
307             add(reg_from, 1*sizeof(float));
308             add(reg_to, 1*sizeof(float));
309             dec(reg_work_amount);
310
311             jmp(tail_loop_flat_label, T_NEAR);
312         }
313
314         L(exit_label);
315
316         this->postamble();
317
318         ker_ = (decltype(ker_))this->getCode();
319     }
320
321 private:
322     using Vmm = typename utils::conditional3<isa == sse42, Xmm,
323                                              isa == avx2, Ymm, Zmm>::type;
324
325     inline Vmm get_scale_reg(int idx) { return Vmm(idx + 2); }
326     inline Vmm get_shift_reg(int idx) { return Vmm(idx + 4); }
327
328     Reg64 reg_from = r8;
329     Reg64 reg_to = r9;
330     Reg64 reg_work_amount = r10;
331     Reg64 reg_scale = r11;
332     Reg64 reg_shift = r12;
333
334     Vmm vmm_src = Vmm(0);
335     Vmm vmm_dst = Vmm(1);
336
337     Xmm xmm_src = Xmm(0);
338     Xmm xmm_dst = Xmm(1);
339     Xmm xmm_scale = Xmm(6);
340     Xmm xmm_shift = Xmm(7);
341 };
342
343 template <cpu_isa_t isa>
344 struct jit_uni_prelu_kernel_f32 : public jit_uni_depthwise_kernel_f32,
345     public jit_generator
346 {
347     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_prelu_kernel_f32)
348     jit_uni_prelu_kernel_f32(const depthwise_desc_t &desc, bool with_bias)
349         : jit_uni_depthwise_kernel_f32(desc, with_bias), jit_generator() {
350         assert(desc.alg_kind == alg_kind::depthwise_prelu);
351         assert(isa == sse42 || isa == avx2 || isa == avx512_common);
352
353         bool isFlat = (desc.src_desc.format == nchw && desc.dst_desc.format == nchw) ||
354                       (desc.src_desc.format == ncdhw && desc.dst_desc.format == ncdhw);
355
356         Reg64 param = abi_param1;
357
358         const int block_size = isa == avx512_common ? 16 : 8;
359         const int main_loop_step = (isFlat || desc.src_desc.format == nc) ? block_size : 1;
360
361         this->preamble();
362
363         mov(reg_from, ptr[param + GET_OFF(from)]);
364         mov(reg_to, ptr[param + GET_OFF(to)]);
365         mov(reg_scale, ptr[param + GET_OFF(weights)]);
366         mov(reg_work_amount, ptr[param + GET_OFF(work_amount)]);
367
368         uni_vpxor(vmm_zero, vmm_zero, vmm_zero);
369
370         int repeats = isa == sse42 ? 2 : 1;
371         for (int i = 0; i < repeats; i++) {
372             if (isFlat) {
373                 uni_vbroadcastss(get_scale_reg(i), ptr[reg_scale]);
374             } else {
375                 uni_vmovups(get_scale_reg(i), ptr[reg_scale + i*4*sizeof(float)]);
376             }
377         }
378
379         if (isFlat) {
380             uni_vbroadcastss(xmm_scale, ptr[reg_scale]);
381         }
382
383         Label main_loop_label;
384         Label tail_loop_label;
385         Label tail_loop_flat_label;
386         Label exit_label;
387
388         L(main_loop_label); {
389             cmp(reg_work_amount, main_loop_step-1);
390             jle(isFlat ? tail_loop_flat_label :tail_loop_label, T_NEAR);
391
392             for (int i = 0; i < repeats; i++) {
393                 uni_vmovups(vmm_src, ptr[reg_from + i*4*sizeof(float)]);
394
395                 if (isa == sse42) {
396                     pxor(vmm_mask, vmm_mask);
397                     cmpps(vmm_mask, vmm_src, _cmp_gt_os);
398                     movups(vmm_dst, vmm_src);
399                     mulps(vmm_src, get_scale_reg(i));
400                     blendvps(vmm_dst, vmm_src);
401                 } else if (isa == avx2) {
402                     vcmpgtps(vmm_mask, vmm_src, vmm_zero);
403                     vmulps(vmm_dst, vmm_src, get_scale_reg(i));
404                     vblendvps(vmm_dst, vmm_dst, vmm_src, vmm_mask);
405                 } else if (isa == avx512_common) {
406                     Opmask kmask = Opmask(7);
407                     vmovups(vmm_dst, vmm_src);
408                     vcmpps(kmask, vmm_src, vmm_zero, _cmp_lt_os);
409                     vmulps(vmm_dst | kmask, vmm_src, get_scale_reg(i));
410                 }
411
412                 uni_vmovups(ptr[reg_to + i*4*sizeof(float)], vmm_dst);
413             }
414
415             add(reg_from, block_size*sizeof(float));
416             add(reg_to, block_size*sizeof(float));
417             sub(reg_work_amount, main_loop_step);
418
419             jmp(main_loop_label, T_NEAR);
420         }
421
422         L(tail_loop_label); {
423             cmp(reg_work_amount, 0);
424             jle(exit_label, T_NEAR);
425
426             movss(xmm_src, ptr[reg_from]);
427             movss(xmm_scale, ptr[reg_scale]);
428
429             pxor(xmm_mask, xmm_mask);
430             cmpps(xmm_mask, xmm_src, _cmp_gt_os);
431             movups(xmm_dst, xmm_src);
432             mulps(xmm_src, xmm_scale);
433             blendvps(xmm_dst, xmm_src);
434
435             movss(ptr[reg_to], xmm_dst);
436
437             add(reg_from, 1*sizeof(float));
438             add(reg_to, 1*sizeof(float));
439             add(reg_scale, 1*sizeof(float));
440             dec(reg_work_amount);
441
442             jmp(tail_loop_label, T_NEAR);
443         }
444
445         L(tail_loop_flat_label); {
446             cmp(reg_work_amount, 0);
447             jle(exit_label, T_NEAR);
448
449             movss(xmm_src, ptr[reg_from]);
450
451             pxor(xmm_mask, xmm_mask);
452             cmpps(xmm_mask, xmm_src, _cmp_gt_os);
453             movups(xmm_dst, xmm_src);
454             mulps(xmm_src, xmm_scale);
455             blendvps(xmm_dst, xmm_src);
456
457             movss(ptr[reg_to], xmm_dst);
458
459             add(reg_from, 1*sizeof(float));
460             add(reg_to, 1*sizeof(float));
461             dec(reg_work_amount);
462
463             jmp(tail_loop_flat_label, T_NEAR);
464         }
465
466         L(exit_label);
467
468         this->postamble();
469
470         ker_ = (decltype(ker_))this->getCode();
471     }
472
473 private:
474     using Vmm = typename utils::conditional3<isa == sse42, Xmm,
475                                              isa == avx2, Ymm, Zmm>::type;
476
477     inline Vmm get_scale_reg(int idx) { return Vmm(idx + 4); }
478
479     Reg64 reg_from = r8;
480     Reg64 reg_to = r9;
481     Reg64 reg_work_amount = r10;
482     Reg64 reg_scale = r11;
483
484     Vmm vmm_mask = Vmm(0);
485     Vmm vmm_src = Vmm(1);
486     Vmm vmm_zero = Vmm(2);
487     Vmm vmm_dst = Vmm(3);
488
489     Xmm xmm_mask = Xmm(0);
490     Xmm xmm_src = Xmm(1);
491     Xmm xmm_dst = Xmm(3);
492     Xmm xmm_scale = Xmm(4);
493
494     const unsigned char _cmp_gt_os = 6;
495     const unsigned char _cmp_lt_os = 1;
496 };
497
498 } /* namespace */
499
500 template <cpu_isa_t isa>
501 status_t jit_uni_depthwise_fwd_t<isa>::pd_t::init() {
502     using namespace alg_kind;
503
504     memory_format_t desired_blk_fmt, desired_pln_fmt;
505     if (desc()->src_desc.ndims == 5) {
506         desired_blk_fmt = isa == avx512_common ? nCdhw16c : nCdhw8c;
507         desired_pln_fmt = ncdhw;
508     } else if (desc()->src_desc.ndims == 4) {
509         desired_blk_fmt = isa == avx512_common ? nChw16c : nChw8c;
510         desired_pln_fmt = nchw;
511     } else {
512         desired_blk_fmt = nc;
513         desired_pln_fmt = nc;
514     }
515
516     assert(engine()->kind() == engine_kind::cpu);
517     bool ok = true && mayiuse(isa)
518         && utils::one_of(desc()->prop_kind, prop_kind::forward_training,
519                 prop_kind::forward_inference)
520         && utils::everyone_is(data_type::f32, desc()->src_desc.data_type, desc()->dst_desc.data_type)
521         && desc()->src_desc.format == desc()->dst_desc.format
522         && utils::one_of(desc()->src_desc.format, desired_blk_fmt, desired_pln_fmt)
523         && utils::one_of(desc()->dst_desc.format, desired_blk_fmt, desired_pln_fmt)
524         && utils::one_of(desc()->weights_desc.format, x)
525         && IMPLICATION(this->with_bias(), x == desc()->bias_desc.format)
526         && attr()->has_default_values();
527
528     return ok ? status::success : status::unimplemented;
529 }
530
531 template <cpu_isa_t isa>
532 jit_uni_depthwise_fwd_t<isa>::jit_uni_depthwise_fwd_t(const pd_t *apd,
533         const input_vector &inputs, const output_vector &outputs)
534     : cpu_primitive_t(apd, inputs, outputs), kernel_(nullptr),
535       padded_weights_(nullptr), padded_bias_(nullptr) {
536     const auto &desc = *pd()->desc();
537     switch (desc.alg_kind) {
538         case alg_kind::depthwise_scale_shift:
539             kernel_ = new jit_uni_scale_shift_kernel_f32<isa>(desc, pd()->with_bias()); break;
540         case alg_kind::depthwise_prelu:
541             kernel_ = new jit_uni_prelu_kernel_f32<isa>(desc, pd()->with_bias()); break;
542         default: assert(!"unknown depthwise alg_kind");
543     }
544
545     const int simd_w = isa == avx512_common ? 16 : 8;
546     const memory_desc_wrapper data_d(pd()->src_pd());
547     const int c_without_padding = data_d.dims()[1];
548     const int c_padded = rnd_up(c_without_padding, simd_w);
549
550     if (pd()->want_padded_weights()) {
551         padded_weights_ = (data_t *)malloc(sizeof(data_t) * c_padded, 64);
552         for (int oc = c_without_padding; oc < c_padded; ++oc)
553             padded_weights_[oc] = 0;
554
555         if (pd()->with_bias()) {
556             padded_bias_ = (data_t *)malloc(sizeof(data_t) * c_padded, 64);
557             for (int oc = c_without_padding; oc < c_padded; ++oc)
558                 padded_bias_[oc] = 0;
559         }
560     }
561 }
562
563 template <cpu_isa_t isa>
564 jit_uni_depthwise_fwd_t<isa>::~jit_uni_depthwise_fwd_t() {
565     delete kernel_;
566     free(padded_weights_);
567     free(padded_bias_);
568 }
569
570 template <cpu_isa_t isa>
571 void jit_uni_depthwise_fwd_t<isa>::execute_forward() const {
572     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
573     auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
574     auto bias = reinterpret_cast<const data_t *>(this->input_memory(2));
575     auto dst = reinterpret_cast<data_t *>(this->memory());
576
577     const memory_desc_wrapper data_d(pd()->src_pd());
578     const memory_desc_wrapper weights_d(pd()->weights_pd(0));
579     const memory_desc_wrapper bias_d(pd()->weights_pd(1));
580
581     const int MB = pd()->MB();
582     const int C = pd()->C();
583     const int D = pd()->D();
584     const int H = pd()->H();
585     const int W = pd()->W();
586
587     const int simd_w = isa == avx512_common ? 16 : 8;
588     const int ch_block_size = (data_d.format() == nchw) || (data_d.format() == ncdhw) ? 1 : simd_w;
589     const int CB = div_up(C, ch_block_size);
590
591     if (pd()->want_padded_weights()) {
592         for (int oc = 0; oc < C; ++oc)
593             padded_weights_[oc] = weights[oc];
594         weights = padded_weights_;
595
596         if (pd()->with_bias()) {
597             for (int oc = 0; oc < C; ++oc)
598                 padded_bias_[oc] = bias[oc];
599             bias = padded_bias_;
600         }
601     }
602
603     parallel_nd(MB, CB, D, H,
604         [&](int mb, int cb, int d, int h) {
605         auto arg = jit_args();
606
607         size_t data_off = data_d.ndims() == 4
608                           ? data_d.blk_off(mb, cb, h)
609                           : data_d.ndims() == 5
610                             ? data_d.blk_off(mb, cb, d, h)
611                             : data_d.blk_off(mb, cb * ch_block_size);
612
613         arg.from    = &src[data_off];
614         arg.to      = &dst[data_off];
615         arg.weights = &weights[weights_d.blk_off(cb * ch_block_size)];
616         if (bias)
617             arg.bias = &bias[bias_d.blk_off(cb * ch_block_size)];
618         arg.work_amount = data_d.format() == nc ? nstl::min(ch_block_size, C - cb * ch_block_size) : (size_t)W;
619
620         (*kernel_)(&arg);
621     });
622 }
623
624 template struct jit_uni_depthwise_fwd_t<sse42>;
625 template struct jit_uni_depthwise_fwd_t<avx2>;
626 template struct jit_uni_depthwise_fwd_t<avx512_common>;
627
628
629 #define GET_OFF_DW(field) offsetof(jit_conv_call_s, field)
630
631 template <cpu_isa_t isa>
632 void jit_uni_dw_conv_row_f32<isa>::load_src(int ur_w) {
633     int repeats = isa == sse42 ? 2 : 1;
634     for (int i = 0; i < repeats; i++) {
635         for (int ow = 0; ow < ur_w; ow++) {
636             Vmm vmm_acc = get_acc_reg(i*ur_w + ow);
637
638             uni_vpxor(vmm_acc, vmm_acc, vmm_acc);
639         }
640     }
641 }
642
643 template <cpu_isa_t isa>
644 void jit_uni_dw_conv_row_f32<isa>::apply_filter(int ur_w, int kw_size) {
645     auto load_src = [=](Vmm vmm_src, const Xbyak::Address &op) {
646         if (jcp.src_dt == data_type::u8) {
647             uni_vpmovzxbd(vmm_src, op);
648         } else {
649             uni_vmovups(vmm_src, op);
650         }
651     };
652
653     auto load_ker = [=](Vmm vmm_ker, const Xbyak::Address &op) {
654         if (jcp.src_dt == data_type::u8) {
655             uni_vpmovsxbd(vmm_ker, op);
656         } else {
657             uni_vmovups(vmm_ker, op);
658         }
659     };
660
661     auto compute = [=](Vmm vmm_acc, Vmm vmm_src, Vmm vmm_ker) {
662         if (jcp.src_dt == data_type::u8) {
663             uni_vpmulld(vmm_src, vmm_src, vmm_ker);
664             uni_vpaddd(vmm_acc, vmm_acc, vmm_src);
665         } else {
666             uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker);
667         }
668     };
669
670     int ch_blk = jcp.ch_block;
671     int stride_w = jcp.stride_w;
672
673     Label exit_label;
674
675     int repeats = isa == sse42 ? 2 : 1;
676
677     cmp(reg_kh, 1);
678     jl(exit_label, T_NEAR);
679     for (int i = 0; i < repeats; i++) {
680         for (int kw = 0; kw < kw_size; kw++) {
681             int ker_off = kw * ch_blk + i*(jcp.ch_block / 2);
682
683             Vmm vmm_ker = get_ker_reg(0);
684             load_ker(vmm_ker, ptr[aux_reg_kernel + ker_off * jcp.typesize_in]);
685
686             for (int ow = 0; ow < ur_w; ow++) {
687                 int inp_off = ow * stride_w * ch_blk + kw * ch_blk + i*(jcp.ch_block / 2);
688
689                 Vmm vmm_src = get_src_reg(0);
690                 load_src(vmm_src, ptr[aux_reg_input0 + inp_off * jcp.typesize_in]);
691
692                 Vmm vmm_acc = get_acc_reg(i*ur_w + ow);
693                 compute(vmm_acc, vmm_src, vmm_ker);
694             }
695         }
696     }
697     add(aux_reg_kernel, jcp.kw*ch_blk*jcp.typesize_in);
698
699     cmp(reg_kh, 2);
700     jl(exit_label, T_NEAR);
701     for (int i = 0; i < repeats; i++) {
702         for (int kw = 0; kw < kw_size; kw++) {
703             int ker_off = kw * ch_blk + i*(jcp.ch_block / 2);
704
705             Vmm vmm_ker = get_ker_reg(0);
706             load_ker(vmm_ker, ptr[aux_reg_kernel + ker_off * jcp.typesize_in]);
707
708             for (int ow = 0; ow < ur_w; ow++) {
709                 int inp_off = ow * stride_w * ch_blk + kw * ch_blk + i*(jcp.ch_block / 2);
710
711                 Vmm vmm_src = get_src_reg(0);
712                 load_src(vmm_src, ptr[aux_reg_input1 + inp_off * jcp.typesize_in]);
713
714                 Vmm vmm_acc = get_acc_reg(i*ur_w + ow);
715                 compute(vmm_acc, vmm_src, vmm_ker);
716             }
717         }
718     }
719     add(aux_reg_kernel, jcp.kw*ch_blk*jcp.typesize_in);
720
721     cmp(reg_kh, 3);
722     jl(exit_label, T_NEAR);
723     for (int i = 0; i < repeats; i++) {
724         for (int kw = 0; kw < kw_size; kw++) {
725             int ker_off = kw * ch_blk + i*(jcp.ch_block / 2);
726
727             Vmm vmm_ker = get_ker_reg(0);
728             load_ker(vmm_ker, ptr[aux_reg_kernel + ker_off * jcp.typesize_in]);
729
730             for (int ow = 0; ow < ur_w; ow++) {
731                 int inp_off = ow * stride_w * ch_blk + kw * ch_blk + i*(jcp.ch_block / 2);
732
733                 Vmm vmm_src = get_src_reg(0);
734                 load_src(vmm_src, ptr[aux_reg_input2 + inp_off * jcp.typesize_in]);
735
736                 Vmm vmm_acc = get_acc_reg(i*ur_w + ow);
737                 compute(vmm_acc, vmm_src, vmm_ker);
738             }
739         }
740     }
741
742     L(exit_label);
743 }
744
745 template <cpu_isa_t isa>
746 void jit_uni_dw_conv_row_f32<isa>::cvt2ps(data_type_t type_in, Vmm vmm_in, const Operand &op, bool scalar_load) {
747     Xmm xmm_in = Xmm(vmm_in.getIdx());
748
749     switch (type_in) {
750         case data_type::f32:
751         case data_type::s32:
752             if (scalar_load) {
753                 mov(reg_tmp_32, op);
754                 movq(xmm_in, reg_tmp_64);
755             } else {
756                 uni_vmovups(vmm_in, op);
757             }
758             break;
759         case data_type::s8:
760             if (scalar_load) {
761                 movsx(reg_tmp_32, op);
762                 movq(xmm_in, reg_tmp_64);
763             } else {
764                 uni_vpmovsxbd(vmm_in, op);
765             }
766             break;
767         case data_type::u8:
768             if (scalar_load) {
769                 movzx(reg_tmp_32, op);
770                 movq(xmm_in, reg_tmp_64);
771             } else {
772                 uni_vpmovzxbd(vmm_in, op);
773             }
774             break;
775         default: assert(!"unsupported data type");
776     }
777
778     if (type_in != data_type::f32)
779         uni_vcvtdq2ps(vmm_in, vmm_in);
780 }
781
782 template <cpu_isa_t isa>
783 void jit_uni_dw_conv_row_f32<isa>::apply_postprocessing(int ur_w, int oc_step) {
784     int repeats = isa == sse42 ? 2 : 1;
785
786     for (int r = 0; r < repeats; r++) {
787         for (int ow = 0; ow < ur_w; ow++) {
788             if (jcp.src_dt == data_type::u8) {
789                 uni_vcvtdq2ps(get_acc_reg(r * ur_w + ow), get_acc_reg(r * ur_w + ow));
790             }
791
792             if (jcp.with_bias) {
793                 int b_off = r * (jcp.ch_block / 2);
794                 cvt2ps(jcp.bia_dt, vmm_bias, ptr[reg_bias + b_off * jcp.typesize_bia], false);
795                 uni_vaddps(get_acc_reg(r * ur_w + ow), get_acc_reg(r * ur_w + ow), vmm_bias);
796             }
797         }
798     }
799
800     if (jcp.with_sum) {
801         for (int r = 0; r < repeats; r++) {
802             int tail_size = isa == sse42 ? nstl::min(jcp.ch_block / 2, oc_step - r * jcp.ch_block / 2) : oc_step;
803             bool is_scalar_store = isa == sse42 ? tail_size < jcp.ch_block / 2 : tail_size < jcp.ch_block;
804
805             for (int ow = 0; ow < ur_w; ow++) {
806                 if (is_scalar_store) {
807                     if (isa == avx512_common) {
808                         int o_off = ow * ow_stride_;
809
810                         Vmm vmm_in = vmm_sum | ktail_mask | T_z;
811
812                         cvt2ps(jcp.dst_dt, vmm_in, ptr[reg_output + o_off * jcp.typesize_out], false);
813                         uni_vaddps(get_acc_reg(r * ur_w + ow), get_acc_reg(r * ur_w + ow), vmm_sum);
814                     } else {
815                         for (int oc = 0; oc < tail_size; oc++) {
816                             int o_off = ow * ow_stride_ + r * (jcp.ch_block / 2) + oc;
817
818                             uni_vpxor(vmm_sum, vmm_sum, vmm_sum);
819                             cvt2ps(jcp.dst_dt, vmm_sum, ptr[reg_output + o_off * jcp.typesize_out], true);
820
821                             if (oc >= jcp.ch_block / 2) {
822                                 vperm2i128(Ymm(vmm_sum.getIdx()), Ymm(vmm_sum.getIdx()), Ymm(vmm_sum.getIdx()), 0x01);
823                             }
824                             uni_vpslldq(vmm_sum, vmm_sum, jcp.typesize_out * (oc % (jcp.ch_block / 2)));
825
826                             uni_vaddps(get_acc_reg(r * ur_w + ow), get_acc_reg(r * ur_w + ow), vmm_sum);
827                         }
828                     }
829                 } else {
830                     int o_off = ow * ow_stride_ + r * (jcp.ch_block / 2);
831
832                     uni_vpxor(vmm_sum, vmm_sum, vmm_sum);
833                     cvt2ps(jcp.dst_dt, vmm_sum, ptr[reg_output + o_off * jcp.typesize_out], false);
834
835                     uni_vaddps(get_acc_reg(r * ur_w + ow), get_acc_reg(r * ur_w + ow), vmm_sum);
836                 }
837             }
838         }
839     }
840
841     const auto &p = attr_.post_ops_;
842     int eltwise_inj_idx = 0;
843     int depthwise_inj_idx = 0;
844     int start_idx = p.find(primitive_kind::convolution) + 1;
845     for (int i = start_idx; i < p.len_; i++) {
846         auto& post_op = p.entry_[i];
847         if (post_op.is_eltwise()) {
848             eltwise_injectors[eltwise_inj_idx]->compute_vector_range(4, 4 + repeats * ur_w);
849             eltwise_inj_idx++;
850         } else if (post_op.is_depthwise()) {
851             mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
852             mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
853
854             add(reg_d_weights, reg_oc_off);
855             add(reg_d_bias, reg_oc_off);
856
857             depthwise_injectors[depthwise_inj_idx]->compute_vector_range(4, 4 + ur_w, reg_d_weights, reg_d_bias);
858
859             if (repeats == 2) {
860                 add(reg_d_weights, (jcp.ch_block / 2) * sizeof(float));
861                 add(reg_d_bias, (jcp.ch_block / 2) * sizeof(float));
862
863                 depthwise_injectors[depthwise_inj_idx]->compute_vector_range(4 + ur_w, 4 + 2 * ur_w, reg_d_weights, reg_d_bias);
864             }
865
866             depthwise_inj_idx++;
867         }
868     }
869 }
870
871 template <cpu_isa_t isa>
872 void jit_uni_dw_conv_row_f32<isa>::store_dst_typed(const Xbyak::Address &op, Vmm vmm_dst, bool scalar_store) {
873     Ymm ymm_dst = Ymm(vmm_dst.getIdx());
874     Xmm xmm_dst = Xmm(vmm_dst.getIdx());
875
876     switch (jcp.dst_dt) {
877         case data_type::f32:
878         case data_type::s32:
879             if (scalar_store) {
880                 movq(reg_tmp_64, xmm_dst);
881                 mov(op, reg_tmp_32);
882             } else {
883                 uni_vmovups(op, vmm_dst);
884             }
885             break;
886         case data_type::s8:
887             uni_vpackssdw(vmm_dst, vmm_dst, vmm_dst);
888
889             if (isa != sse42 && !scalar_store)
890                 vpermq(ymm_dst, ymm_dst, 0x08);
891
892             uni_vpacksswb(vmm_dst, vmm_dst, vmm_dst);
893
894             if (scalar_store) {
895                 movq(reg_tmp_64, xmm_dst);
896                 mov(op, reg_tmp_8);
897             } else {
898                 if (isa != sse42)
899                     vmovq(op, xmm_dst);
900                 else
901                     movd(op, xmm_dst);
902             }
903             break;
904         case data_type::u8:
905         case data_type::bin:
906             uni_vpackusdw(vmm_dst, vmm_dst, vmm_dst);
907
908             if (isa != sse42 && !scalar_store)
909                 vpermq(ymm_dst, ymm_dst, 0x08);
910
911             uni_vpackuswb(vmm_dst, vmm_dst, vmm_dst);
912
913             if (scalar_store) {
914                 movq(reg_tmp_64, xmm_dst);
915                 mov(op, reg_tmp_8);
916             } else {
917                 if (isa != sse42)
918                     vmovq(op, xmm_dst);
919                 else
920                     movd(op, xmm_dst);
921             }
922             break;
923         default:
924             assert(!"unknown dst_dt");
925     }
926 }
927
928 template <cpu_isa_t isa>
929 void jit_uni_dw_conv_row_f32<isa>::store_dst(int ur_w, int oc_step) {
930     int nbits = 8;
931     int repeats = isa == sse42 && oc_step > (jcp.ch_block / 2) ? 2 : 1;
932
933     if (isa == avx512_common && oc_step != jcp.ch_block) {
934         int mask = (1 << oc_step) - 1;
935         mov(reg_tmp_32, mask);
936         kmovw(ktail_mask, reg_tmp_32);
937     }
938
939     for (int i = 0; i < repeats; i++) {
940         for (int ow = 0; ow < ur_w; ow++) {
941             Vmm vmm_dst = get_acc_reg(i * ur_w + ow);
942             if (jcp.dst_dt != data_type::f32 && jcp.dst_dt != data_type::bin) {
943                 if (attr_.round_mode_ == round_mode::nearest)
944                     uni_vcvtps2dq(vmm_dst, vmm_dst);
945                 else if (attr_.round_mode_ == round_mode::down) {
946                     uni_vroundps(vmm_dst, vmm_dst, 1);
947                     uni_vcvtps2dq(vmm_dst, vmm_dst);
948                 } else
949                     assert(!"unimplemented");
950             }
951         }
952     }
953
954     if (jcp.with_binarization) {
955         int output_step = div_up(ow_stride_, nbits);
956
957         const auto &p = attr_.post_ops_;
958         int binarization_idx = p.find(primitive_kind::binarization);
959
960         push(reg_bias);
961
962         mov(reg_b_weights, reinterpret_cast<size_t>(p.entry_[binarization_idx].binarization.weights_data));
963         mov(reg_b_out_mask, reinterpret_cast<size_t>(p.entry_[binarization_idx].binarization.output_mask_data));
964         add(reg_b_weights, reg_oc_off);
965         add(reg_b_out_mask, reg_oc_off);
966
967         for (int ow = 0; ow < ur_w; ow++) {
968             for (int i = 0; i < repeats; i++) {
969                 int tail_size = isa == sse42 ? nstl::min(jcp.ch_block / 2, oc_step - i * jcp.ch_block / 2) : oc_step;
970                 mov(reg_b_mask, (1 << tail_size) - 1);
971                 uni_vmovups(vmm_thr, ptr[reg_b_weights + i * (jcp.ch_block / 2) * sizeof(float)]);
972                 uni_vmovups(vmm_out_mask, ptr[reg_b_out_mask + i * (jcp.ch_block / 2) * sizeof(float)]);
973
974                 Vmm vmm_dst = get_acc_reg(i * ur_w + ow);
975
976                 if (isa == avx512_common) {
977                     vcmpps(bin_mask0, vmm_dst, vmm_thr, _cmp_gt_os);
978                     vptestmd(bin_mask1, vmm_out_mask, vmm_out_mask);
979                     kxnorw(bin_mask0, bin_mask0, bin_mask1);
980                 } else {
981                     uni_vcmpgtps(vmm_dst, vmm_dst, vmm_thr);
982                     uni_vpcmpeqd(vmm_dst, vmm_dst, vmm_out_mask);
983                 }
984
985                 if (i == 0) {
986                     if (isa == avx512_common) {
987                         kmovw(reg_tmp_32, bin_mask0);
988                     } else {
989                         uni_vmovmskps(reg_tmp_32, vmm_dst);
990                     }
991                     and_(reg_tmp_64, reg_b_mask);
992                 } else {
993                     uni_vmovmskps(reg_tmp2_32, vmm_dst);
994                     and_(reg_tmp2_64, reg_b_mask);
995                     shl(reg_tmp2_32, 4);
996                     or_(reg_tmp_32, reg_tmp2_32);
997                 }
998
999                 if (i == repeats - 1) {
1000                     const size_t o_off = ow * output_step;
1001                     if (isa == avx512_common && oc_step > nbits) {
1002                         mov(ptr[reg_output + o_off * jcp.typesize_out], reg_tmp_16);
1003                     } else {
1004                         mov(ptr[reg_output + o_off * jcp.typesize_out], reg_tmp_8);
1005                     }
1006                 }
1007             }
1008         }
1009
1010         pop(reg_bias);
1011     } else {
1012         for (int i = 0; i < repeats; i++) {
1013             int tail_size = isa == sse42 ? nstl::min(jcp.ch_block / 2, oc_step - i * jcp.ch_block / 2) : oc_step;
1014             bool is_scalar_store = isa == sse42 ? tail_size < jcp.ch_block / 2 : tail_size < jcp.ch_block;
1015             if (is_scalar_store) {
1016                 for (int ow = 0; ow < ur_w; ow++) {
1017                     Vmm vmm_dst = get_acc_reg(i * ur_w + ow);
1018
1019                     if (isa == avx512_common) {
1020                         int o_off = ow * ow_stride_;
1021
1022                         store_dst_typed(ptr[reg_output + o_off * jcp.typesize_out], vmm_dst | ktail_mask, false);
1023                     } else {
1024                         for (int oc = 0; oc < tail_size; oc++) {
1025                             int o_off = ow * ow_stride_ + i * (jcp.ch_block / 2) + oc;
1026                             store_dst_typed(ptr[reg_output + o_off * jcp.typesize_out], vmm_dst, true);
1027
1028                             if (isa == sse42) {
1029                                 psrldq(vmm_dst, jcp.typesize_out);
1030                             } else {
1031                                 Ymm ymm_dst = Ymm(vmm_dst.getIdx());
1032
1033                                 vperm2i128(ymm_tmp, ymm_dst, ymm_dst, 0x01);
1034                                 vpalignr(ymm_dst, vmm_tmp, ymm_dst, jcp.typesize_out);
1035                             }
1036                         }
1037                     }
1038                 }
1039             } else {
1040                 for (int ow = 0; ow < ur_w; ow++) {
1041                     int o_off = ow * ow_stride_ + i * (jcp.ch_block / 2);
1042                     Vmm vmm_dst = get_acc_reg(i * ur_w + ow);
1043
1044                     store_dst_typed(ptr[reg_output + o_off * jcp.typesize_out], vmm_dst, false);
1045                 }
1046             }
1047         }
1048     }
1049 }
1050
1051 template <cpu_isa_t isa>
1052 void jit_uni_dw_conv_row_f32<isa>::loop_body(int oc_step) {
1053     Label left_pad_label;
1054     Label right_pad_label;
1055     Label unrolled_w_label;
1056     Label tail_w_label;
1057     Label exit_label;
1058
1059     int output_step = jcp.with_binarization ? div_up(ow_stride_, 8) : ow_stride_;
1060
1061     L(left_pad_label); {
1062         int ur_w = 1;
1063         int kw = jcp.iw == 1 ? jcp.kw - 2 : jcp.kw - 1;
1064
1065         mov(aux_reg_input0, reg_input0);
1066         mov(aux_reg_input1, reg_input1);
1067         mov(aux_reg_input2, reg_input2);
1068         mov(aux_reg_kernel, reg_kernel);
1069         add(aux_reg_kernel, jcp.ch_block*jcp.typesize_in);
1070
1071         load_src(ur_w);
1072         apply_filter(ur_w, kw);
1073         apply_postprocessing(ur_w, oc_step);
1074         store_dst(ur_w, oc_step);
1075
1076         add(reg_input0, jcp.typesize_in * ur_w * jcp.ch_block * (jcp.stride_w-1));
1077         add(reg_input1, jcp.typesize_in * ur_w * jcp.ch_block * (jcp.stride_w-1));
1078         add(reg_input2, jcp.typesize_in * ur_w * jcp.ch_block * (jcp.stride_w-1));
1079         add(reg_output, jcp.typesize_out * ur_w * output_step);
1080
1081         sub(reg_ur_w, ur_w);
1082     }
1083
1084     L(unrolled_w_label); {
1085         int ur_w = jcp.ur_w;
1086         int kw = jcp.kw;
1087
1088         cmp(reg_ur_w, ur_w);
1089         jle(tail_w_label, T_NEAR);
1090
1091         mov(aux_reg_input0, reg_input0);
1092         mov(aux_reg_input1, reg_input1);
1093         mov(aux_reg_input2, reg_input2);
1094         mov(aux_reg_kernel, reg_kernel);
1095
1096         load_src(ur_w);
1097         apply_filter(ur_w, kw);
1098         apply_postprocessing(ur_w, oc_step);
1099         store_dst(ur_w, oc_step);
1100
1101         add(reg_input0, jcp.typesize_in * ur_w * jcp.ch_block * jcp.stride_w);
1102         add(reg_input1, jcp.typesize_in * ur_w * jcp.ch_block * jcp.stride_w);
1103         add(reg_input2, jcp.typesize_in * ur_w * jcp.ch_block * jcp.stride_w);
1104         add(reg_output, jcp.typesize_out * ur_w * output_step);
1105
1106         sub(reg_ur_w, ur_w);
1107         jmp(unrolled_w_label, T_NEAR);
1108     }
1109
1110     L(tail_w_label); {
1111         int ur_w = 1;
1112         int kw = jcp.kw;
1113
1114         cmp(reg_ur_w, ur_w);
1115         if (jcp.ow > 1)
1116             jle(right_pad_label, T_NEAR);
1117         else
1118             jle(exit_label, T_NEAR);
1119
1120         mov(aux_reg_input0, reg_input0);
1121         mov(aux_reg_input1, reg_input1);
1122         mov(aux_reg_input2, reg_input2);
1123         mov(aux_reg_kernel, reg_kernel);
1124
1125         load_src(ur_w);
1126         apply_filter(ur_w, kw);
1127         apply_postprocessing(ur_w, oc_step);
1128         store_dst(ur_w, oc_step);
1129
1130         add(reg_input0, jcp.typesize_in * ur_w * jcp.ch_block * jcp.stride_w);
1131         add(reg_input1, jcp.typesize_in * ur_w * jcp.ch_block * jcp.stride_w);
1132         add(reg_input2, jcp.typesize_in * ur_w * jcp.ch_block * jcp.stride_w);
1133         add(reg_output, jcp.typesize_out * ur_w * output_step);
1134
1135         sub(reg_ur_w, ur_w);
1136         jmp(tail_w_label, T_NEAR);
1137     }
1138
1139     if (jcp.ow > 1) {
1140         L(right_pad_label); {
1141             int ur_w = 1;
1142             int kw = jcp.kw - ((jcp.stride_w == 1) ? 1 : jcp.iw % jcp.stride_w);
1143
1144             mov(aux_reg_input0, reg_input0);
1145             mov(aux_reg_input1, reg_input1);
1146             mov(aux_reg_input2, reg_input2);
1147             mov(aux_reg_kernel, reg_kernel);
1148
1149             load_src(ur_w);
1150             apply_filter(ur_w, kw);
1151             apply_postprocessing(ur_w, oc_step);
1152             store_dst(ur_w, oc_step);
1153
1154             sub(reg_ur_w, ur_w);
1155         }
1156     }
1157
1158     L(exit_label);
1159 }
1160
1161 template <cpu_isa_t isa>
1162 void jit_uni_dw_conv_row_f32<isa>::generate() {
1163     const auto &p = attr_.post_ops_;
1164     int start_idx = p.find(primitive_kind::convolution) + 1;
1165     for (int i = start_idx; i < p.len_; i++) {
1166         auto &post_op = p.entry_[i];
1167         if (post_op.is_eltwise()) {
1168             eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<isa>(
1169                     this,
1170                     post_op.eltwise.alg,
1171                     post_op.eltwise.alpha,
1172                     post_op.eltwise.beta
1173             ));
1174         } else if (post_op.is_depthwise()) {
1175             depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32<isa>(
1176                     this,
1177                     post_op.depthwise.alg
1178             ));
1179         }
1180     }
1181
1182     this->preamble();
1183
1184     mov(reg_input0, ptr[this->param1 + GET_OFF_DW(src_row0)]);
1185     mov(reg_input1, ptr[this->param1 + GET_OFF_DW(src_row1)]);
1186     mov(reg_input2, ptr[this->param1 + GET_OFF_DW(src_row2)]);
1187     mov(reg_output, ptr[this->param1 + GET_OFF_DW(dst)]);
1188     mov(reg_kernel, ptr[this->param1 + GET_OFF_DW(filt)]);
1189     if (jcp.with_bias)
1190         mov(reg_bias, ptr[this->param1 + GET_OFF_DW(bias)]);
1191     mov(reg_kh, ptr[this->param1 + GET_OFF_DW(kh_padding)]);
1192     mov(reg_ur_w, ptr[this->param1 + GET_OFF_DW(ur_w)]);
1193     mov(reg_oc_work, ptr[this->param1 + GET_OFF_DW(oc_work)]);
1194     mov(reg_oc_off, ptr[this->param1 + GET_OFF_DW(oc_off)]);
1195
1196     Label(tail_label);
1197     Label(exit_label);
1198
1199     cmp(reg_oc_work, jcp.ch_block);
1200     jl(tail_label, T_NEAR);
1201
1202     loop_body(jcp.ch_block);
1203     jmp(exit_label, T_NEAR);
1204
1205     L(tail_label);
1206
1207     if (jcp.oc % jcp.ch_block != 0)
1208         loop_body(jcp.oc % jcp.ch_block);
1209
1210     L(exit_label);
1211
1212     this->postamble();
1213
1214     for (auto& inj : eltwise_injectors)
1215         inj->prepare_table();
1216 }
1217
1218 template <cpu_isa_t isa>
1219 bool jit_uni_dw_conv_row_f32<isa>::post_ops_ok(jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
1220     const auto &p = attr.post_ops_;
1221
1222     int start_idx = p.find(primitive_kind::convolution) + 1;
1223
1224     auto all_post_ops_supported = [&]() {
1225         bool ok = true;
1226
1227         for (int i = start_idx; i < p.len_; i++) {
1228             ok = ok && utils::one_of(p.entry_[i].kind, primitive_kind::sum, primitive_kind::eltwise, primitive_kind::depthwise,
1229                                                        primitive_kind::binarization);
1230         }
1231         return ok;
1232     };
1233     auto contain = [&](mkldnn::impl::primitive_kind_t kind) { return p.find(kind, start_idx, -1) != -1; };
1234     auto position = [&](mkldnn::impl::primitive_kind_t kind) { return p.find(kind, start_idx, -1); };
1235     auto count = [&](mkldnn::impl::primitive_kind_t kind) { return p.count(kind, start_idx, -1); };
1236
1237     return all_post_ops_supported() &&
1238            count(primitive_kind::sum) <= 1 &&
1239            count(primitive_kind::binarization) <= 1 &&
1240            IMPLICATION(contain(primitive_kind::sum), position(primitive_kind::sum) == start_idx) &&
1241            IMPLICATION(contain(primitive_kind::binarization), position(primitive_kind::binarization) == p.len_-1) &&
1242            IMPLICATION(contain(primitive_kind::binarization), !contain(primitive_kind::sum));
1243 }
1244
1245 template <cpu_isa_t isa>
1246 status_t jit_uni_dw_conv_row_f32<isa>::init_conf(jit_1x1_conv_conf_t &jcp, jit_conv_conf_t &jcp_dw,
1247         const primitive_attr_t &attr) {
1248     if (!mayiuse(isa)) return status::unimplemented;
1249     const int simd_w = isa == avx512_common ? 16 : 8;
1250
1251     const auto &p = attr.post_ops_;
1252
1253     int dw_conv_ind = p.find(primitive_kind::convolution);
1254     jcp_dw.with_sum = p.find(primitive_kind::sum, dw_conv_ind) != -1;
1255
1256     jcp_dw.ch_block = simd_w;
1257     jcp_dw.with_bias = true;
1258
1259     jcp_dw.kh = p.entry_[dw_conv_ind].dw_conv.ker_h;
1260     jcp_dw.kw = p.entry_[dw_conv_ind].dw_conv.ker_w;
1261     jcp_dw.ic = jcp.oc;
1262     jcp_dw.oc = jcp.oc;
1263     jcp_dw.ih = p.entry_[dw_conv_ind].dw_conv.in_h;
1264     jcp_dw.iw = p.entry_[dw_conv_ind].dw_conv.in_w;
1265     jcp_dw.oh = jcp.dw_conv_oh;
1266     jcp_dw.ow = jcp.dw_conv_ow;
1267     jcp_dw.stride_h = p.entry_[dw_conv_ind].dw_conv.str_h;
1268     jcp_dw.stride_w = p.entry_[dw_conv_ind].dw_conv.str_w;
1269     jcp_dw.conv_weights = p.entry_[dw_conv_ind].dw_conv.weights_data;
1270     jcp_dw.conv_biases = p.entry_[dw_conv_ind].dw_conv.biases_data;
1271
1272     if (jcp_dw.kh != 3 || jcp_dw.kw != 3)
1273         return status::unimplemented;
1274
1275     if (!post_ops_ok(jcp_dw, attr))
1276         return status::unimplemented;
1277
1278     jcp_dw.ur_w = 4;
1279
1280     jcp_dw.src_dt = jcp.dst_dt;
1281     jcp_dw.dst_dt = jcp.dw_conv_dst_dt;
1282     jcp_dw.bia_dt = jcp.bia_dt;
1283     jcp_dw.typesize_in = (int)types::data_type_size(jcp_dw.src_dt);
1284     jcp_dw.typesize_bia = (int)types::data_type_size(jcp_dw.bia_dt);
1285     jcp_dw.typesize_out = (int)types::data_type_size(jcp_dw.dst_dt);
1286
1287     if (jcp_dw.src_dt != mkldnn_f32 && jcp_dw.src_dt != mkldnn_u8)
1288         return status::unimplemented;
1289
1290     return status::success;
1291 }
1292
1293 template <cpu_isa_t isa>
1294 status_t jit_uni_dw_conv_row_f32<isa>::init_conf(jit_conv_conf_t &jcp, jit_conv_conf_t &jcp_dw,
1295         const primitive_attr_t &attr) {
1296     if (!mayiuse(isa)) return status::unimplemented;
1297     const int simd_w = isa == avx512_common ? 16 : 8;
1298
1299     const auto &p = attr.post_ops_;
1300
1301     int dw_conv_ind = p.find(primitive_kind::convolution);
1302     jcp_dw.with_sum = p.find(primitive_kind::sum, dw_conv_ind) != -1;
1303
1304     jcp_dw.ch_block = simd_w;
1305     jcp_dw.with_bias = true;
1306
1307     jcp_dw.kh = p.entry_[dw_conv_ind].dw_conv.ker_h;
1308     jcp_dw.kw = p.entry_[dw_conv_ind].dw_conv.ker_w;
1309     jcp_dw.ic = jcp.oc;
1310     jcp_dw.oc = jcp.oc;
1311     jcp_dw.ih = p.entry_[dw_conv_ind].dw_conv.in_h;
1312     jcp_dw.iw = p.entry_[dw_conv_ind].dw_conv.in_w;
1313     jcp_dw.oh = jcp.dw_conv_oh;
1314     jcp_dw.ow = jcp.dw_conv_ow;
1315     jcp_dw.stride_h = p.entry_[dw_conv_ind].dw_conv.str_h;
1316     jcp_dw.stride_w = p.entry_[dw_conv_ind].dw_conv.str_w;
1317     jcp_dw.conv_weights = p.entry_[dw_conv_ind].dw_conv.weights_data;
1318     jcp_dw.conv_biases = p.entry_[dw_conv_ind].dw_conv.biases_data;
1319
1320     if (jcp_dw.kh != 3 || jcp_dw.kw != 3)
1321         return status::unimplemented;
1322
1323     if (!post_ops_ok(jcp_dw, attr))
1324         return status::unimplemented;
1325
1326     jcp_dw.ur_w = 4;
1327
1328     jcp_dw.src_dt = jcp.dst_dt;
1329     jcp_dw.dst_dt = jcp.dw_conv_dst_dt;
1330     jcp_dw.bia_dt = jcp.bia_dt;
1331     jcp_dw.typesize_in = (int)types::data_type_size(jcp_dw.src_dt);
1332     jcp_dw.typesize_bia = (int)types::data_type_size(jcp_dw.bia_dt);
1333     jcp_dw.typesize_out = (int)types::data_type_size(jcp_dw.dst_dt);
1334
1335     if (jcp_dw.src_dt != mkldnn_f32 && jcp_dw.src_dt != mkldnn_u8)
1336         return status::unimplemented;
1337
1338     return status::success;
1339 }
1340
1341 template <cpu_isa_t isa>
1342 status_t jit_uni_dw_conv_row_f32<isa>::init_conf(jit_bin_conv_conf_t &jcp, jit_conv_conf_t &jcp_dw,
1343         const primitive_attr_t &attr) {
1344     if (!mayiuse(isa)) return status::unimplemented;
1345     const int simd_w = isa == avx512_common ? 16 : 8;
1346
1347     const auto &p = attr.post_ops_;
1348
1349     int dw_conv_ind = p.find(primitive_kind::convolution);
1350     jcp_dw.with_sum = p.find(primitive_kind::sum, dw_conv_ind) != -1;
1351     jcp_dw.with_binarization = p.find(primitive_kind::binarization, dw_conv_ind) != -1;
1352
1353     jcp_dw.ch_block = simd_w;
1354     jcp_dw.with_bias = true;
1355
1356     jcp_dw.kh = p.entry_[dw_conv_ind].dw_conv.ker_h;
1357     jcp_dw.kw = p.entry_[dw_conv_ind].dw_conv.ker_w;
1358     jcp_dw.ic = jcp.oc;
1359     jcp_dw.oc = jcp.oc;
1360     jcp_dw.ih = p.entry_[dw_conv_ind].dw_conv.in_h;
1361     jcp_dw.iw = p.entry_[dw_conv_ind].dw_conv.in_w;
1362     jcp_dw.oh = jcp.dw_conv_oh;
1363     jcp_dw.ow = jcp.dw_conv_ow;
1364     jcp_dw.stride_h = p.entry_[dw_conv_ind].dw_conv.str_h;
1365     jcp_dw.stride_w = p.entry_[dw_conv_ind].dw_conv.str_w;
1366     jcp_dw.conv_weights = p.entry_[dw_conv_ind].dw_conv.weights_data;
1367     jcp_dw.conv_biases = p.entry_[dw_conv_ind].dw_conv.biases_data;
1368
1369     if (jcp_dw.kh != 3 || jcp_dw.kw != 3)
1370         return status::unimplemented;
1371
1372     if (!post_ops_ok(jcp_dw, attr))
1373         return status::unimplemented;
1374
1375     jcp_dw.ur_w = 4;
1376
1377     jcp_dw.src_dt = mkldnn_f32;
1378     jcp_dw.dst_dt = jcp_dw.with_binarization ? mkldnn_bin : mkldnn_f32;
1379     jcp_dw.bia_dt = mkldnn_f32;
1380     jcp_dw.typesize_in = (int)types::data_type_size(jcp_dw.src_dt);
1381     jcp_dw.typesize_bia = (int)types::data_type_size(jcp_dw.bia_dt);
1382     jcp_dw.typesize_out = (int)types::data_type_size(jcp_dw.dst_dt);
1383
1384     if (jcp_dw.src_dt != mkldnn_f32 && jcp_dw.src_dt != mkldnn_u8)
1385         return status::unimplemented;
1386
1387     return status::success;
1388 }
1389
1390 template struct jit_uni_dw_conv_row_f32<avx512_common>;
1391 template struct jit_uni_dw_conv_row_f32<avx2>;
1392 template struct jit_uni_dw_conv_row_f32<sse42>;
1393
1394 }
1395 }
1396 }