updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_roi_pool_kernel_f32.hpp
1 /*******************************************************************************
2 * Copyright 2017 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef CPU_JIT_UNI_ROI_POOL_KERNEL_F32_HPP
18 #define CPU_JIT_UNI_ROI_POOL_KERNEL_F32_HPP
19
20 #include <cfloat>
21
22 #include "c_types_map.hpp"
23 #include "jit_generator.hpp"
24 #include "type_helpers.hpp"
25
26 #include "jit_primitive_conf.hpp"
27
28 namespace mkldnn {
29 namespace impl {
30 namespace cpu {
31
32 using namespace Xbyak;
33
34 template <cpu_isa_t isa>
35 struct jit_uni_roi_pool_kernel_f32: public jit_generator {
36     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_roi_pool_kernel_f32)
37
38     jit_uni_roi_pool_kernel_f32(jit_roi_pool_conf_t ajpp): jpp(ajpp)
39     {
40         this->generate();
41         jit_ker = (decltype(jit_ker))this->getCode();
42     }
43
44     jit_roi_pool_conf_t jpp;
45
46     void operator()(jit_roi_pool_call_s *arg) { jit_ker(arg); }
47     static status_t init_conf(jit_roi_pool_conf_t &jbp,
48             const roi_pooling_desc_t &pd, const memory_desc_wrapper &src_d,
49             const memory_desc_wrapper &dst_d);
50
51 private:
52     using Vmm = typename utils::conditional3<isa == sse42, Xmm, isa == avx2, Ymm, Zmm>::type;
53
54     Vmm vmm_mask = Vmm(0);
55     Vmm vmm_zero = Vmm(0);
56
57     Xmm xmm_yf = Xmm(0);
58     Vmm vmm_yf = Vmm(0);
59     Xmm xmm_xf = Xmm(1);
60     Vmm vmm_xf = Vmm(1);
61
62     Vmm get_acc_reg(int idx) { return Vmm(2*idx + 1); }
63     Vmm get_src_reg(int idx) { return Vmm(2*idx + 2); }
64
65     Opmask k_store_mask = Opmask(7);
66
67     const unsigned char _cmp_lt_os = 1;
68
69     using reg64_t = const Xbyak::Reg64;
70     reg64_t reg_input     = r8;
71     reg64_t aux_reg_input = rax;
72     reg64_t aux_reg_input1 = rdx;
73     reg64_t reg_output    = r9;
74     reg64_t reg_kh    = r10;
75     reg64_t reg_kw    = r11;
76
77     reg64_t h_iter = r14;
78     reg64_t w_iter = r15;
79
80     reg64_t reg_c_blocks = rbx;
81     reg64_t reg_bin_area = rdx;
82
83     reg64_t reg_yf = reg_kh;
84     reg64_t reg_xf = reg_kw;
85
86     reg64_t reg_yoff = h_iter;
87     reg64_t reg_xoff = r12;
88
89     void (*jit_ker)(jit_roi_pool_call_s *);
90
91     void roi_pool_max(int c_blocks);
92     void roi_pool_bilinear(int c_blocks);
93     void empty_roi(int c_blocks);
94     void loop_body(int c_blocks);
95
96     void generate();
97 };
98
99 }
100 }
101 }
102
103 #endif