c09dbe815b8ce9dfaaa9199156c691eb035defae
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_depthwise.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <mkldnn_types.h>
18 #include "mkldnn_types.h"
19 #include "mkldnn_thread.hpp"
20 #include "nstl.hpp"
21 #include "utils.hpp"
22 #include "jit_generator.hpp"
23
24 #include "jit_uni_depthwise.hpp"
25
26 #define GET_OFF(field) offsetof(jit_args, field)
27
28 namespace mkldnn {
29 namespace impl {
30 namespace cpu {
31
32 using namespace Xbyak;
33 using namespace mkldnn::impl::memory_format;
34 using namespace mkldnn::impl::utils;
35
36 struct jit_args {
37     const float *from;
38     const float *to;
39     const float *weights;
40     const float *bias;
41     size_t work_amount;
42 };
43
44 struct jit_uni_depthwise_kernel_f32 : public c_compatible {
45     const depthwise_desc_t &desc_;
46     void (*ker_)(const jit_args *);
47     bool with_bias_;
48
49     void operator()(const jit_args *args) { assert(ker_); ker_(args); }
50
51     jit_uni_depthwise_kernel_f32(const depthwise_desc_t &desc, bool with_bias)
52         : desc_(desc), ker_(nullptr), with_bias_(with_bias) {}
53     virtual ~jit_uni_depthwise_kernel_f32() {}
54 };
55
56 template <cpu_isa_t isa>
57 int jit_uni_depthwise_injector_f32<isa>::aux_vecs_count(alg_kind_t depthwise_alg) {
58     switch (depthwise_alg) {
59         case alg_kind::depthwise_scale_shift: return 0;
60         case alg_kind::depthwise_prelu: return 2;
61         default: assert(!"unsupported depthwise algorithm");
62     }
63
64     return 0;
65 }
66
67 template <cpu_isa_t isa>
68 void jit_uni_depthwise_injector_f32<isa>::injector_preamble(size_t start_idx, size_t end_idx) {
69     preserved_vecs_count = 0;
70     vecs_to_preserve = (size_t)jit_uni_depthwise_injector_f32<isa>::aux_vecs_count(depthwise_alg);
71
72     for (size_t i = 0; i < vecs_count; i++) {
73         if (preserved_vecs_count >= vecs_to_preserve)
74             break;
75
76         if (i < start_idx || i >= end_idx) {
77             preserved_vec_idxs[preserved_vecs_count] = i;
78             preserved_vecs_count++;
79         }
80     }
81
82     start_idx_tail = start_idx;
83     size_t preserved_vecs_count_tail = vecs_to_preserve - preserved_vecs_count;
84     for (size_t i = 0; i < preserved_vecs_count_tail; i++) {
85         preserved_vec_idxs[preserved_vecs_count] = start_idx + i;
86         preserved_vecs_count++;
87         start_idx_tail = start_idx + i + 1;
88     }
89
90     h->sub(h->rsp, preserved_vecs_count * vlen);
91     for (size_t i = 0; i < preserved_vecs_count; ++i)
92         h->uni_vmovups(h->ptr[h->rsp + i * vlen], Vmm(preserved_vec_idxs[i]));
93
94     assign_regs();
95 }
96
97 template <cpu_isa_t isa>
98 void jit_uni_depthwise_injector_f32<isa>::injector_preamble_tail(size_t start_idx, size_t end_idx) {
99     size_t tail_vecs_to_preserve = start_idx_tail - start_idx;
100     int idx_off = (vecs_to_preserve - tail_vecs_to_preserve);
101
102     if (tail_vecs_to_preserve > 0) {
103         h->add(h->rsp, idx_off * vlen);
104         for (size_t i = 0; i < tail_vecs_to_preserve; ++i)
105             h->uni_vmovups(Vmm(preserved_vec_idxs[idx_off + i]), h->ptr[h->rsp + i * vlen]);
106
107         for (size_t i = 0; i < tail_vecs_to_preserve; ++i) {
108             preserved_vec_idxs[idx_off + i] += tail_vecs_to_preserve;
109         }
110
111         for (size_t i = 0; i < tail_vecs_to_preserve; ++i)
112             h->uni_vmovups(h->ptr[h->rsp + i * vlen], Vmm(preserved_vec_idxs[idx_off + i]));
113         h->sub(h->rsp, idx_off * vlen);
114
115         assign_regs();
116     }
117 }
118
119 template <cpu_isa_t isa>
120 void jit_uni_depthwise_injector_f32<isa>::injector_postamble() {
121     for (size_t i = 0; i < preserved_vecs_count; ++i)
122         h->uni_vmovups(Vmm(preserved_vec_idxs[i]), h->ptr[h->rsp + i * vlen]);
123     h->add(h->rsp, preserved_vecs_count * vlen);
124 }
125
126 template <cpu_isa_t isa>
127 void jit_uni_depthwise_injector_f32<isa>::assign_regs() {
128     vmm_mask = Vmm(preserved_vec_idxs[0]);
129     vmm_aux0 = Vmm(preserved_vec_idxs[1]);
130 }
131
132 template <cpu_isa_t isa>
133 void jit_uni_depthwise_injector_f32<isa>::scale_shift_compute_vector(const Vmm &vmm_src,
134         const Xbyak::Reg64& p_weights, const Xbyak::Reg64& p_bias) {
135     h->uni_vmulps(vmm_src, vmm_src, h->ptr[p_weights]);
136     h->uni_vaddps(vmm_src, vmm_src, h->ptr[p_bias]);
137 }
138
139 template <cpu_isa_t isa>
140 void jit_uni_depthwise_injector_f32<isa>::prelu_compute_vector(const Vmm &vmm_src,
141         const Xbyak::Reg64& p_weights, const Xbyak::Reg64& p_bias) {
142     const unsigned char _cmp_gt_os = 6;
143     const unsigned char _cmp_lt_os = 1;
144
145     if (isa == sse42) {
146         h->pxor(vmm_mask, vmm_mask);
147         h->cmpps(vmm_mask, vmm_src, _cmp_gt_os);
148         h->movups(vmm_aux0, vmm_src);
149         h->mulps(vmm_aux0, h->ptr[p_weights]);
150         h->blendvps(vmm_src, vmm_aux0);
151     } else if (isa == avx2) {
152         h->vxorps(vmm_mask, vmm_mask, vmm_mask);
153         h->vcmpgtps(vmm_mask, vmm_src, vmm_mask);
154         h->vmulps(vmm_aux0, vmm_src, h->ptr[p_weights]);
155         h->vblendvps(vmm_src, vmm_aux0, vmm_src, vmm_mask);
156     } else if (isa == avx512_common) {
157         h->vxorpd(vmm_mask, vmm_mask, vmm_mask);
158         h->vmovups(vmm_aux0, vmm_src);
159         h->vcmpps(k_mask, vmm_src, vmm_mask, _cmp_lt_os);
160         h->vmulps(vmm_src | k_mask, vmm_aux0, h->ptr[p_weights]);
161     }
162 }
163
164 template <cpu_isa_t isa>
165 void jit_uni_depthwise_injector_f32<isa>::compute_body(size_t start_idx, size_t end_idx,
166         const Xbyak::Reg64& p_weights, const Xbyak::Reg64& p_bias) {
167     for (size_t idx = start_idx; idx < end_idx; idx++) {
168         switch (depthwise_alg) {
169             case alg_kind::depthwise_scale_shift:
170                 scale_shift_compute_vector(Vmm(idx), p_weights, p_bias); break;
171             case alg_kind::depthwise_prelu:
172                 prelu_compute_vector(Vmm(idx), p_weights, p_bias); break;
173             default: assert(!"unsupported depthwise algorithm");
174         }
175     }
176 }
177
178 template <cpu_isa_t isa>
179 void jit_uni_depthwise_injector_f32<isa>::compute_vector_range(int start_idx, int end_idx,
180         const Xbyak::Reg64& p_weights, const Xbyak::Reg64& p_bias) {
181     injector_preamble(start_idx, end_idx);
182     compute_body(start_idx_tail, end_idx, p_weights, p_bias);
183     injector_preamble_tail(start_idx, end_idx);
184     compute_body(start_idx, start_idx_tail, p_weights, p_bias);
185     injector_postamble();
186 }
187
188 template struct jit_uni_depthwise_injector_f32<avx512_common>;
189 template struct jit_uni_depthwise_injector_f32<avx2>;
190 template struct jit_uni_depthwise_injector_f32<sse42>;
191
192 /* jit kernels */
193 namespace {
194
195 template <cpu_isa_t isa>
196 struct jit_uni_scale_shift_kernel_f32 : public jit_uni_depthwise_kernel_f32,
197     public jit_generator
198 {
199     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_scale_shift_kernel_f32)
200     jit_uni_scale_shift_kernel_f32(const depthwise_desc_t &desc, bool with_bias)
201         : jit_uni_depthwise_kernel_f32(desc, with_bias), jit_generator() {
202         assert(desc.alg_kind == alg_kind::depthwise_scale_shift);
203         assert(isa == sse42 || isa == avx2 || isa == avx512_common);
204
205         bool isFlat = desc.src_desc.format == nchw && desc.dst_desc.format == nchw ;
206
207         Reg64 param = abi_param1;
208
209         const int block_size = isa == avx512_common ? 16 : 8;
210         const int main_loop_step = isFlat ? block_size : 1;
211
212         this->preamble();
213
214         mov(reg_from, ptr[param + GET_OFF(from)]);
215         mov(reg_to, ptr[param + GET_OFF(to)]);
216         mov(reg_scale, ptr[param + GET_OFF(weights)]);
217         mov(reg_work_amount, ptr[param + GET_OFF(work_amount)]);
218         if (with_bias_)
219             mov(reg_shift, ptr[param + GET_OFF(bias)]);
220
221         Label main_loop_label;
222         Label tail_loop_label;
223         Label exit_label;
224
225         int repeats = isa == sse42 ? 2 : 1;
226         for (int i = 0; i < repeats; i++) {
227             if (isFlat) {
228                 uni_vbroadcastss(get_scale_reg(i), ptr[reg_scale]);
229                 if (with_bias_)
230                     uni_vbroadcastss(get_shift_reg(i), ptr[reg_shift]);
231                 else
232                     uni_vpxor(get_shift_reg(i), get_shift_reg(i), get_shift_reg(i));
233             } else {
234                 uni_vmovups(get_scale_reg(i), ptr[reg_scale + i*4*sizeof(float)]);
235                 if (with_bias_)
236                     uni_vmovups(get_shift_reg(i), ptr[reg_shift + i*4*sizeof(float)]);
237                 else
238                     uni_vpxor(get_shift_reg(i), get_shift_reg(i), get_shift_reg(i));
239             }
240         }
241
242         if (isFlat) {
243             uni_vbroadcastss(xmm_scale, ptr[reg_scale]);
244             if (with_bias_)
245                 uni_vbroadcastss(xmm_shift, ptr[reg_shift]);
246             else
247                 uni_vpxor(xmm_shift, xmm_shift, xmm_shift);
248         }
249
250         L(main_loop_label); {
251             cmp(reg_work_amount, main_loop_step-1);
252             jle(tail_loop_label, T_NEAR);
253
254             int repeats = isa == sse42 ? 2 : 1;
255             for (int i = 0; i < repeats; i++) {
256                 uni_vmovups(vmm_src, ptr[reg_from + i*4*sizeof(float)]);
257                 uni_vmovups(vmm_dst, get_shift_reg(i));
258                 uni_vfmadd231ps(vmm_dst, vmm_src, get_scale_reg(i));
259                 uni_vmovups(ptr[reg_to + i*4*sizeof(float)], vmm_dst);
260             }
261
262             add(reg_from, block_size*sizeof(float));
263             add(reg_to, block_size*sizeof(float));
264             sub(reg_work_amount, main_loop_step);
265
266             jmp(main_loop_label, T_NEAR);
267         }
268
269         L(tail_loop_label); {
270             cmp(reg_work_amount, 0);
271             jle(exit_label, T_NEAR);
272
273             movss(xmm_src, ptr[reg_from]);
274             uni_vmovups(xmm_dst, xmm_shift);
275             uni_vfmadd231ps(xmm_dst, xmm_src, xmm_scale);
276             movss(ptr[reg_to], xmm_dst);
277
278             add(reg_from, 1*sizeof(float));
279             add(reg_to, 1*sizeof(float));
280             dec(reg_work_amount);
281
282             jmp(tail_loop_label, T_NEAR);
283         }
284
285         L(exit_label);
286
287         this->postamble();
288
289         ker_ = (decltype(ker_))this->getCode();
290     }
291
292 private:
293     using Vmm = typename utils::conditional3<isa == sse42, Xmm,
294                                              isa == avx2, Ymm, Zmm>::type;
295
296     inline Vmm get_scale_reg(int idx) { return Vmm(idx + 2); }
297     inline Vmm get_shift_reg(int idx) { return Vmm(idx + 4); }
298
299     Reg64 reg_from = r8;
300     Reg64 reg_to = r9;
301     Reg64 reg_work_amount = r10;
302     Reg64 reg_scale = r11;
303     Reg64 reg_shift = r12;
304
305     Vmm vmm_src = Vmm(0);
306     Vmm vmm_dst = Vmm(1);
307
308     Xmm xmm_src = Xmm(0);
309     Xmm xmm_dst = Xmm(1);
310     Xmm xmm_scale = Xmm(6);
311     Xmm xmm_shift = Xmm(7);
312 };
313
314 template <cpu_isa_t isa>
315 struct jit_uni_prelu_kernel_f32 : public jit_uni_depthwise_kernel_f32,
316     public jit_generator
317 {
318     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_prelu_kernel_f32)
319     jit_uni_prelu_kernel_f32(const depthwise_desc_t &desc, bool with_bias)
320         : jit_uni_depthwise_kernel_f32(desc, with_bias), jit_generator() {
321         assert(desc.alg_kind == alg_kind::depthwise_prelu);
322         assert(isa == sse42 || isa == avx2 || isa == avx512_common);
323
324         bool isFlat = desc.src_desc.format == nchw && desc.dst_desc.format == nchw;
325
326         Reg64 param = abi_param1;
327
328         const int block_size = isa == avx512_common ? 16 : 8;
329         const int main_loop_step = isFlat ? block_size : 1;
330
331         this->preamble();
332
333         mov(reg_from, ptr[param + GET_OFF(from)]);
334         mov(reg_to, ptr[param + GET_OFF(to)]);
335         mov(reg_scale, ptr[param + GET_OFF(weights)]);
336         mov(reg_work_amount, ptr[param + GET_OFF(work_amount)]);
337
338         uni_vpxor(vmm_zero, vmm_zero, vmm_zero);
339
340         int repeats = isa == sse42 ? 2 : 1;
341         for (int i = 0; i < repeats; i++) {
342             if (isFlat) {
343                 uni_vbroadcastss(get_scale_reg(i), ptr[reg_scale]);
344             } else {
345                 uni_vmovups(get_scale_reg(i), ptr[reg_scale + i*4*sizeof(float)]);
346             }
347         }
348
349         if (isFlat) {
350             uni_vbroadcastss(xmm_scale, ptr[reg_scale]);
351         }
352
353         Label main_loop_label;
354         Label tail_loop_label;
355         Label exit_label;
356
357         L(main_loop_label); {
358             cmp(reg_work_amount, main_loop_step-1);
359             jle(tail_loop_label, T_NEAR);
360
361             for (int i = 0; i < repeats; i++) {
362                 uni_vmovups(vmm_src, ptr[reg_from + i*4*sizeof(float)]);
363
364                 if (isa == sse42) {
365                     pxor(vmm_mask, vmm_mask);
366                     cmpps(vmm_mask, vmm_src, _cmp_gt_os);
367                     movups(vmm_dst, vmm_src);
368                     mulps(vmm_src, get_scale_reg(i));
369                     blendvps(vmm_dst, vmm_src);
370                 } else if (isa == avx2) {
371                     vcmpgtps(vmm_mask, vmm_src, vmm_zero);
372                     vmulps(vmm_dst, vmm_src, get_scale_reg(i));
373                     vblendvps(vmm_dst, vmm_dst, vmm_src, vmm_mask);
374                 } else if (isa == avx512_common) {
375                     Opmask kmask = Opmask(7);
376                     vmovups(vmm_dst, vmm_src);
377                     vcmpps(kmask, vmm_src, vmm_zero, _cmp_lt_os);
378                     vmulps(vmm_dst | kmask, vmm_src, get_scale_reg(i));
379                 }
380
381                 uni_vmovups(ptr[reg_to + i*4*sizeof(float)], vmm_dst);
382             }
383
384             add(reg_from, block_size*sizeof(float));
385             add(reg_to, block_size*sizeof(float));
386             sub(reg_work_amount, main_loop_step);
387
388             jmp(main_loop_label, T_NEAR);
389         }
390
391         L(tail_loop_label); {
392             cmp(reg_work_amount, 0);
393             jle(exit_label, T_NEAR);
394
395             movss(xmm_src, ptr[reg_from]);
396
397             pxor(xmm_mask, xmm_mask);
398             cmpps(xmm_mask, xmm_src, _cmp_gt_os);
399             movups(xmm_dst, xmm_src);
400             mulps(xmm_src, xmm_scale);
401             blendvps(xmm_dst, xmm_src);
402
403             movss(ptr[reg_to], xmm_dst);
404
405             add(reg_from, 1*sizeof(float));
406             add(reg_to, 1*sizeof(float));
407             dec(reg_work_amount);
408
409             jmp(tail_loop_label, T_NEAR);
410         }
411
412         L(exit_label);
413
414         this->postamble();
415
416         ker_ = (decltype(ker_))this->getCode();
417     }
418
419 private:
420     using Vmm = typename utils::conditional3<isa == sse42, Xmm,
421                                              isa == avx2, Ymm, Zmm>::type;
422
423     inline Vmm get_scale_reg(int idx) { return Vmm(idx + 4); }
424
425     Reg64 reg_from = r8;
426     Reg64 reg_to = r9;
427     Reg64 reg_work_amount = r10;
428     Reg64 reg_scale = r11;
429
430     Vmm vmm_mask = Vmm(0);
431     Vmm vmm_src = Vmm(1);
432     Vmm vmm_zero = Vmm(2);
433     Vmm vmm_dst = Vmm(3);
434
435     Xmm xmm_mask = Xmm(0);
436     Xmm xmm_src = Xmm(1);
437     Xmm xmm_dst = Xmm(3);
438     Xmm xmm_scale = Xmm(4);
439
440     const unsigned char _cmp_gt_os = 6;
441     const unsigned char _cmp_lt_os = 1;
442 };
443
444 } /* namespace */
445
446 template <cpu_isa_t isa>
447 status_t jit_uni_depthwise_fwd_t<isa>::pd_t::init() {
448     using namespace alg_kind;
449
450     auto desired_blk_fmt = isa == avx512_common ? nChw16c : nChw8c;
451
452     assert(engine()->kind() == engine_kind::cpu);
453     bool ok = true && mayiuse(isa)
454         && utils::one_of(desc()->prop_kind, prop_kind::forward_training,
455                 prop_kind::forward_inference)
456         && utils::everyone_is(data_type::f32, desc()->src_desc.data_type, desc()->dst_desc.data_type)
457         && desc()->src_desc.format == desc()->dst_desc.format
458         && utils::one_of(desc()->src_desc.format, desired_blk_fmt, nchw)
459         && utils::one_of(desc()->dst_desc.format, desired_blk_fmt, nchw)
460         && utils::one_of(desc()->weights_desc.format, x)
461         && utils::implication(this->with_bias(), x == desc()->bias_desc.format)
462         && attr()->has_default_values();
463
464     return ok ? status::success : status::unimplemented;
465 }
466
467 template <cpu_isa_t isa>
468 jit_uni_depthwise_fwd_t<isa>::jit_uni_depthwise_fwd_t(const pd_t *pd,
469         const input_vector &inputs, const output_vector &outputs)
470     : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd), kernel_(nullptr),
471       padded_weights_(nullptr), padded_bias_(nullptr) {
472     const auto &desc = *conf_.desc();
473     switch (desc.alg_kind) {
474         case alg_kind::depthwise_scale_shift:
475             kernel_ = new jit_uni_scale_shift_kernel_f32<isa>(desc, pd->with_bias()); break;
476         case alg_kind::depthwise_prelu:
477             kernel_ = new jit_uni_prelu_kernel_f32<isa>(desc, pd->with_bias()); break;
478         default: assert(!"unknown depthwise alg_kind");
479     }
480
481     const int simd_w = isa == avx512_common ? 16 : 8;
482     const memory_desc_wrapper data_d(conf_.src_pd());
483     const int c_without_padding = data_d.dims()[1];
484     const int c_padded = rnd_up(c_without_padding, simd_w);
485
486     if (conf_.want_padded_weights()) {
487         padded_weights_ = (data_t *)malloc(sizeof(data_t) * c_padded, 64);
488         for (int oc = c_without_padding; oc < c_padded; ++oc)
489             padded_weights_[oc] = 0;
490
491         if (conf_.with_bias()) {
492             padded_bias_ = (data_t *)malloc(sizeof(data_t) * c_padded, 64);
493             for (int oc = c_without_padding; oc < c_padded; ++oc)
494                 padded_bias_[oc] = 0;
495         }
496     }
497 }
498
499 template <cpu_isa_t isa>
500 jit_uni_depthwise_fwd_t<isa>::~jit_uni_depthwise_fwd_t() {
501     delete kernel_;
502     free(padded_weights_);
503     free(padded_bias_);
504 }
505
506 template <cpu_isa_t isa>
507 void jit_uni_depthwise_fwd_t<isa>::execute_forward() {
508     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
509     auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
510     auto bias = reinterpret_cast<const data_t *>(this->input_memory(2));
511     auto dst = reinterpret_cast<data_t *>(this->memory());
512
513     const memory_desc_wrapper data_d(conf_.src_pd());
514     const memory_desc_wrapper weights_d(conf_.weights_pd(0));
515     const memory_desc_wrapper bias_d(conf_.weights_pd(1));
516
517     const int N = data_d.dims()[0];
518     const int C = data_d.dims()[1];
519     const int H = data_d.dims()[2];
520     const int W = data_d.dims()[3];
521
522     const int simd_w = isa == avx512_common ? 16 : 8;
523     const int ch_block_size = data_d.format() == nchw ? 1 : simd_w;
524     const int CB = div_up(C, ch_block_size);
525
526     if (conf_.want_padded_weights()) {
527         for (int oc = 0; oc < C; ++oc)
528             padded_weights_[oc] = weights[oc];
529         weights = padded_weights_;
530
531         if (conf_.with_bias()) {
532             for (int oc = 0; oc < C; ++oc)
533                 padded_bias_[oc] = bias[oc];
534             bias = padded_bias_;
535         }
536     }
537
538     parallel_nd(N, CB, H,
539         [&](int n, int cb, int h) {
540         jit_args arg = {};
541
542         arg.from    = &src[data_d.blk_off(n, cb, h)];
543         arg.to      = &dst[data_d.blk_off(n, cb, h)];
544         arg.weights = &weights[weights_d.blk_off(cb * ch_block_size)];
545         if (bias)
546             arg.bias = &bias[bias_d.blk_off(cb * ch_block_size)];
547         arg.work_amount = (size_t)W;
548
549         (*kernel_)(&arg);
550     });
551 }
552
553 template struct jit_uni_depthwise_fwd_t<sse42>;
554 template struct jit_uni_depthwise_fwd_t<avx2>;
555 template struct jit_uni_depthwise_fwd_t<avx512_common>;
556
557
558 #define GET_OFF_DW(field) offsetof(jit_conv_call_s, field)
559
560 template <cpu_isa_t isa>
561 void jit_uni_dw_conv_row_f32<isa>::load_src(int ur_w) {
562     int repeats = isa == sse42 ? 2 : 1;
563     for (int i = 0; i < repeats; i++) {
564         for (int ow = 0; ow < ur_w; ow++) {
565             Vmm vmm_acc = get_acc_reg(i*ur_w + ow);
566
567             if (this->jcp.with_bias)
568                 uni_vmovups(vmm_acc, vmmword[reg_bias + i*4*sizeof(float)]);
569             else
570                 uni_vpxor(vmm_acc, vmm_acc, vmm_acc);
571
572             int o_off = ow*jcp.ch_block + i*4;
573             if (this->jcp.with_sum)
574                 uni_vaddps(vmm_acc, vmm_acc,
575                            vmmword[reg_output + o_off*sizeof(float)]);
576         }
577     }
578 }
579
580 template <cpu_isa_t isa>
581 void jit_uni_dw_conv_row_f32<isa>::apply_filter(int ur_w, int kw_size) {
582     int ch_blk = jcp.ch_block;
583     int stride_w = jcp.stride_w;
584
585     Label exit_label;
586
587     int repeats = isa == sse42 ? 2 : 1;
588
589     cmp(reg_kh, 1);
590     jl(exit_label, T_NEAR);
591     for (int i = 0; i < repeats; i++) {
592         for (int kw = 0; kw < kw_size; kw++) {
593             int ker_off = kw * ch_blk + i*4;
594
595             Vmm vmm_ker = get_ker_reg(0);
596             uni_vmovups(vmm_ker, ptr[aux_reg_kernel
597                                      + ker_off * sizeof(float)]);
598
599             for (int ow = 0; ow < ur_w; ow++) {
600                 int inp_off = ow * stride_w * ch_blk + kw * ch_blk + i*4;
601
602                 Vmm vmm_src = get_src_reg(0);
603                 uni_vmovups(vmm_src, ptr[aux_reg_input0
604                                          + inp_off * sizeof(float)]);
605
606                 Vmm vmm_acc = get_acc_reg(i*ur_w + ow);
607                 uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker);
608             }
609         }
610     }
611     add(aux_reg_kernel, jcp.kw*ch_blk*sizeof(float));
612
613     cmp(reg_kh, 2);
614     jl(exit_label, T_NEAR);
615     for (int i = 0; i < repeats; i++) {
616         for (int kw = 0; kw < kw_size; kw++) {
617             int ker_off = kw * ch_blk + i*4;
618
619             Vmm vmm_ker = get_ker_reg(0);
620             uni_vmovups(vmm_ker, ptr[aux_reg_kernel
621                                      + ker_off * sizeof(float)]);
622
623             for (int ow = 0; ow < ur_w; ow++) {
624                 int inp_off = ow * stride_w * ch_blk + kw * ch_blk + i*4;
625
626                 Vmm vmm_src = get_src_reg(0);
627                 uni_vmovups(vmm_src, ptr[aux_reg_input1
628                                          + inp_off * sizeof(float)]);
629
630                 Vmm vmm_acc = get_acc_reg(i*ur_w + ow);
631                 uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker);
632             }
633         }
634     }
635     add(aux_reg_kernel, jcp.kw*ch_blk*sizeof(float));
636
637     cmp(reg_kh, 3);
638     jl(exit_label, T_NEAR);
639     for (int i = 0; i < repeats; i++) {
640         for (int kw = 0; kw < kw_size; kw++) {
641             int ker_off = kw * ch_blk + i*4;
642
643             Vmm vmm_ker = get_ker_reg(0);
644             uni_vmovups(vmm_ker, ptr[aux_reg_kernel
645                                      + ker_off * sizeof(float)]);
646
647             for (int ow = 0; ow < ur_w; ow++) {
648                 int inp_off = ow * stride_w * ch_blk + kw * ch_blk + i*4;
649
650                 Vmm vmm_src = get_src_reg(0);
651                 uni_vmovups(vmm_src, ptr[aux_reg_input2
652                                          + inp_off * sizeof(float)]);
653
654                 Vmm vmm_acc = get_acc_reg(i*ur_w + ow);
655                 uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker);
656             }
657         }
658     }
659
660     L(exit_label);
661 }
662
663 template <cpu_isa_t isa>
664 void jit_uni_dw_conv_row_f32<isa>::apply_activation(int ur_w) {
665     if (this->jcp.with_eltwise) {
666         int repeats = isa == sse42 ? 2 : 1;
667         eltwise_injector->compute_vector_range(4, repeats * ur_w + 4);
668     }
669 }
670
671 template <cpu_isa_t isa>
672 void jit_uni_dw_conv_row_f32<isa>::store_dst(int ur_w) {
673     int repeats = isa == sse42 ? 2 : 1;
674     for (int i = 0; i < repeats; i++) {
675         for (int ow = 0; ow < ur_w; ow++) {
676             int o_off = ow*jcp.ch_block + i*4;
677             Vmm vmm_dst = get_acc_reg(i*ur_w + ow);
678
679             uni_vmovups(vmmword[reg_output + o_off*sizeof(float)], vmm_dst);
680         }
681     }
682 }
683
684 template <cpu_isa_t isa>
685 void jit_uni_dw_conv_row_f32<isa>::loop_body() {
686     Label left_pad_label;
687     Label right_pad_label;
688     Label unrolled_w_label;
689     Label tail_w_label;
690     Label exit_label;
691
692     L(left_pad_label); {
693         int ur_w = 1;
694         int kw = jcp.iw == 1 ? jcp.kw - 2 : jcp.kw - 1;
695
696         mov(aux_reg_input0, reg_input0);
697         mov(aux_reg_input1, reg_input1);
698         mov(aux_reg_input2, reg_input2);
699         mov(aux_reg_kernel, reg_kernel);
700         add(aux_reg_kernel, jcp.ch_block*sizeof(float));
701
702         load_src(ur_w);
703         apply_filter(ur_w, kw);
704         apply_activation(ur_w);
705         store_dst(ur_w);
706
707         add(reg_input0, sizeof(float) * ur_w * jcp.ch_block * (jcp.stride_w-1));
708         add(reg_input1, sizeof(float) * ur_w * jcp.ch_block * (jcp.stride_w-1));
709         add(reg_input2, sizeof(float) * ur_w * jcp.ch_block * (jcp.stride_w-1));
710
711         add(reg_output, sizeof(float) * ur_w * jcp.ch_block);
712
713         sub(reg_ur_w, ur_w);
714     }
715
716     L(unrolled_w_label); {
717         int ur_w = jcp.ur_w;
718         int kw = jcp.kw;
719
720         cmp(reg_ur_w, ur_w);
721         jle(tail_w_label, T_NEAR);
722
723         mov(aux_reg_input0, reg_input0);
724         mov(aux_reg_input1, reg_input1);
725         mov(aux_reg_input2, reg_input2);
726         mov(aux_reg_kernel, reg_kernel);
727
728         load_src(ur_w);
729         apply_filter(ur_w, kw);
730         apply_activation(ur_w);
731         store_dst(ur_w);
732
733         add(reg_input0, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w);
734         add(reg_input1, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w);
735         add(reg_input2, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w);
736         add(reg_output, sizeof(float) * ur_w * jcp.ch_block);
737
738         sub(reg_ur_w, ur_w);
739         jmp(unrolled_w_label, T_NEAR);
740     }
741
742     L(tail_w_label); {
743         int ur_w = 1;
744         int kw = jcp.kw;
745
746         cmp(reg_ur_w, ur_w);
747         if (jcp.ow > 1)
748             jle(right_pad_label, T_NEAR);
749         else
750             jle(exit_label, T_NEAR);
751
752         mov(aux_reg_input0, reg_input0);
753         mov(aux_reg_input1, reg_input1);
754         mov(aux_reg_input2, reg_input2);
755         mov(aux_reg_kernel, reg_kernel);
756
757         load_src(ur_w);
758         apply_filter(ur_w, kw);
759         apply_activation(ur_w);
760         store_dst(ur_w);
761
762         add(reg_input0, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w);
763         add(reg_input1, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w);
764         add(reg_input2, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w);
765         add(reg_output, sizeof(float) * ur_w * jcp.ch_block);
766
767         sub(reg_ur_w, ur_w);
768         jmp(tail_w_label, T_NEAR);
769     }
770
771     if (jcp.ow > 1) {
772         L(right_pad_label); {
773             int ur_w = 1;
774             int kw = jcp.kw - ((jcp.stride_w == 1) ? 1 : jcp.iw % jcp.stride_w);
775
776             mov(aux_reg_input0, reg_input0);
777             mov(aux_reg_input1, reg_input1);
778             mov(aux_reg_input2, reg_input2);
779             mov(aux_reg_kernel, reg_kernel);
780
781             load_src(ur_w);
782             apply_filter(ur_w, kw);
783             apply_activation(ur_w);
784             store_dst(ur_w);
785
786             sub(reg_ur_w, ur_w);
787         }
788     }
789
790     L(exit_label);
791 }
792
793 template <cpu_isa_t isa>
794 void jit_uni_dw_conv_row_f32<isa>::generate()
795 {
796     this->preamble();
797
798     mov(reg_input0, ptr[this->param1 + GET_OFF_DW(src_row0)]);
799     mov(reg_input1, ptr[this->param1 + GET_OFF_DW(src_row1)]);
800     mov(reg_input2, ptr[this->param1 + GET_OFF_DW(src_row2)]);
801     mov(reg_output, ptr[this->param1 + GET_OFF_DW(dst)]);
802     mov(reg_kernel, ptr[this->param1 + GET_OFF_DW(filt)]);
803     if (jcp.with_bias)
804         mov(reg_bias, ptr[this->param1 + GET_OFF_DW(bias)]);
805     mov(reg_kh, ptr[this->param1 + GET_OFF_DW(kh_padding)]);
806     mov(reg_ur_w, ptr[this->param1 + GET_OFF_DW(ur_w)]);
807
808     loop_body();
809
810     this->postamble();
811
812     if (jcp.with_eltwise)
813         eltwise_injector->prepare_table();
814 }
815
816 template <cpu_isa_t isa>
817 status_t jit_uni_dw_conv_row_f32<isa>::init_conf(jit_conv_conf_t &jcp,
818         int ic, int ih, int iw, int oh, int ow, int ker_h, int ker_w, int str_h, int str_w, alg_kind_t eltwise_alg,
819         float eltwise_alpha, float eltwise_beta, bool with_sum) {
820     if (!mayiuse(isa)) return status::unimplemented;
821     const int simd_w = isa == avx512_common ? 16 : 8;
822
823     jcp.kh = ker_h;
824     jcp.kw = ker_w;
825     jcp.ch_block = simd_w;
826     jcp.with_bias = true;
827     jcp.ic = ic;
828     jcp.oc = ic;
829     jcp.ih = ih;
830     jcp.iw = iw;
831     jcp.oh = oh;
832     jcp.ow = ow;
833     jcp.stride_h = str_h;
834     jcp.stride_w = str_w;
835
836     if (jcp.kh != 3 || jcp.kw != 3)
837         return  status::unimplemented;
838
839     jcp.ur_w = 4;
840
841     jcp.with_eltwise  = eltwise_alg != mkldnn_alg_kind_undef;
842     jcp.eltwise_alg   = eltwise_alg;
843     jcp.eltwise_alpha = eltwise_alpha;
844     jcp.eltwise_beta  = eltwise_beta;
845     jcp.with_sum = with_sum;
846
847     return status::success;
848 }
849
850 template struct jit_uni_dw_conv_row_f32<avx512_common>;
851 template struct jit_uni_dw_conv_row_f32<avx2>;
852 template struct jit_uni_dw_conv_row_f32<sse42>;
853
854 }
855 }
856 }