Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_depthwise.cpp
1 /*******************************************************************************
2 * Copyright 2018-2019 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <mkldnn_types.h>
18 #include "mkldnn_types.h"
19 #include "mkldnn_thread.hpp"
20 #include "nstl.hpp"
21 #include "utils.hpp"
22 #include "jit_generator.hpp"
23
24 #include "jit_uni_depthwise.hpp"
25
26 #define GET_OFF(field) offsetof(jit_args, field)
27
28 namespace mkldnn {
29 namespace impl {
30 namespace cpu {
31
32 using namespace Xbyak;
33 using namespace mkldnn::impl::memory_format;
34 using namespace mkldnn::impl::utils;
35
36 struct jit_args {
37     const float *from;
38     const float *to;
39     const float *weights;
40     const float *bias;
41     size_t work_amount;
42 };
43
44 struct jit_uni_depthwise_kernel_f32 : public c_compatible {
45     const depthwise_desc_t &desc_;
46     void (*ker_)(const jit_args *);
47     bool with_bias_;
48
49     void operator()(const jit_args *args) { assert(ker_); ker_(args); }
50
51     jit_uni_depthwise_kernel_f32(const depthwise_desc_t &desc, bool with_bias)
52         : desc_(desc), ker_(nullptr), with_bias_(with_bias) {}
53     virtual ~jit_uni_depthwise_kernel_f32() {}
54 };
55
56 template <cpu_isa_t isa>
57 int jit_uni_depthwise_injector_f32<isa>::aux_vecs_count(alg_kind_t depthwise_alg) {
58     switch (depthwise_alg) {
59         case alg_kind::depthwise_scale_shift: return isa == sse42 ? 1 : 0;
60         case alg_kind::depthwise_prelu: return 2;
61         default: assert(!"unsupported depthwise algorithm");
62     }
63
64     return 0;
65 }
66
67 template <cpu_isa_t isa>
68 void jit_uni_depthwise_injector_f32<isa>::injector_preamble(size_t start_idx, size_t end_idx) {
69     preserved_vecs_count = 0;
70     vecs_to_preserve = (size_t)jit_uni_depthwise_injector_f32<isa>::aux_vecs_count(depthwise_alg);
71
72     for (size_t i = 0; i < vecs_count; i++) {
73         if (preserved_vecs_count >= vecs_to_preserve)
74             break;
75
76         if (i < start_idx || i >= end_idx) {
77             preserved_vec_idxs[preserved_vecs_count] = i;
78             preserved_vecs_count++;
79         }
80     }
81
82     start_idx_tail = start_idx;
83     size_t preserved_vecs_count_tail = vecs_to_preserve - preserved_vecs_count;
84     for (size_t i = 0; i < preserved_vecs_count_tail; i++) {
85         preserved_vec_idxs[preserved_vecs_count] = start_idx + i;
86         preserved_vecs_count++;
87         start_idx_tail = start_idx + i + 1;
88     }
89
90     h->sub(h->rsp, preserved_vecs_count * vlen);
91     for (size_t i = 0; i < preserved_vecs_count; ++i)
92         h->uni_vmovups(h->ptr[h->rsp + i * vlen], Vmm(preserved_vec_idxs[i]));
93
94     assign_regs();
95 }
96
97 template <cpu_isa_t isa>
98 void jit_uni_depthwise_injector_f32<isa>::injector_preamble_tail(size_t start_idx, size_t end_idx) {
99     size_t tail_vecs_to_preserve = start_idx_tail - start_idx;
100     int idx_off = (vecs_to_preserve - tail_vecs_to_preserve);
101
102     if (tail_vecs_to_preserve > 0) {
103         h->add(h->rsp, idx_off * vlen);
104         for (size_t i = 0; i < tail_vecs_to_preserve; ++i)
105             h->uni_vmovups(Vmm(preserved_vec_idxs[idx_off + i]), h->ptr[h->rsp + i * vlen]);
106
107         for (size_t i = 0; i < tail_vecs_to_preserve; ++i) {
108             preserved_vec_idxs[idx_off + i] += tail_vecs_to_preserve;
109         }
110
111         for (size_t i = 0; i < tail_vecs_to_preserve; ++i)
112             h->uni_vmovups(h->ptr[h->rsp + i * vlen], Vmm(preserved_vec_idxs[idx_off + i]));
113         h->sub(h->rsp, idx_off * vlen);
114
115         assign_regs();
116     }
117 }
118
119 template <cpu_isa_t isa>
120 void jit_uni_depthwise_injector_f32<isa>::injector_postamble() {
121     for (size_t i = 0; i < preserved_vecs_count; ++i)
122         h->uni_vmovups(Vmm(preserved_vec_idxs[i]), h->ptr[h->rsp + i * vlen]);
123     h->add(h->rsp, preserved_vecs_count * vlen);
124 }
125
126 template <cpu_isa_t isa>
127 void jit_uni_depthwise_injector_f32<isa>::assign_regs() {
128     vmm_mask = Vmm(preserved_vec_idxs[0]);
129     vmm_aux0 = Vmm(preserved_vec_idxs[1]);
130 }
131
132 template <cpu_isa_t isa>
133 void jit_uni_depthwise_injector_f32<isa>::scale_shift_compute_vector(const Vmm &vmm_src,
134         const Xbyak::Reg64& p_weights, const Xbyak::Reg64& p_bias) {
135     if (isa == sse42) {
136         h->movups(vmm_mask, h->ptr[p_weights]);
137         h->mulps(vmm_src, vmm_mask);
138         h->movups(vmm_mask, h->ptr[p_bias]);
139         h->addps(vmm_src, vmm_mask);
140     } else {
141         h->uni_vmulps(vmm_src, vmm_src, h->ptr[p_weights]);
142         h->uni_vaddps(vmm_src, vmm_src, h->ptr[p_bias]);
143     };
144 }
145
146 template <cpu_isa_t isa>
147 void jit_uni_depthwise_injector_f32<isa>::prelu_compute_vector(const Vmm &vmm_src,
148         const Xbyak::Reg64& p_weights, const Xbyak::Reg64& p_bias) {
149     const unsigned char _cmp_gt_os = 6;
150     const unsigned char _cmp_lt_os = 1;
151
152     if (isa == sse42) {
153         h->pxor(vmm_mask, vmm_mask);
154         h->cmpps(vmm_mask, vmm_src, _cmp_gt_os);
155         h->movups(vmm_aux0, h->ptr[p_weights]);
156         h->mulps(vmm_aux0, vmm_src);
157         h->blendvps(vmm_src, vmm_aux0);
158     } else if (isa == avx2) {
159         h->vxorps(vmm_mask, vmm_mask, vmm_mask);
160         h->vcmpgtps(vmm_mask, vmm_src, vmm_mask);
161         h->vmulps(vmm_aux0, vmm_src, h->ptr[p_weights]);
162         h->vblendvps(vmm_src, vmm_aux0, vmm_src, vmm_mask);
163     } else if (isa == avx512_common) {
164         h->vxorpd(vmm_mask, vmm_mask, vmm_mask);
165         h->vmovups(vmm_aux0, vmm_src);
166         h->vcmpps(k_mask, vmm_src, vmm_mask, _cmp_lt_os);
167         h->vmulps(vmm_src | k_mask, vmm_aux0, h->ptr[p_weights]);
168     }
169 }
170
171 template <cpu_isa_t isa>
172 void jit_uni_depthwise_injector_f32<isa>::compute_body(size_t start_idx, size_t end_idx,
173         const Xbyak::Reg64& p_weights, const Xbyak::Reg64& p_bias) {
174     for (size_t idx = start_idx; idx < end_idx; idx++) {
175         switch (depthwise_alg) {
176             case alg_kind::depthwise_scale_shift:
177                 scale_shift_compute_vector(Vmm(idx), p_weights, p_bias); break;
178             case alg_kind::depthwise_prelu:
179                 prelu_compute_vector(Vmm(idx), p_weights, p_bias); break;
180             default: assert(!"unsupported depthwise algorithm");
181         }
182     }
183 }
184
185 template <cpu_isa_t isa>
186 void jit_uni_depthwise_injector_f32<isa>::compute_vector_range(int start_idx, int end_idx,
187         const Xbyak::Reg64& p_weights, const Xbyak::Reg64& p_bias) {
188     injector_preamble(start_idx, end_idx);
189     compute_body(start_idx_tail, end_idx, p_weights, p_bias);
190     injector_preamble_tail(start_idx, end_idx);
191     compute_body(start_idx, start_idx_tail, p_weights, p_bias);
192     injector_postamble();
193 }
194
195 template struct jit_uni_depthwise_injector_f32<avx512_common>;
196 template struct jit_uni_depthwise_injector_f32<avx2>;
197 template struct jit_uni_depthwise_injector_f32<sse42>;
198
199 /* jit kernels */
200 namespace {
201
202 template <cpu_isa_t isa>
203 struct jit_uni_scale_shift_kernel_f32 : public jit_uni_depthwise_kernel_f32,
204     public jit_generator
205 {
206     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_scale_shift_kernel_f32)
207     jit_uni_scale_shift_kernel_f32(const depthwise_desc_t &desc, bool with_bias)
208         : jit_uni_depthwise_kernel_f32(desc, with_bias), jit_generator() {
209         assert(desc.alg_kind == alg_kind::depthwise_scale_shift);
210         assert(isa == sse42 || isa == avx2 || isa == avx512_common);
211
212         bool isFlat = desc.src_desc.format == nchw && desc.dst_desc.format == nchw;
213
214         Reg64 param = abi_param1;
215
216         const int block_size = isa == avx512_common ? 16 : 8;
217         const int main_loop_step = isFlat ? block_size : 1;
218
219         this->preamble();
220
221         mov(reg_from, ptr[param + GET_OFF(from)]);
222         mov(reg_to, ptr[param + GET_OFF(to)]);
223         mov(reg_scale, ptr[param + GET_OFF(weights)]);
224         mov(reg_work_amount, ptr[param + GET_OFF(work_amount)]);
225         if (with_bias_)
226             mov(reg_shift, ptr[param + GET_OFF(bias)]);
227
228         Label main_loop_label;
229         Label tail_loop_label;
230         Label exit_label;
231
232         int repeats = isa == sse42 ? 2 : 1;
233         for (int i = 0; i < repeats; i++) {
234             if (isFlat) {
235                 uni_vbroadcastss(get_scale_reg(i), ptr[reg_scale]);
236                 if (with_bias_)
237                     uni_vbroadcastss(get_shift_reg(i), ptr[reg_shift]);
238                 else
239                     uni_vpxor(get_shift_reg(i), get_shift_reg(i), get_shift_reg(i));
240             } else {
241                 uni_vmovups(get_scale_reg(i), ptr[reg_scale + i*4*sizeof(float)]);
242                 if (with_bias_)
243                     uni_vmovups(get_shift_reg(i), ptr[reg_shift + i*4*sizeof(float)]);
244                 else
245                     uni_vpxor(get_shift_reg(i), get_shift_reg(i), get_shift_reg(i));
246             }
247         }
248
249         if (isFlat) {
250             uni_vbroadcastss(xmm_scale, ptr[reg_scale]);
251             if (with_bias_)
252                 uni_vbroadcastss(xmm_shift, ptr[reg_shift]);
253             else
254                 uni_vpxor(xmm_shift, xmm_shift, xmm_shift);
255         }
256
257         L(main_loop_label); {
258             cmp(reg_work_amount, main_loop_step-1);
259             jle(tail_loop_label, T_NEAR);
260
261             int repeats = isa == sse42 ? 2 : 1;
262             for (int i = 0; i < repeats; i++) {
263                 uni_vmovups(vmm_src, ptr[reg_from + i*4*sizeof(float)]);
264                 uni_vmovups(vmm_dst, get_shift_reg(i));
265                 uni_vfmadd231ps(vmm_dst, vmm_src, get_scale_reg(i));
266                 uni_vmovups(ptr[reg_to + i*4*sizeof(float)], vmm_dst);
267             }
268
269             add(reg_from, block_size*sizeof(float));
270             add(reg_to, block_size*sizeof(float));
271             sub(reg_work_amount, main_loop_step);
272
273             jmp(main_loop_label, T_NEAR);
274         }
275
276         L(tail_loop_label); {
277             cmp(reg_work_amount, 0);
278             jle(exit_label, T_NEAR);
279
280             movss(xmm_src, ptr[reg_from]);
281             uni_vmovups(xmm_dst, xmm_shift);
282             uni_vfmadd231ps(xmm_dst, xmm_src, xmm_scale);
283             movss(ptr[reg_to], xmm_dst);
284
285             add(reg_from, 1*sizeof(float));
286             add(reg_to, 1*sizeof(float));
287             dec(reg_work_amount);
288
289             jmp(tail_loop_label, T_NEAR);
290         }
291
292         L(exit_label);
293
294         this->postamble();
295
296         ker_ = (decltype(ker_))this->getCode();
297     }
298
299 private:
300     using Vmm = typename utils::conditional3<isa == sse42, Xmm,
301                                              isa == avx2, Ymm, Zmm>::type;
302
303     inline Vmm get_scale_reg(int idx) { return Vmm(idx + 2); }
304     inline Vmm get_shift_reg(int idx) { return Vmm(idx + 4); }
305
306     Reg64 reg_from = r8;
307     Reg64 reg_to = r9;
308     Reg64 reg_work_amount = r10;
309     Reg64 reg_scale = r11;
310     Reg64 reg_shift = r12;
311
312     Vmm vmm_src = Vmm(0);
313     Vmm vmm_dst = Vmm(1);
314
315     Xmm xmm_src = Xmm(0);
316     Xmm xmm_dst = Xmm(1);
317     Xmm xmm_scale = Xmm(6);
318     Xmm xmm_shift = Xmm(7);
319 };
320
321 template <cpu_isa_t isa>
322 struct jit_uni_prelu_kernel_f32 : public jit_uni_depthwise_kernel_f32,
323     public jit_generator
324 {
325     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_prelu_kernel_f32)
326     jit_uni_prelu_kernel_f32(const depthwise_desc_t &desc, bool with_bias)
327         : jit_uni_depthwise_kernel_f32(desc, with_bias), jit_generator() {
328         assert(desc.alg_kind == alg_kind::depthwise_prelu);
329         assert(isa == sse42 || isa == avx2 || isa == avx512_common);
330
331         bool isFlat = desc.src_desc.format == nchw && desc.dst_desc.format == nchw;
332
333         Reg64 param = abi_param1;
334
335         const int block_size = isa == avx512_common ? 16 : 8;
336         const int main_loop_step = isFlat ? block_size : 1;
337
338         this->preamble();
339
340         mov(reg_from, ptr[param + GET_OFF(from)]);
341         mov(reg_to, ptr[param + GET_OFF(to)]);
342         mov(reg_scale, ptr[param + GET_OFF(weights)]);
343         mov(reg_work_amount, ptr[param + GET_OFF(work_amount)]);
344
345         uni_vpxor(vmm_zero, vmm_zero, vmm_zero);
346
347         int repeats = isa == sse42 ? 2 : 1;
348         for (int i = 0; i < repeats; i++) {
349             if (isFlat) {
350                 uni_vbroadcastss(get_scale_reg(i), ptr[reg_scale]);
351             } else {
352                 uni_vmovups(get_scale_reg(i), ptr[reg_scale + i*4*sizeof(float)]);
353             }
354         }
355
356         if (isFlat) {
357             uni_vbroadcastss(xmm_scale, ptr[reg_scale]);
358         }
359
360         Label main_loop_label;
361         Label tail_loop_label;
362         Label exit_label;
363
364         L(main_loop_label); {
365             cmp(reg_work_amount, main_loop_step-1);
366             jle(tail_loop_label, T_NEAR);
367
368             for (int i = 0; i < repeats; i++) {
369                 uni_vmovups(vmm_src, ptr[reg_from + i*4*sizeof(float)]);
370
371                 if (isa == sse42) {
372                     pxor(vmm_mask, vmm_mask);
373                     cmpps(vmm_mask, vmm_src, _cmp_gt_os);
374                     movups(vmm_dst, vmm_src);
375                     mulps(vmm_src, get_scale_reg(i));
376                     blendvps(vmm_dst, vmm_src);
377                 } else if (isa == avx2) {
378                     vcmpgtps(vmm_mask, vmm_src, vmm_zero);
379                     vmulps(vmm_dst, vmm_src, get_scale_reg(i));
380                     vblendvps(vmm_dst, vmm_dst, vmm_src, vmm_mask);
381                 } else if (isa == avx512_common) {
382                     Opmask kmask = Opmask(7);
383                     vmovups(vmm_dst, vmm_src);
384                     vcmpps(kmask, vmm_src, vmm_zero, _cmp_lt_os);
385                     vmulps(vmm_dst | kmask, vmm_src, get_scale_reg(i));
386                 }
387
388                 uni_vmovups(ptr[reg_to + i*4*sizeof(float)], vmm_dst);
389             }
390
391             add(reg_from, block_size*sizeof(float));
392             add(reg_to, block_size*sizeof(float));
393             sub(reg_work_amount, main_loop_step);
394
395             jmp(main_loop_label, T_NEAR);
396         }
397
398         L(tail_loop_label); {
399             cmp(reg_work_amount, 0);
400             jle(exit_label, T_NEAR);
401
402             movss(xmm_src, ptr[reg_from]);
403
404             pxor(xmm_mask, xmm_mask);
405             cmpps(xmm_mask, xmm_src, _cmp_gt_os);
406             movups(xmm_dst, xmm_src);
407             mulps(xmm_src, xmm_scale);
408             blendvps(xmm_dst, xmm_src);
409
410             movss(ptr[reg_to], xmm_dst);
411
412             add(reg_from, 1*sizeof(float));
413             add(reg_to, 1*sizeof(float));
414             dec(reg_work_amount);
415
416             jmp(tail_loop_label, T_NEAR);
417         }
418
419         L(exit_label);
420
421         this->postamble();
422
423         ker_ = (decltype(ker_))this->getCode();
424     }
425
426 private:
427     using Vmm = typename utils::conditional3<isa == sse42, Xmm,
428                                              isa == avx2, Ymm, Zmm>::type;
429
430     inline Vmm get_scale_reg(int idx) { return Vmm(idx + 4); }
431
432     Reg64 reg_from = r8;
433     Reg64 reg_to = r9;
434     Reg64 reg_work_amount = r10;
435     Reg64 reg_scale = r11;
436
437     Vmm vmm_mask = Vmm(0);
438     Vmm vmm_src = Vmm(1);
439     Vmm vmm_zero = Vmm(2);
440     Vmm vmm_dst = Vmm(3);
441
442     Xmm xmm_mask = Xmm(0);
443     Xmm xmm_src = Xmm(1);
444     Xmm xmm_dst = Xmm(3);
445     Xmm xmm_scale = Xmm(4);
446
447     const unsigned char _cmp_gt_os = 6;
448     const unsigned char _cmp_lt_os = 1;
449 };
450
451 } /* namespace */
452
453 template <cpu_isa_t isa>
454 status_t jit_uni_depthwise_fwd_t<isa>::pd_t::init() {
455     using namespace alg_kind;
456
457     auto desired_blk_fmt = isa == avx512_common ? nChw16c : nChw8c;
458
459     assert(engine()->kind() == engine_kind::cpu);
460     bool ok = true && mayiuse(isa)
461         && utils::one_of(desc()->prop_kind, prop_kind::forward_training,
462                 prop_kind::forward_inference)
463         && utils::everyone_is(data_type::f32, desc()->src_desc.data_type, desc()->dst_desc.data_type)
464         && desc()->src_desc.format == desc()->dst_desc.format
465         && utils::one_of(desc()->src_desc.format, desired_blk_fmt, nchw)
466         && utils::one_of(desc()->dst_desc.format, desired_blk_fmt, nchw)
467         && utils::one_of(desc()->weights_desc.format, x)
468         && IMPLICATION(this->with_bias(), x == desc()->bias_desc.format)
469         && attr()->has_default_values();
470
471     return ok ? status::success : status::unimplemented;
472 }
473
474 template <cpu_isa_t isa>
475 jit_uni_depthwise_fwd_t<isa>::jit_uni_depthwise_fwd_t(const pd_t *apd,
476         const input_vector &inputs, const output_vector &outputs)
477     : cpu_primitive_t(apd, inputs, outputs), kernel_(nullptr),
478       padded_weights_(nullptr), padded_bias_(nullptr) {
479     const auto &desc = *pd()->desc();
480     switch (desc.alg_kind) {
481         case alg_kind::depthwise_scale_shift:
482             kernel_ = new jit_uni_scale_shift_kernel_f32<isa>(desc, pd()->with_bias()); break;
483         case alg_kind::depthwise_prelu:
484             kernel_ = new jit_uni_prelu_kernel_f32<isa>(desc, pd()->with_bias()); break;
485         default: assert(!"unknown depthwise alg_kind");
486     }
487
488     const int simd_w = isa == avx512_common ? 16 : 8;
489     const memory_desc_wrapper data_d(pd()->src_pd());
490     const int c_without_padding = data_d.dims()[1];
491     const int c_padded = rnd_up(c_without_padding, simd_w);
492
493     if (pd()->want_padded_weights()) {
494         padded_weights_ = (data_t *)malloc(sizeof(data_t) * c_padded, 64);
495         for (int oc = c_without_padding; oc < c_padded; ++oc)
496             padded_weights_[oc] = 0;
497
498         if (pd()->with_bias()) {
499             padded_bias_ = (data_t *)malloc(sizeof(data_t) * c_padded, 64);
500             for (int oc = c_without_padding; oc < c_padded; ++oc)
501                 padded_bias_[oc] = 0;
502         }
503     }
504 }
505
506 template <cpu_isa_t isa>
507 jit_uni_depthwise_fwd_t<isa>::~jit_uni_depthwise_fwd_t() {
508     delete kernel_;
509     free(padded_weights_);
510     free(padded_bias_);
511 }
512
513 template <cpu_isa_t isa>
514 void jit_uni_depthwise_fwd_t<isa>::execute_forward() const {
515     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
516     auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
517     auto bias = reinterpret_cast<const data_t *>(this->input_memory(2));
518     auto dst = reinterpret_cast<data_t *>(this->memory());
519
520     const memory_desc_wrapper data_d(pd()->src_pd());
521     const memory_desc_wrapper weights_d(pd()->weights_pd(0));
522     const memory_desc_wrapper bias_d(pd()->weights_pd(1));
523
524     const int N = data_d.dims()[0];
525     const int C = data_d.dims()[1];
526     const int H = data_d.dims()[2];
527     const int W = data_d.dims()[3];
528
529     const int simd_w = isa == avx512_common ? 16 : 8;
530     const int ch_block_size = data_d.format() == nchw ? 1 : simd_w;
531     const int CB = div_up(C, ch_block_size);
532
533     if (pd()->want_padded_weights()) {
534         for (int oc = 0; oc < C; ++oc)
535             padded_weights_[oc] = weights[oc];
536         weights = padded_weights_;
537
538         if (pd()->with_bias()) {
539             for (int oc = 0; oc < C; ++oc)
540                 padded_bias_[oc] = bias[oc];
541             bias = padded_bias_;
542         }
543     }
544
545     parallel_nd(N, CB, H,
546         [&](int n, int cb, int h) {
547         auto arg = jit_args();
548
549         arg.from    = &src[data_d.blk_off(n, cb, h)];
550         arg.to      = &dst[data_d.blk_off(n, cb, h)];
551         arg.weights = &weights[weights_d.blk_off(cb * ch_block_size)];
552         if (bias)
553             arg.bias = &bias[bias_d.blk_off(cb * ch_block_size)];
554         arg.work_amount = (size_t)W;
555
556         (*kernel_)(&arg);
557     });
558 }
559
560 template struct jit_uni_depthwise_fwd_t<sse42>;
561 template struct jit_uni_depthwise_fwd_t<avx2>;
562 template struct jit_uni_depthwise_fwd_t<avx512_common>;
563
564
565 #define GET_OFF_DW(field) offsetof(jit_conv_call_s, field)
566
567 template <cpu_isa_t isa>
568 void jit_uni_dw_conv_row_f32<isa>::load_src(int ur_w) {
569     int repeats = isa == sse42 ? 2 : 1;
570     for (int i = 0; i < repeats; i++) {
571         for (int ow = 0; ow < ur_w; ow++) {
572             Vmm vmm_acc = get_acc_reg(i*ur_w + ow);
573
574             uni_vpxor(vmm_acc, vmm_acc, vmm_acc);
575         }
576     }
577 }
578
579 template <cpu_isa_t isa>
580 void jit_uni_dw_conv_row_f32<isa>::apply_filter(int ur_w, int kw_size) {
581     auto load_src = [=](Vmm vmm_src, const Xbyak::Address &op) {
582         if (jcp.src_dt == data_type::u8) {
583             uni_vpmovzxbd(vmm_src, op);
584         } else {
585             uni_vmovups(vmm_src, op);
586         }
587     };
588
589     auto load_ker = [=](Vmm vmm_ker, const Xbyak::Address &op) {
590         if (jcp.src_dt == data_type::u8) {
591             uni_vpmovsxbd(vmm_ker, op);
592         } else {
593             uni_vmovups(vmm_ker, op);
594         }
595     };
596
597     auto compute = [=](Vmm vmm_acc, Vmm vmm_src, Vmm vmm_ker) {
598         if (jcp.src_dt == data_type::u8) {
599             uni_vpmulld(vmm_src, vmm_src, vmm_ker);
600             uni_vpaddd(vmm_acc, vmm_acc, vmm_src);
601         } else {
602             uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker);
603         }
604     };
605
606     int ch_blk = jcp.ch_block;
607     int stride_w = jcp.stride_w;
608
609     Label exit_label;
610
611     int repeats = isa == sse42 ? 2 : 1;
612
613     cmp(reg_kh, 1);
614     jl(exit_label, T_NEAR);
615     for (int i = 0; i < repeats; i++) {
616         for (int kw = 0; kw < kw_size; kw++) {
617             int ker_off = kw * ch_blk + i*(jcp.ch_block / 2);
618
619             Vmm vmm_ker = get_ker_reg(0);
620             load_ker(vmm_ker, ptr[aux_reg_kernel + ker_off * jcp.typesize_in]);
621
622             for (int ow = 0; ow < ur_w; ow++) {
623                 int inp_off = ow * stride_w * ch_blk + kw * ch_blk + i*(jcp.ch_block / 2);
624
625                 Vmm vmm_src = get_src_reg(0);
626                 load_src(vmm_src, ptr[aux_reg_input0 + inp_off * jcp.typesize_in]);
627
628                 Vmm vmm_acc = get_acc_reg(i*ur_w + ow);
629                 compute(vmm_acc, vmm_src, vmm_ker);
630             }
631         }
632     }
633     add(aux_reg_kernel, jcp.kw*ch_blk*jcp.typesize_in);
634
635     cmp(reg_kh, 2);
636     jl(exit_label, T_NEAR);
637     for (int i = 0; i < repeats; i++) {
638         for (int kw = 0; kw < kw_size; kw++) {
639             int ker_off = kw * ch_blk + i*(jcp.ch_block / 2);
640
641             Vmm vmm_ker = get_ker_reg(0);
642             load_ker(vmm_ker, ptr[aux_reg_kernel + ker_off * jcp.typesize_in]);
643
644             for (int ow = 0; ow < ur_w; ow++) {
645                 int inp_off = ow * stride_w * ch_blk + kw * ch_blk + i*(jcp.ch_block / 2);
646
647                 Vmm vmm_src = get_src_reg(0);
648                 load_src(vmm_src, ptr[aux_reg_input1 + inp_off * jcp.typesize_in]);
649
650                 Vmm vmm_acc = get_acc_reg(i*ur_w + ow);
651                 compute(vmm_acc, vmm_src, vmm_ker);
652             }
653         }
654     }
655     add(aux_reg_kernel, jcp.kw*ch_blk*jcp.typesize_in);
656
657     cmp(reg_kh, 3);
658     jl(exit_label, T_NEAR);
659     for (int i = 0; i < repeats; i++) {
660         for (int kw = 0; kw < kw_size; kw++) {
661             int ker_off = kw * ch_blk + i*(jcp.ch_block / 2);
662
663             Vmm vmm_ker = get_ker_reg(0);
664             load_ker(vmm_ker, ptr[aux_reg_kernel + ker_off * jcp.typesize_in]);
665
666             for (int ow = 0; ow < ur_w; ow++) {
667                 int inp_off = ow * stride_w * ch_blk + kw * ch_blk + i*(jcp.ch_block / 2);
668
669                 Vmm vmm_src = get_src_reg(0);
670                 load_src(vmm_src, ptr[aux_reg_input2 + inp_off * jcp.typesize_in]);
671
672                 Vmm vmm_acc = get_acc_reg(i*ur_w + ow);
673                 compute(vmm_acc, vmm_src, vmm_ker);
674             }
675         }
676     }
677
678     L(exit_label);
679 }
680
681 template <cpu_isa_t isa>
682 void jit_uni_dw_conv_row_f32<isa>::cvt2ps(data_type_t type_in, Vmm vmm_in, const Operand &op, bool scalar_load) {
683     Xmm xmm_in = Xmm(vmm_in.getIdx());
684
685     switch (type_in) {
686         case data_type::f32:
687         case data_type::s32:
688             if (scalar_load) {
689                 mov(reg_tmp_32, op);
690                 movq(xmm_in, reg_tmp_64);
691             } else {
692                 uni_vmovups(vmm_in, op);
693             }
694             break;
695         case data_type::s8:
696             if (scalar_load) {
697                 movsx(reg_tmp_32, op);
698                 movq(xmm_in, reg_tmp_64);
699             } else {
700                 uni_vpmovsxbd(vmm_in, op);
701             }
702             break;
703         case data_type::u8:
704             if (scalar_load) {
705                 movzx(reg_tmp_32, op);
706                 movq(xmm_in, reg_tmp_64);
707             } else {
708                 uni_vpmovzxbd(vmm_in, op);
709             }
710             break;
711         default: assert(!"unsupported data type");
712     }
713
714     if (type_in != data_type::f32)
715         uni_vcvtdq2ps(vmm_in, vmm_in);
716 }
717
718 template <cpu_isa_t isa>
719 void jit_uni_dw_conv_row_f32<isa>::apply_postprocessing(int ur_w, int oc_step) {
720     int repeats = isa == sse42 ? 2 : 1;
721
722     for (int r = 0; r < repeats; r++) {
723         for (int ow = 0; ow < ur_w; ow++) {
724             if (jcp.src_dt == data_type::u8) {
725                 uni_vcvtdq2ps(get_acc_reg(r * ur_w + ow), get_acc_reg(r * ur_w + ow));
726             }
727
728             if (jcp.with_bias) {
729                 int b_off = r * (jcp.ch_block / 2);
730                 cvt2ps(jcp.bia_dt, vmm_bias, ptr[reg_bias + b_off * jcp.typesize_bia], false);
731                 uni_vaddps(get_acc_reg(r * ur_w + ow), get_acc_reg(r * ur_w + ow), vmm_bias);
732             }
733         }
734     }
735
736     if (jcp.with_sum) {
737         for (int r = 0; r < repeats; r++) {
738             int tail_size = isa == sse42 ? nstl::min(jcp.ch_block / 2, oc_step - r * jcp.ch_block / 2) : oc_step;
739             bool is_scalar_store = isa == sse42 ? tail_size < jcp.ch_block / 2 : tail_size < jcp.ch_block;
740
741             for (int ow = 0; ow < ur_w; ow++) {
742                 if (is_scalar_store) {
743                     for (int oc = 0; oc < tail_size; oc++) {
744                         int o_off = ow * ow_stride_ + r * (jcp.ch_block / 2) + oc;
745
746                         uni_vpxor(vmm_sum, vmm_sum, vmm_sum);
747                         cvt2ps(jcp.dst_dt, vmm_sum, ptr[reg_output + o_off * jcp.typesize_out], true);
748
749                         if (oc >= jcp.ch_block / 2) {
750                             vperm2i128(Ymm(vmm_sum.getIdx()), Ymm(vmm_sum.getIdx()), Ymm(vmm_sum.getIdx()), 0x01);
751                         }
752                         uni_vpslldq(vmm_sum, vmm_sum, jcp.typesize_out * (oc % (jcp.ch_block / 2)));
753
754                         uni_vaddps(get_acc_reg(r * ur_w + ow), get_acc_reg(r * ur_w + ow), vmm_sum);
755                     }
756                 } else {
757                     int o_off = ow * ow_stride_ + r * (jcp.ch_block / 2);
758
759                     uni_vpxor(vmm_sum, vmm_sum, vmm_sum);
760                     cvt2ps(jcp.dst_dt, vmm_sum, ptr[reg_output + o_off * jcp.typesize_out], false);
761
762                     uni_vaddps(get_acc_reg(r * ur_w + ow), get_acc_reg(r * ur_w + ow), vmm_sum);
763                 }
764             }
765         }
766     }
767
768     const auto &p = attr_.post_ops_;
769     int eltwise_inj_idx = 0;
770     int depthwise_inj_idx = 0;
771     int start_idx = p.find(primitive_kind::convolution) + 1;
772     for (int i = start_idx; i < p.len_; i++) {
773         auto& post_op = p.entry_[i];
774         if (post_op.is_eltwise()) {
775             eltwise_injectors[eltwise_inj_idx]->compute_vector_range(4, 4 + repeats * ur_w);
776             eltwise_inj_idx++;
777         } else if (post_op.is_depthwise()) {
778             mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
779             mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
780
781             add(reg_d_weights, reg_oc_off);
782             add(reg_d_bias, reg_oc_off);
783
784             depthwise_injectors[depthwise_inj_idx]->compute_vector_range(4, 4 + ur_w, reg_d_weights, reg_d_bias);
785
786             if (repeats == 2) {
787                 add(reg_d_weights, (jcp.ch_block / 2) * sizeof(float));
788                 add(reg_d_bias, (jcp.ch_block / 2) * sizeof(float));
789
790                 depthwise_injectors[depthwise_inj_idx]->compute_vector_range(4 + ur_w, 4 + 2 * ur_w, reg_d_weights, reg_d_bias);
791             }
792
793             depthwise_inj_idx++;
794         }
795     }
796 }
797
798 template <cpu_isa_t isa>
799 void jit_uni_dw_conv_row_f32<isa>::store_dst_typed(const Xbyak::Address &op, Vmm vmm_dst, bool scalar_store) {
800     Ymm ymm_dst = Ymm(vmm_dst.getIdx());
801     Xmm xmm_dst = Xmm(vmm_dst.getIdx());
802
803     switch (jcp.dst_dt) {
804         case data_type::f32:
805         case data_type::s32:
806             if (scalar_store) {
807                 movq(reg_tmp_64, xmm_dst);
808                 mov(op, reg_tmp_32);
809             } else {
810                 uni_vmovups(op, vmm_dst);
811             }
812             break;
813         case data_type::s8:
814             uni_vpackssdw(vmm_dst, vmm_dst, vmm_dst);
815
816             if (isa != sse42 && !scalar_store)
817                 vpermq(ymm_dst, ymm_dst, 0x08);
818
819             uni_vpacksswb(vmm_dst, vmm_dst, vmm_dst);
820
821             if (scalar_store) {
822                 movq(reg_tmp_64, xmm_dst);
823                 mov(op, reg_tmp_8);
824             } else {
825                 if (isa != sse42)
826                     vmovq(op, xmm_dst);
827                 else
828                     movd(op, xmm_dst);
829             }
830             break;
831         case data_type::u8:
832         case data_type::bin:
833             uni_vpackusdw(vmm_dst, vmm_dst, vmm_dst);
834
835             if (isa != sse42 && !scalar_store)
836                 vpermq(ymm_dst, ymm_dst, 0x08);
837
838             uni_vpackuswb(vmm_dst, vmm_dst, vmm_dst);
839
840             if (scalar_store) {
841                 movq(reg_tmp_64, xmm_dst);
842                 mov(op, reg_tmp_8);
843             } else {
844                 if (isa != sse42)
845                     vmovq(op, xmm_dst);
846                 else
847                     movd(op, xmm_dst);
848             }
849             break;
850         default:
851             assert(!"unknown dst_dt");
852     }
853 }
854
855 template <cpu_isa_t isa>
856 void jit_uni_dw_conv_row_f32<isa>::store_dst(int ur_w, int oc_step) {
857     int repeats = isa == sse42 && oc_step > (jcp.ch_block / 2) ? 2 : 1;
858
859     for (int i = 0; i < repeats; i++) {
860         for (int ow = 0; ow < ur_w; ow++) {
861             Vmm vmm_dst = get_acc_reg(i * ur_w + ow);
862             if (jcp.dst_dt != data_type::f32 && jcp.dst_dt != data_type::bin) {
863                 if (attr_.round_mode_ == round_mode::nearest)
864                     uni_vcvtps2dq(vmm_dst, vmm_dst);
865                 else if (attr_.round_mode_ == round_mode::down) {
866                     uni_vroundps(vmm_dst, vmm_dst, 1);
867                     uni_vcvtps2dq(vmm_dst, vmm_dst);
868                 } else
869                     assert(!"unimplemented");
870             }
871         }
872     }
873
874     if (jcp.with_binarization) {
875         int output_step = div_up(ow_stride_, 8);
876
877         const auto &p = attr_.post_ops_;
878         int binarization_idx = p.find(primitive_kind::binarization);
879
880         mov(reg_b_weights, reinterpret_cast<size_t>(p.entry_[binarization_idx].binarization.weights_data));
881         add(reg_b_weights, reg_oc_off);
882
883         for (int ow = 0; ow < ur_w; ow++) {
884             for (int i = 0; i < repeats; i++) {
885                 int tail_size = isa == sse42 ? nstl::min(jcp.ch_block / 2, oc_step - i * jcp.ch_block / 2) : oc_step;
886                 mov(reg_b_mask, (1 << tail_size) - 1);
887                 uni_vmovups(vmm_thr, ptr[reg_b_weights + i * (jcp.ch_block / 2) * sizeof(float)]);
888
889                 Vmm vmm_dst = get_acc_reg(i * ur_w + ow);
890
891                 uni_vcmpgtps(vmm_dst, vmm_dst, vmm_thr);
892
893                 if (i == 0) {
894                     uni_vmovmskps(reg_tmp_32, vmm_dst);
895                     and_(reg_tmp_64, reg_b_mask);
896                 } else {
897                     uni_vmovmskps(reg_tmp2_32, vmm_dst);
898                     and_(reg_tmp2_64, reg_b_mask);
899                     shl(reg_tmp2_32, 4);
900                     or_(reg_tmp_32, reg_tmp2_32);
901                 }
902
903                 if (i == repeats - 1) {
904                     const size_t o_off = ow * output_step;
905                     mov(ptr[reg_output + o_off * jcp.typesize_out], reg_tmp_8);
906                 }
907             }
908         }
909     } else {
910         for (int i = 0; i < repeats; i++) {
911             int tail_size = isa == sse42 ? nstl::min(jcp.ch_block / 2, oc_step - i * jcp.ch_block / 2) : oc_step;
912             bool is_scalar_store = isa == sse42 ? tail_size < jcp.ch_block / 2 : tail_size < jcp.ch_block;
913             if (is_scalar_store) {
914                 for (int ow = 0; ow < ur_w; ow++) {
915                     Vmm vmm_dst = get_acc_reg(i * ur_w + ow);
916                     Ymm ymm_dst = Ymm(vmm_dst.getIdx());
917
918                     for (int oc = 0; oc < tail_size; oc++) {
919                         int o_off = ow * ow_stride_ + i * (jcp.ch_block / 2) + oc;
920                         store_dst_typed(ptr[reg_output + o_off * jcp.typesize_out], vmm_dst, true);
921
922                         if (isa == sse42) {
923                             psrldq(vmm_dst, jcp.typesize_out);
924                         } else {
925                             vperm2i128(ymm_tmp, ymm_dst, ymm_dst, 0x01);
926                             vpalignr(ymm_dst, vmm_tmp, ymm_dst, jcp.typesize_out);
927                         }
928                     }
929                 }
930             } else {
931                 for (int ow = 0; ow < ur_w; ow++) {
932                     int o_off = ow * ow_stride_ + i * (jcp.ch_block / 2);
933                     Vmm vmm_dst = get_acc_reg(i * ur_w + ow);
934
935                     store_dst_typed(ptr[reg_output + o_off * jcp.typesize_out], vmm_dst, false);
936                 }
937             }
938         }
939     }
940 }
941
942 template <cpu_isa_t isa>
943 void jit_uni_dw_conv_row_f32<isa>::loop_body(int oc_step) {
944     Label left_pad_label;
945     Label right_pad_label;
946     Label unrolled_w_label;
947     Label tail_w_label;
948     Label exit_label;
949
950     int output_step = jcp.with_binarization ? div_up(ow_stride_, 8) : ow_stride_;
951
952     L(left_pad_label); {
953         int ur_w = 1;
954         int kw = jcp.iw == 1 ? jcp.kw - 2 : jcp.kw - 1;
955
956         mov(aux_reg_input0, reg_input0);
957         mov(aux_reg_input1, reg_input1);
958         mov(aux_reg_input2, reg_input2);
959         mov(aux_reg_kernel, reg_kernel);
960         add(aux_reg_kernel, jcp.ch_block*jcp.typesize_in);
961
962         load_src(ur_w);
963         apply_filter(ur_w, kw);
964         apply_postprocessing(ur_w, oc_step);
965         store_dst(ur_w, oc_step);
966
967         add(reg_input0, jcp.typesize_in * ur_w * jcp.ch_block * (jcp.stride_w-1));
968         add(reg_input1, jcp.typesize_in * ur_w * jcp.ch_block * (jcp.stride_w-1));
969         add(reg_input2, jcp.typesize_in * ur_w * jcp.ch_block * (jcp.stride_w-1));
970         add(reg_output, jcp.typesize_out * ur_w * output_step);
971
972         sub(reg_ur_w, ur_w);
973     }
974
975     L(unrolled_w_label); {
976         int ur_w = jcp.ur_w;
977         int kw = jcp.kw;
978
979         cmp(reg_ur_w, ur_w);
980         jle(tail_w_label, T_NEAR);
981
982         mov(aux_reg_input0, reg_input0);
983         mov(aux_reg_input1, reg_input1);
984         mov(aux_reg_input2, reg_input2);
985         mov(aux_reg_kernel, reg_kernel);
986
987         load_src(ur_w);
988         apply_filter(ur_w, kw);
989         apply_postprocessing(ur_w, oc_step);
990         store_dst(ur_w, oc_step);
991
992         add(reg_input0, jcp.typesize_in * ur_w * jcp.ch_block * jcp.stride_w);
993         add(reg_input1, jcp.typesize_in * ur_w * jcp.ch_block * jcp.stride_w);
994         add(reg_input2, jcp.typesize_in * ur_w * jcp.ch_block * jcp.stride_w);
995         add(reg_output, jcp.typesize_out * ur_w * output_step);
996
997         sub(reg_ur_w, ur_w);
998         jmp(unrolled_w_label, T_NEAR);
999     }
1000
1001     L(tail_w_label); {
1002         int ur_w = 1;
1003         int kw = jcp.kw;
1004
1005         cmp(reg_ur_w, ur_w);
1006         if (jcp.ow > 1)
1007             jle(right_pad_label, T_NEAR);
1008         else
1009             jle(exit_label, T_NEAR);
1010
1011         mov(aux_reg_input0, reg_input0);
1012         mov(aux_reg_input1, reg_input1);
1013         mov(aux_reg_input2, reg_input2);
1014         mov(aux_reg_kernel, reg_kernel);
1015
1016         load_src(ur_w);
1017         apply_filter(ur_w, kw);
1018         apply_postprocessing(ur_w, oc_step);
1019         store_dst(ur_w, oc_step);
1020
1021         add(reg_input0, jcp.typesize_in * ur_w * jcp.ch_block * jcp.stride_w);
1022         add(reg_input1, jcp.typesize_in * ur_w * jcp.ch_block * jcp.stride_w);
1023         add(reg_input2, jcp.typesize_in * ur_w * jcp.ch_block * jcp.stride_w);
1024         add(reg_output, jcp.typesize_out * ur_w * output_step);
1025
1026         sub(reg_ur_w, ur_w);
1027         jmp(tail_w_label, T_NEAR);
1028     }
1029
1030     if (jcp.ow > 1) {
1031         L(right_pad_label); {
1032             int ur_w = 1;
1033             int kw = jcp.kw - ((jcp.stride_w == 1) ? 1 : jcp.iw % jcp.stride_w);
1034
1035             mov(aux_reg_input0, reg_input0);
1036             mov(aux_reg_input1, reg_input1);
1037             mov(aux_reg_input2, reg_input2);
1038             mov(aux_reg_kernel, reg_kernel);
1039
1040             load_src(ur_w);
1041             apply_filter(ur_w, kw);
1042             apply_postprocessing(ur_w, oc_step);
1043             store_dst(ur_w, oc_step);
1044
1045             sub(reg_ur_w, ur_w);
1046         }
1047     }
1048
1049     L(exit_label);
1050 }
1051
1052 template <cpu_isa_t isa>
1053 void jit_uni_dw_conv_row_f32<isa>::generate() {
1054     const auto &p = attr_.post_ops_;
1055     int start_idx = p.find(primitive_kind::convolution) + 1;
1056     for (int i = start_idx; i < p.len_; i++) {
1057         auto &post_op = p.entry_[i];
1058         if (post_op.is_eltwise()) {
1059             eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<isa>(
1060                     this,
1061                     post_op.eltwise.alg,
1062                     post_op.eltwise.alpha,
1063                     post_op.eltwise.beta
1064             ));
1065         } else if (post_op.is_depthwise()) {
1066             depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32<isa>(
1067                     this,
1068                     post_op.depthwise.alg
1069             ));
1070         }
1071     }
1072
1073     this->preamble();
1074
1075     mov(reg_input0, ptr[this->param1 + GET_OFF_DW(src_row0)]);
1076     mov(reg_input1, ptr[this->param1 + GET_OFF_DW(src_row1)]);
1077     mov(reg_input2, ptr[this->param1 + GET_OFF_DW(src_row2)]);
1078     mov(reg_output, ptr[this->param1 + GET_OFF_DW(dst)]);
1079     mov(reg_kernel, ptr[this->param1 + GET_OFF_DW(filt)]);
1080     if (jcp.with_bias)
1081         mov(reg_bias, ptr[this->param1 + GET_OFF_DW(bias)]);
1082     mov(reg_kh, ptr[this->param1 + GET_OFF_DW(kh_padding)]);
1083     mov(reg_ur_w, ptr[this->param1 + GET_OFF_DW(ur_w)]);
1084     mov(reg_oc_work, ptr[this->param1 + GET_OFF_DW(oc_work)]);
1085     mov(reg_oc_off, ptr[this->param1 + GET_OFF_DW(oc_off)]);
1086
1087     Label(tail_label);
1088     Label(exit_label);
1089
1090     cmp(reg_oc_work, jcp.ch_block);
1091     jl(tail_label, T_NEAR);
1092
1093     loop_body(jcp.ch_block);
1094     jmp(exit_label, T_NEAR);
1095
1096     L(tail_label);
1097
1098     if (jcp.oc % jcp.ch_block != 0)
1099         loop_body(jcp.oc % jcp.ch_block);
1100
1101     L(exit_label);
1102
1103     this->postamble();
1104
1105     for (auto& inj : eltwise_injectors)
1106         inj->prepare_table();
1107 }
1108
1109 template <cpu_isa_t isa>
1110 bool jit_uni_dw_conv_row_f32<isa>::post_ops_ok(jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
1111     const auto &p = attr.post_ops_;
1112
1113     auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
1114     auto is_depthwise = [&](int idx) { return p.entry_[idx].is_depthwise(); };
1115     auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
1116     auto is_simple = [&](int idx) { return is_eltwise(idx) || is_depthwise(idx); };
1117     auto is_binarization = [&](int idx) { return p.entry_[idx].is_binarization(); };
1118
1119     int start_idx = p.find(primitive_kind::convolution) + 1;
1120
1121     switch (p.len_ - start_idx) {
1122     case 0: return true; // no post_ops
1123     case 1: return is_simple(start_idx) || is_sum(start_idx) || is_binarization(start_idx);
1124     case 2: return (is_sum(start_idx) && is_simple(start_idx+1)) || (is_simple(start_idx) && is_simple(start_idx+1)) ||
1125                    (is_simple(start_idx) && is_binarization(start_idx+1));
1126     case 3: return (is_sum(start_idx) && is_simple(start_idx+1) && is_simple(start_idx+2));
1127     default: return false;
1128     }
1129
1130     return false;
1131 }
1132
1133 template <cpu_isa_t isa>
1134 status_t jit_uni_dw_conv_row_f32<isa>::init_conf(jit_1x1_conv_conf_t &jcp, jit_conv_conf_t &jcp_dw,
1135         const primitive_attr_t &attr) {
1136     if (!mayiuse(isa)) return status::unimplemented;
1137     const int simd_w = isa == avx512_common ? 16 : 8;
1138
1139     const auto &p = attr.post_ops_;
1140
1141     int dw_conv_ind = p.find(primitive_kind::convolution);
1142     jcp_dw.with_sum = p.find(primitive_kind::sum, dw_conv_ind) != -1;
1143
1144     jcp_dw.ch_block = simd_w;
1145     jcp_dw.with_bias = true;
1146
1147     jcp_dw.kh = p.entry_[dw_conv_ind].dw_conv.ker_h;
1148     jcp_dw.kw = p.entry_[dw_conv_ind].dw_conv.ker_w;
1149     jcp_dw.ic = jcp.oc;
1150     jcp_dw.oc = jcp.oc;
1151     jcp_dw.ih = p.entry_[dw_conv_ind].dw_conv.in_h;
1152     jcp_dw.iw = p.entry_[dw_conv_ind].dw_conv.in_w;
1153     jcp_dw.oh = jcp.dw_conv_oh;
1154     jcp_dw.ow = jcp.dw_conv_ow;
1155     jcp_dw.stride_h = p.entry_[dw_conv_ind].dw_conv.str_h;
1156     jcp_dw.stride_w = p.entry_[dw_conv_ind].dw_conv.str_w;
1157     jcp_dw.conv_weights = p.entry_[dw_conv_ind].dw_conv.weights_data;
1158     jcp_dw.conv_biases = p.entry_[dw_conv_ind].dw_conv.biases_data;
1159
1160     if (jcp_dw.kh != 3 || jcp_dw.kw != 3)
1161         return status::unimplemented;
1162
1163     if (!post_ops_ok(jcp_dw, attr))
1164         return status::unimplemented;
1165
1166     jcp_dw.ur_w = 4;
1167
1168     jcp_dw.src_dt = jcp.src_dt;
1169     jcp_dw.dst_dt = jcp.dst_dt;
1170     jcp_dw.bia_dt = jcp.bia_dt;
1171     jcp_dw.typesize_in = (int)types::data_type_size(jcp.src_dt);
1172     jcp_dw.typesize_bia = (int)types::data_type_size(jcp.bia_dt);
1173     jcp_dw.typesize_out = (int)types::data_type_size(jcp.dst_dt);
1174
1175     if (jcp_dw.src_dt != mkldnn_f32 && jcp_dw.src_dt != mkldnn_u8)
1176         return status::unimplemented;
1177
1178     return status::success;
1179 }
1180
1181 template <cpu_isa_t isa>
1182 status_t jit_uni_dw_conv_row_f32<isa>::init_conf(jit_conv_conf_t &jcp, jit_conv_conf_t &jcp_dw,
1183         const primitive_attr_t &attr) {
1184     if (!mayiuse(isa)) return status::unimplemented;
1185     const int simd_w = isa == avx512_common ? 16 : 8;
1186
1187     const auto &p = attr.post_ops_;
1188
1189     int dw_conv_ind = p.find(primitive_kind::convolution);
1190     jcp_dw.with_sum = p.find(primitive_kind::sum, dw_conv_ind) != -1;
1191
1192     jcp_dw.ch_block = simd_w;
1193     jcp_dw.with_bias = true;
1194
1195     jcp_dw.kh = p.entry_[dw_conv_ind].dw_conv.ker_h;
1196     jcp_dw.kw = p.entry_[dw_conv_ind].dw_conv.ker_w;
1197     jcp_dw.ic = jcp.oc;
1198     jcp_dw.oc = jcp.oc;
1199     jcp_dw.ih = p.entry_[dw_conv_ind].dw_conv.in_h;
1200     jcp_dw.iw = p.entry_[dw_conv_ind].dw_conv.in_w;
1201     jcp_dw.oh = jcp.dw_conv_oh;
1202     jcp_dw.ow = jcp.dw_conv_ow;
1203     jcp_dw.stride_h = p.entry_[dw_conv_ind].dw_conv.str_h;
1204     jcp_dw.stride_w = p.entry_[dw_conv_ind].dw_conv.str_w;
1205     jcp_dw.conv_weights = p.entry_[dw_conv_ind].dw_conv.weights_data;
1206     jcp_dw.conv_biases = p.entry_[dw_conv_ind].dw_conv.biases_data;
1207
1208     if (jcp_dw.kh != 3 || jcp_dw.kw != 3)
1209         return status::unimplemented;
1210
1211     if (!post_ops_ok(jcp_dw, attr))
1212         return status::unimplemented;
1213
1214     jcp_dw.ur_w = 4;
1215
1216     jcp_dw.src_dt = jcp.dst_dt;
1217     jcp_dw.dst_dt = jcp.dst_dt;
1218     jcp_dw.bia_dt = jcp.bia_dt;
1219     jcp_dw.typesize_in = (int)types::data_type_size(jcp.src_dt);
1220     jcp_dw.typesize_bia = (int)types::data_type_size(jcp.bia_dt);
1221     jcp_dw.typesize_out = (int)types::data_type_size(jcp.dst_dt);
1222
1223     if (jcp_dw.src_dt != mkldnn_f32 && jcp_dw.src_dt != mkldnn_u8)
1224         return status::unimplemented;
1225
1226     return status::success;
1227 }
1228
1229 template <cpu_isa_t isa>
1230 status_t jit_uni_dw_conv_row_f32<isa>::init_conf(jit_bin_conv_conf_t &jcp, jit_conv_conf_t &jcp_dw,
1231         const primitive_attr_t &attr) {
1232     if (!mayiuse(isa)) return status::unimplemented;
1233     const int simd_w = isa == avx512_common ? 16 : 8;
1234
1235     const auto &p = attr.post_ops_;
1236
1237     int dw_conv_ind = p.find(primitive_kind::convolution);
1238     jcp_dw.with_sum = p.find(primitive_kind::sum, dw_conv_ind) != -1;
1239     jcp_dw.with_binarization = p.find(primitive_kind::binarization, dw_conv_ind) != -1;
1240
1241     jcp_dw.ch_block = simd_w;
1242     jcp_dw.with_bias = true;
1243
1244     jcp_dw.kh = p.entry_[dw_conv_ind].dw_conv.ker_h;
1245     jcp_dw.kw = p.entry_[dw_conv_ind].dw_conv.ker_w;
1246     jcp_dw.ic = jcp.oc;
1247     jcp_dw.oc = jcp.oc;
1248     jcp_dw.ih = p.entry_[dw_conv_ind].dw_conv.in_h;
1249     jcp_dw.iw = p.entry_[dw_conv_ind].dw_conv.in_w;
1250     jcp_dw.oh = jcp.dw_conv_oh;
1251     jcp_dw.ow = jcp.dw_conv_ow;
1252     jcp_dw.stride_h = p.entry_[dw_conv_ind].dw_conv.str_h;
1253     jcp_dw.stride_w = p.entry_[dw_conv_ind].dw_conv.str_w;
1254     jcp_dw.conv_weights = p.entry_[dw_conv_ind].dw_conv.weights_data;
1255     jcp_dw.conv_biases = p.entry_[dw_conv_ind].dw_conv.biases_data;
1256
1257     if (jcp_dw.kh != 3 || jcp_dw.kw != 3)
1258         return status::unimplemented;
1259
1260     if (!post_ops_ok(jcp_dw, attr))
1261         return status::unimplemented;
1262
1263     jcp_dw.ur_w = 4;
1264
1265     jcp_dw.src_dt = mkldnn_f32;
1266     jcp_dw.dst_dt = jcp_dw.with_binarization ? mkldnn_bin : mkldnn_f32;
1267     jcp_dw.bia_dt = mkldnn_f32;
1268     jcp_dw.typesize_in = (int)types::data_type_size(jcp_dw.src_dt);
1269     jcp_dw.typesize_bia = (int)types::data_type_size(jcp_dw.bia_dt);
1270     jcp_dw.typesize_out = (int)types::data_type_size(jcp_dw.dst_dt);
1271
1272     if (jcp_dw.src_dt != mkldnn_f32 && jcp_dw.src_dt != mkldnn_u8)
1273         return status::unimplemented;
1274
1275     return status::success;
1276 }
1277
1278 template struct jit_uni_dw_conv_row_f32<avx512_common>;
1279 template struct jit_uni_dw_conv_row_f32<avx2>;
1280 template struct jit_uni_dw_conv_row_f32<sse42>;
1281
1282 }
1283 }
1284 }