1 /*******************************************************************************
2 * Copyright 2017 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef CPU_JIT_UNI_ROI_POOL_KERNEL_F32_HPP
18 #define CPU_JIT_UNI_ROI_POOL_KERNEL_F32_HPP
22 #include "c_types_map.hpp"
23 #include "jit_generator.hpp"
24 #include "type_helpers.hpp"
26 #include "jit_primitive_conf.hpp"
32 using namespace Xbyak;
34 template <cpu_isa_t isa>
35 struct jit_uni_roi_pool_kernel_f32: public jit_generator {
36 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_roi_pool_kernel_f32)
38 jit_uni_roi_pool_kernel_f32(jit_roi_pool_conf_t ajpp): jpp(ajpp)
41 jit_ker = (decltype(jit_ker))this->getCode();
44 jit_roi_pool_conf_t jpp;
46 void operator()(jit_roi_pool_call_s *arg) { jit_ker(arg); }
47 static status_t init_conf(jit_roi_pool_conf_t &jbp,
48 const roi_pooling_desc_t &pd, const memory_desc_wrapper &src_d,
49 const memory_desc_wrapper &dst_d);
52 using Vmm = typename utils::conditional3<isa == sse42, Xmm, isa == avx2, Ymm, Zmm>::type;
54 Vmm vmm_mask = Vmm(0);
55 Vmm vmm_zero = Vmm(0);
62 Vmm get_acc_reg(int idx) { return Vmm(2*idx + 1); }
63 Vmm get_src_reg(int idx) { return Vmm(2*idx + 2); }
65 Opmask k_store_mask = Opmask(7);
67 const unsigned char _cmp_lt_os = 1;
69 using reg64_t = const Xbyak::Reg64;
70 reg64_t reg_input = r8;
71 reg64_t aux_reg_input = rax;
72 reg64_t aux_reg_input1 = rdx;
73 reg64_t reg_output = r9;
80 reg64_t reg_c_blocks = rbx;
81 reg64_t reg_bin_area = rdx;
83 reg64_t reg_yf = reg_kh;
84 reg64_t reg_xf = reg_kw;
86 reg64_t reg_yoff = h_iter;
87 reg64_t reg_xoff = r12;
89 void (*jit_ker)(jit_roi_pool_call_s *);
91 void roi_pool_max(int c_blocks);
92 void roi_pool_bilinear(int c_blocks);
93 void empty_roi(int c_blocks);
94 void loop_body(int c_blocks);