1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "mkldnn_types.h"
18 #include "mkldnn_thread.hpp"
21 #include "jit_uni_binarization.hpp"
23 #define GET_OFF(field) offsetof(jit_args, field)
29 using namespace Xbyak;
30 using namespace mkldnn::impl::memory_format;
31 using namespace mkldnn::impl::utils;
40 struct jit_uni_binarization_kernel_f32 : public c_compatible {
41 const binarization_desc_t &desc_;
42 void (*ker_)(const jit_args *);
44 void operator()(const jit_args *args) { assert(ker_); ker_(args); }
46 jit_uni_binarization_kernel_f32(const binarization_desc_t &desc)
47 : desc_(desc), ker_(nullptr) {}
48 virtual ~jit_uni_binarization_kernel_f32() {}
54 template <cpu_isa_t isa>
55 struct jit_uni_bin_depthwise_kernel_f32 : public jit_uni_binarization_kernel_f32,
58 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_bin_depthwise_kernel_f32)
59 jit_uni_bin_depthwise_kernel_f32(const binarization_desc_t &desc)
60 : jit_uni_binarization_kernel_f32(desc), jit_generator() {
61 assert(desc.alg_kind == alg_kind::binarization_depthwise);
62 assert(isa == sse42 || isa == avx2 || isa == avx512_common);
66 mov(reg_from, ptr[param + GET_OFF(from)]);
67 mov(reg_to, ptr[param + GET_OFF(to)]);
68 mov(reg_weights, ptr[param + GET_OFF(weights)]);
69 mov(reg_work_amount, ptr[param + GET_OFF(work_amount)]);
72 int simd_w = isa == avx512_common ? 16 : 8;
73 const int C = desc.src_desc.dims[1];
74 const int tail_size = C % simd_w;
76 Label unrolled_loop_label;
77 Label main_loop_label;
81 L(unrolled_loop_label); {
82 int step = isa == sse42 ? nbits / 2 : isa == avx2 ? nbits : 2 * nbits;
83 const int ur_ch = isa == sse42 ? nbits : isa == avx2 ? nbits / 2 : nbits / 4;
84 const int unrolled_loop_step = ur_ch * step;
86 cmp(reg_work_amount, unrolled_loop_step);
87 jl(main_loop_label, T_NEAR);
89 xor_(reg_bin_32, reg_bin_32);
90 for (int ch = 0; ch < ur_ch; ch++) {
91 uni_vmovups(vmm_src(0), ptr[reg_from + ch*step*sizeof(float)]);
92 uni_vmovups(vmm_wei(0), ptr[reg_weights + ch*step*sizeof(float)]);
93 if (isa == avx512_common) {
94 vcmpps(k_mask, vmm_src(0), vmm_wei(0), _cmp_gt_os);
95 kmovw(reg_src_32, k_mask);
97 uni_vcmpgtps(vmm_src(0), vmm_src(0), vmm_wei(0));
98 uni_vmovmskps(reg_src_32, vmm_src(0));
100 shl(reg_src_32, ch * step);
101 or_(reg_bin_32, reg_src_32);
103 mov(ptr[reg_to], reg_bin_32);
105 add(reg_from, unrolled_loop_step*sizeof(float));
106 add(reg_weights, unrolled_loop_step*sizeof(float));
107 add(reg_to, sizeof(uint32_t));
108 sub(reg_work_amount, unrolled_loop_step);
110 jmp(unrolled_loop_label, T_NEAR);
113 L(main_loop_label); {
114 int repeats = isa == sse42 ? 2 : 1;
115 int step = isa == sse42 ? nbits / 2 : isa == avx2 ? nbits : nbits * 2;
116 const int main_loop_step = step * repeats;
118 cmp(reg_work_amount, main_loop_step);
119 jl(tail_label, T_NEAR);
121 xor_(reg_bin_32, reg_bin_32);
122 for (int i = 0; i < repeats; i++) {
123 uni_vmovups(vmm_src(0), ptr[reg_from + i*step*sizeof(float)]);
124 uni_vmovups(vmm_wei(0), ptr[reg_weights + i*step*sizeof(float)]);
125 if (isa == avx512_common) {
126 vcmpps(k_mask, vmm_src(0), vmm_wei(0), _cmp_gt_os);
127 kmovw(reg_src_32, k_mask);
129 uni_vcmpgtps(vmm_src(0), vmm_src(0), vmm_wei(0));
130 uni_vmovmskps(reg_src_32, vmm_src(0));
132 shl(reg_src_32, i * step);
133 or_(reg_bin_32, reg_src_32);
135 if (isa == avx512_common)
136 mov(ptr[reg_to], reg_bin_16);
138 mov(ptr[reg_to], reg_bin_8);
140 add(reg_from, main_loop_step*sizeof(float));
141 add(reg_weights, main_loop_step*sizeof(float));
142 add(reg_to, isa == avx512_common ? sizeof(uint16_t) : sizeof(uint8_t));
143 sub(reg_work_amount, main_loop_step);
145 jmp(main_loop_label, T_NEAR);
149 if (tail_size != 0) {
150 xor_(reg_bin_32, reg_bin_32);
151 for (int c = 0; c < tail_size; c++) {
152 uni_vpxor(xmm_src(0), xmm_src(0), xmm_src(0));
153 uni_vpxor(xmm_wei(0), xmm_wei(0), xmm_wei(0));
155 movss(xmm_src(0), ptr[reg_from + c * sizeof(float)]);
156 movss(xmm_wei(0), ptr[reg_weights + c * sizeof(float)]);
157 uni_vcmpgtps(xmm_src(0), xmm_src(0), xmm_wei(0));
158 uni_vmovmskps(reg_src_32, xmm_src(0));
161 or_(reg_bin_32, reg_src_32);
163 if (isa == avx512_common && tail_size > nbits)
164 mov(ptr[reg_to], reg_bin_16);
166 mov(ptr[reg_to], reg_bin_8);
174 ker_ = (decltype(ker_))this->getCode();
178 using Vmm = typename utils::conditional3<isa == sse42, Xmm,
179 isa == avx2, Ymm, Zmm>::type;
181 inline Vmm vmm_src(int idx) { return Vmm(idx); }
182 inline Xmm xmm_src(int idx) { return Xmm(idx); }
183 inline Vmm vmm_wei(int idx) { return Vmm(idx + 4); }
184 inline Xmm xmm_wei(int idx) { return Xmm(idx + 4); }
186 Reg64 param = abi_param1;
189 Reg64 reg_work_amount = r10;
190 Reg64 reg_weights = r11;
191 Reg16 reg_bin_16 = r12w;
192 Reg32 reg_bin_32 = r12d;
193 Reg8 reg_bin_8 = r12b;
194 Reg32 reg_src_32 = r13d;
195 Reg64 reg_src_64 = r13;
197 const unsigned char _cmp_gt_os = 6;
198 Xbyak::Opmask k_mask = Xbyak::Opmask(1);
203 template <cpu_isa_t isa>
204 status_t jit_uni_binarization_fwd_t<isa>::pd_t::init() {
205 using namespace alg_kind;
207 auto desired_fmt = nhwc;
209 assert(engine()->kind() == engine_kind::cpu);
210 bool ok = true && mayiuse(isa)
211 && utils::one_of(desc()->prop_kind, prop_kind::forward_training, prop_kind::forward_inference)
212 && utils::everyone_is(data_type::f32, desc()->src_desc.data_type, desc()->weights_desc.data_type)
213 && utils::everyone_is(data_type::bin, desc()->dst_desc.data_type)
214 && desc()->src_desc.format == desc()->dst_desc.format
215 && utils::one_of(desc()->src_desc.format, desired_fmt)
216 && utils::one_of(desc()->dst_desc.format, desired_fmt)
217 && utils::one_of(desc()->weights_desc.format, x)
218 && attr()->has_default_values();
220 return ok ? status::success : status::unimplemented;
223 template <cpu_isa_t isa>
224 jit_uni_binarization_fwd_t<isa>::jit_uni_binarization_fwd_t(const pd_t *apd,
225 const input_vector &inputs, const output_vector &outputs)
226 : cpu_primitive_t(apd, inputs, outputs), kernel_(nullptr) {
227 const auto &desc = *pd()->desc();
228 switch (desc.alg_kind) {
229 case alg_kind::binarization_depthwise:
230 kernel_ = new jit_uni_bin_depthwise_kernel_f32<isa>(desc); break;
231 default: assert(!"unknown binarization alg_kind");
235 template <cpu_isa_t isa>
236 jit_uni_binarization_fwd_t<isa>::~jit_uni_binarization_fwd_t() {
240 template <cpu_isa_t isa>
241 void jit_uni_binarization_fwd_t<isa>::execute_forward() const {
242 auto src = reinterpret_cast<const src_data_t*>(this->input_memory(0));
243 auto weights = reinterpret_cast<const src_data_t*>(this->input_memory(1));
244 auto dst = reinterpret_cast<uint8_t*>(this->memory());
246 const memory_desc_wrapper src_d(pd()->src_pd());
247 const memory_desc_wrapper dst_d(pd()->dst_pd());
248 const memory_desc_wrapper weights_d(pd()->weights_pd(0));
250 const int N = src_d.dims()[0];
251 const int C = src_d.dims()[1];
252 const int H = src_d.dims()[2];
253 const int W = src_d.dims()[3];
258 [&](int n, int h, int w) {
259 auto arg = jit_args();
261 arg.from = &src[src_d.blk_off(n, 0, h, w)];
262 arg.to = &dst[dst_d.blk_off(n, 0, h, w) / nbits];
263 arg.weights = &weights[weights_d.blk_off(0)];
264 arg.work_amount = (size_t)C;
270 template struct jit_uni_binarization_fwd_t<sse42>;
271 template struct jit_uni_binarization_fwd_t<avx2>;
272 template struct jit_uni_binarization_fwd_t<avx512_common>;