Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_binarization.cpp
1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "mkldnn_types.h"
18 #include "mkldnn_thread.hpp"
19 #include "nstl.hpp"
20 #include "utils.hpp"
21 #include "jit_uni_binarization.hpp"
22
23 #define GET_OFF(field) offsetof(jit_args, field)
24
25 namespace mkldnn {
26 namespace impl {
27 namespace cpu {
28
29 using namespace Xbyak;
30 using namespace mkldnn::impl::memory_format;
31 using namespace mkldnn::impl::utils;
32
33 struct jit_args {
34     const float* from;
35     const uint8_t* to;
36     const float* weights;
37     size_t work_amount;
38 };
39
40 struct jit_uni_binarization_kernel_f32 : public c_compatible {
41     const binarization_desc_t &desc_;
42     void (*ker_)(const jit_args *);
43
44     void operator()(const jit_args *args) { assert(ker_); ker_(args); }
45
46     jit_uni_binarization_kernel_f32(const binarization_desc_t &desc)
47         : desc_(desc), ker_(nullptr) {}
48     virtual ~jit_uni_binarization_kernel_f32() {}
49 };
50
51 /* jit kernels */
52 namespace {
53
54 template <cpu_isa_t isa>
55 struct jit_uni_bin_depthwise_kernel_f32 : public jit_uni_binarization_kernel_f32,
56     public jit_generator
57 {
58     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_bin_depthwise_kernel_f32)
59     jit_uni_bin_depthwise_kernel_f32(const binarization_desc_t &desc)
60         : jit_uni_binarization_kernel_f32(desc), jit_generator() {
61         assert(desc.alg_kind == alg_kind::binarization_depthwise);
62         assert(isa == sse42 || isa == avx2 || isa == avx512_common);
63
64         this->preamble();
65
66         mov(reg_from, ptr[param + GET_OFF(from)]);
67         mov(reg_to, ptr[param + GET_OFF(to)]);
68         mov(reg_weights, ptr[param + GET_OFF(weights)]);
69         mov(reg_work_amount, ptr[param + GET_OFF(work_amount)]);
70
71         const int nbits = 8;
72         int simd_w = isa == avx512_common ? 16 : 8;
73         const int C = desc.src_desc.dims[1];
74         const int tail_size = C % simd_w;
75
76         Label unrolled_loop_label;
77         Label main_loop_label;
78         Label tail_label;
79         Label exit_label;
80
81         L(unrolled_loop_label); {
82             int step = isa == sse42 ? nbits / 2 : isa == avx2 ? nbits : 2 * nbits;
83             const int ur_ch = isa == sse42 ? nbits : isa == avx2 ? nbits / 2 : nbits / 4;
84             const int unrolled_loop_step = ur_ch * step;
85
86             cmp(reg_work_amount, unrolled_loop_step);
87             jl(main_loop_label, T_NEAR);
88
89             xor_(reg_bin_32, reg_bin_32);
90             for (int ch = 0; ch < ur_ch; ch++) {
91                 uni_vmovups(vmm_src(0), ptr[reg_from + ch*step*sizeof(float)]);
92                 uni_vmovups(vmm_wei(0), ptr[reg_weights + ch*step*sizeof(float)]);
93                 if (isa == avx512_common) {
94                     vcmpps(k_mask, vmm_src(0), vmm_wei(0), _cmp_gt_os);
95                     kmovw(reg_src_32, k_mask);
96                 } else {
97                     uni_vcmpgtps(vmm_src(0), vmm_src(0), vmm_wei(0));
98                     uni_vmovmskps(reg_src_32, vmm_src(0));
99                 }
100                 shl(reg_src_32, ch * step);
101                 or_(reg_bin_32, reg_src_32);
102             }
103             mov(ptr[reg_to], reg_bin_32);
104
105             add(reg_from, unrolled_loop_step*sizeof(float));
106             add(reg_weights, unrolled_loop_step*sizeof(float));
107             add(reg_to, sizeof(uint32_t));
108             sub(reg_work_amount, unrolled_loop_step);
109
110             jmp(unrolled_loop_label, T_NEAR);
111         }
112
113         L(main_loop_label); {
114             int repeats = isa == sse42 ? 2 : 1;
115             int step = isa == sse42 ? nbits / 2 : isa == avx2 ? nbits : nbits * 2;
116             const int main_loop_step = step * repeats;
117
118             cmp(reg_work_amount, main_loop_step);
119             jl(tail_label, T_NEAR);
120
121             xor_(reg_bin_32, reg_bin_32);
122             for (int i = 0; i < repeats; i++) {
123                 uni_vmovups(vmm_src(0), ptr[reg_from + i*step*sizeof(float)]);
124                 uni_vmovups(vmm_wei(0), ptr[reg_weights + i*step*sizeof(float)]);
125                 if (isa == avx512_common) {
126                     vcmpps(k_mask, vmm_src(0), vmm_wei(0), _cmp_gt_os);
127                     kmovw(reg_src_32, k_mask);
128                 } else {
129                     uni_vcmpgtps(vmm_src(0), vmm_src(0), vmm_wei(0));
130                     uni_vmovmskps(reg_src_32, vmm_src(0));
131                 }
132                 shl(reg_src_32, i * step);
133                 or_(reg_bin_32, reg_src_32);
134             }
135             if (isa == avx512_common)
136                 mov(ptr[reg_to], reg_bin_16);
137             else        
138                 mov(ptr[reg_to], reg_bin_8);
139
140             add(reg_from, main_loop_step*sizeof(float));
141             add(reg_weights, main_loop_step*sizeof(float));
142             add(reg_to, isa == avx512_common ? sizeof(uint16_t) : sizeof(uint8_t));
143             sub(reg_work_amount, main_loop_step);
144
145             jmp(main_loop_label, T_NEAR);
146         }
147
148         L(tail_label); {
149             if (tail_size != 0) {
150                 xor_(reg_bin_32, reg_bin_32);
151                 for (int c = 0; c < tail_size; c++) {
152                     uni_vpxor(xmm_src(0), xmm_src(0), xmm_src(0));
153                     uni_vpxor(xmm_wei(0), xmm_wei(0), xmm_wei(0));
154
155                     movss(xmm_src(0), ptr[reg_from + c * sizeof(float)]);
156                     movss(xmm_wei(0), ptr[reg_weights + c * sizeof(float)]);
157                     uni_vcmpgtps(xmm_src(0), xmm_src(0), xmm_wei(0));
158                     uni_vmovmskps(reg_src_32, xmm_src(0));
159
160                     shl(reg_src_32, c);
161                     or_(reg_bin_32, reg_src_32);
162                 }
163                 if (isa == avx512_common && tail_size > nbits)
164                     mov(ptr[reg_to], reg_bin_16);
165                 else
166                     mov(ptr[reg_to], reg_bin_8);
167             }
168         }
169
170         L(exit_label);
171
172         this->postamble();
173
174         ker_ = (decltype(ker_))this->getCode();
175     }
176
177 private:
178     using Vmm = typename utils::conditional3<isa == sse42, Xmm,
179                                              isa == avx2, Ymm, Zmm>::type;
180
181     inline Vmm vmm_src(int idx) { return Vmm(idx); }
182     inline Xmm xmm_src(int idx) { return Xmm(idx); }
183     inline Vmm vmm_wei(int idx) { return Vmm(idx + 4); }
184     inline Xmm xmm_wei(int idx) { return Xmm(idx + 4); }
185
186     Reg64 param = abi_param1;
187     Reg64 reg_from = r8;
188     Reg64 reg_to = r9;
189     Reg64 reg_work_amount = r10;
190     Reg64 reg_weights = r11;
191     Reg16 reg_bin_16 = r12w;
192     Reg32 reg_bin_32 = r12d;
193     Reg8 reg_bin_8 = r12b;
194     Reg32 reg_src_32 = r13d;
195     Reg64 reg_src_64 = r13;
196
197     const unsigned char _cmp_gt_os = 6;
198     Xbyak::Opmask k_mask = Xbyak::Opmask(1);
199 };
200
201 } /* namespace */
202
203 template <cpu_isa_t isa>
204 status_t jit_uni_binarization_fwd_t<isa>::pd_t::init() {
205     using namespace alg_kind;
206
207     auto desired_fmt = nhwc;
208
209     assert(engine()->kind() == engine_kind::cpu);
210     bool ok = true && mayiuse(isa)
211         && utils::one_of(desc()->prop_kind, prop_kind::forward_training, prop_kind::forward_inference)
212         && utils::everyone_is(data_type::f32, desc()->src_desc.data_type, desc()->weights_desc.data_type)
213         && utils::everyone_is(data_type::bin, desc()->dst_desc.data_type)
214         && desc()->src_desc.format == desc()->dst_desc.format
215         && utils::one_of(desc()->src_desc.format, desired_fmt)
216         && utils::one_of(desc()->dst_desc.format, desired_fmt)
217         && utils::one_of(desc()->weights_desc.format, x)
218         && attr()->has_default_values();
219
220     return ok ? status::success : status::unimplemented;
221 }
222
223 template <cpu_isa_t isa>
224 jit_uni_binarization_fwd_t<isa>::jit_uni_binarization_fwd_t(const pd_t *apd,
225         const input_vector &inputs, const output_vector &outputs)
226     : cpu_primitive_t(apd, inputs, outputs), kernel_(nullptr) {
227     const auto &desc = *pd()->desc();
228     switch (desc.alg_kind) {
229         case alg_kind::binarization_depthwise:
230             kernel_ = new jit_uni_bin_depthwise_kernel_f32<isa>(desc); break;
231         default: assert(!"unknown binarization alg_kind");
232     }
233 }
234
235 template <cpu_isa_t isa>
236 jit_uni_binarization_fwd_t<isa>::~jit_uni_binarization_fwd_t() {
237     delete kernel_;
238 }
239
240 template <cpu_isa_t isa>
241 void jit_uni_binarization_fwd_t<isa>::execute_forward() const {
242     auto src = reinterpret_cast<const src_data_t*>(this->input_memory(0));
243     auto weights = reinterpret_cast<const src_data_t*>(this->input_memory(1));
244     auto dst = reinterpret_cast<uint8_t*>(this->memory());
245
246     const memory_desc_wrapper src_d(pd()->src_pd());
247     const memory_desc_wrapper dst_d(pd()->dst_pd());
248     const memory_desc_wrapper weights_d(pd()->weights_pd(0));
249
250     const int N = src_d.dims()[0];
251     const int C = src_d.dims()[1];
252     const int H = src_d.dims()[2];
253     const int W = src_d.dims()[3];
254
255     int nbits = 8;
256
257     parallel_nd(N, H, W,
258         [&](int n, int h, int w) {
259         auto arg = jit_args();
260
261         arg.from    = &src[src_d.blk_off(n, 0, h, w)];
262         arg.to      = &dst[dst_d.blk_off(n, 0, h, w) / nbits];
263         arg.weights = &weights[weights_d.blk_off(0)];
264         arg.work_amount = (size_t)C;
265
266         (*kernel_)(&arg);
267     });
268 }
269
270 template struct jit_uni_binarization_fwd_t<sse42>;
271 template struct jit_uni_binarization_fwd_t<avx2>;
272 template struct jit_uni_binarization_fwd_t<avx512_common>;
273
274 }
275 }
276 }