Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_pooling.cpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "mkldnn_types.h"
18
19 #include "c_types_map.hpp"
20 #include "jit_uni_pooling.hpp"
21 #include "type_helpers.hpp"
22 #include "nstl.hpp"
23
24 namespace mkldnn {
25 namespace impl {
26 namespace cpu {
27
28 template <cpu_isa_t isa>
29 void jit_uni_pooling_fwd_t<isa>::execute_forward() const {
30     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
31     auto dst = reinterpret_cast<data_t*>(this->memory(0));
32     auto indices = pd()->desc()->alg_kind == alg_kind::pooling_max ?
33         reinterpret_cast<unsigned char *>(this->memory(1)) : nullptr;
34
35     const memory_desc_wrapper src_d(pd()->src_pd());
36     const memory_desc_wrapper dst_d(pd()->dst_pd());
37     const memory_desc_wrapper indices_d(pd()->workspace_pd());
38     const size_t ind_dt_size = indices
39         ? types::data_type_size(indices_d.data_type()) : 0;
40
41     const auto &jpp = pd()->jpp_;
42     int mb = pd()->MB();
43
44     auto ker = [&](int n, int b_c, int oh) {
45         auto arg = jit_pool_call_s();
46
47         const int ij = oh * jpp.stride_h;
48         const int i_t_overflow = nstl::max(0, jpp.t_pad-ij);
49         const int i_b_overflow = nstl::max(jpp.ih, ij+jpp.kh-jpp.t_pad)-jpp.ih;
50         const int ih = nstl::max(ij - jpp.t_pad, 0);
51
52         arg.src = &src[src_d.blk_off(n, b_c, ih)];
53         arg.dst = &dst[dst_d.blk_off(n, b_c, oh)];
54         if (indices) {
55             const size_t ind_off = indices_d.blk_off(n, b_c, oh);
56             arg.indices = &indices[ind_off * ind_dt_size];
57         }
58         arg.oh = oh == 0;
59         arg.kh_padding = jpp.kh - i_t_overflow - i_b_overflow;
60         arg.kh_padding_shift = i_t_overflow*jpp.kw;
61         arg.kw_padding = 0;
62         arg.ker_area_h = pd()->desc()->alg_kind == alg_kind::pooling_avg_exclude_padding
63              ?  (float)(jpp.kh - nstl::max(0, oh*jpp.stride_h - jpp.t_pad + jpp.kh - jpp.ih) -
64                 nstl::max(0, jpp.t_pad - oh*jpp.stride_h))
65              :  (float)(jpp.kh - nstl::max(0, oh*jpp.stride_h - jpp.t_pad + jpp.kh - jpp.ih - jpp.b_pad));
66
67         (*kernel_)(&arg);
68     };
69
70     parallel_nd(mb, jpp.nb_c, jpp.oh,
71         [&](int n, int b_c, int oh) {
72         ker(n, b_c, oh);
73     });
74 }
75
76 template <cpu_isa_t isa>
77 void jit_uni_pooling_fwd_t<isa>::execute_forward_3d() const {
78     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
79     auto dst = reinterpret_cast<data_t*>(this->memory(0));
80     auto indices = pd()->desc()->alg_kind == alg_kind::pooling_max ?
81         reinterpret_cast<unsigned char *>(this->memory(1)) : nullptr;
82
83     const memory_desc_wrapper src_d(pd()->src_pd());
84     const memory_desc_wrapper dst_d(pd()->dst_pd());
85     const memory_desc_wrapper indices_d(pd()->workspace_pd());
86     const size_t ind_dt_size = indices
87         ? types::data_type_size(indices_d.data_type()) : 0;
88
89     const auto &jpp = pd()->jpp_;
90     int mb = pd()->MB();
91
92     auto ker = [&](int n, int b_c, int od, int oh, int id, int d_t_overflow,
93             int d_b_overflow) {
94         auto arg = jit_pool_call_s();
95
96         const int ij = oh * jpp.stride_h;
97         const int i_t_overflow = nstl::max(0, jpp.t_pad-ij);
98         const int i_b_overflow = nstl::max(jpp.ih, ij+jpp.kh-jpp.t_pad)-jpp.ih;
99         const int ih = nstl::max(ij - jpp.t_pad, 0);
100
101         arg.src = &src[src_d.blk_off(n, b_c, id, ih)];
102         arg.dst = &dst[dst_d.blk_off(n, b_c, od, oh)];
103         if (indices) {
104             const size_t ind_off = indices_d.blk_off(n, b_c, od, oh);
105             arg.indices = &indices[ind_off * ind_dt_size];
106         }
107         arg.oh = (oh + od == 0);
108         arg.kd_padding = jpp.kd - d_t_overflow - d_b_overflow;
109         arg.kh_padding = jpp.kh - i_t_overflow - i_b_overflow;
110         arg.kh_padding_shift = i_t_overflow*jpp.kw + d_t_overflow*jpp.kw*jpp.kh;
111         arg.kd_padding_shift = (i_t_overflow + i_b_overflow)*jpp.kw;
112         arg.kw_padding = 0;
113         arg.ker_area_h = (float)(jpp.kh -
114             nstl::max(0, oh*jpp.stride_h - jpp.t_pad + jpp.kh - jpp.ih) -
115             nstl::max(0, jpp.t_pad - oh*jpp.stride_h)) * (jpp.kd -
116             nstl::max(0, od*jpp.stride_d - jpp.f_pad + jpp.kd - jpp.id) -
117             nstl::max(0, jpp.f_pad - od*jpp.stride_d));
118
119
120         (*kernel_)(&arg);
121     };
122
123     parallel_nd(mb, jpp.nb_c, jpp.od,
124         [&](int n, int b_c, int od) {
125         const int ik = od * jpp.stride_d;
126         const int d_t_overflow = nstl::max(0, jpp.f_pad-ik);
127         const int d_b_overflow = nstl::max(jpp.id, ik+jpp.kd-jpp.f_pad)
128             -jpp.id;
129         const int id = nstl::max(ik - jpp.f_pad, 0);
130         for (int oh = 0; oh < jpp.oh; ++oh) {
131             ker(n, b_c, od, oh, id, d_t_overflow, d_b_overflow);
132         }
133     });
134 }
135
136
137 template <cpu_isa_t isa>
138 void jit_uni_pooling_bwd_t<isa>::execute_backward() const {
139     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(0));
140     auto diff_src = reinterpret_cast<data_t*>(this->memory(0));
141     auto indices = pd()->desc()->alg_kind == alg_kind::pooling_max ?
142         reinterpret_cast<const char*>(this->input_memory(1)) : nullptr;
143
144     const memory_desc_wrapper diff_src_d(pd()->diff_src_pd());
145     const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
146     const memory_desc_wrapper indices_d(pd()->workspace_pd());
147     const size_t ind_dt_size = indices
148         ? types::data_type_size(indices_d.data_type()) : 0;
149
150     const auto &jpp = pd()->jpp_;
151     int mb = pd()->MB();
152
153     auto ker = [&](int n, int b_c, int oh) {
154         auto arg = jit_pool_call_s();
155
156         const int ij = oh * jpp.stride_h;
157         const int i_t_overflow = nstl::max(0, jpp.t_pad-ij);
158         const int i_b_overflow = nstl::max(jpp.ih, ij+jpp.kh-jpp.t_pad)-jpp.ih;
159         const int ih = nstl::max(ij - jpp.t_pad, 0);
160
161         arg.src = &diff_src[diff_src_d.blk_off(n, b_c, ih)];
162         arg.dst = &diff_dst[diff_dst_d.blk_off(n, b_c, oh)];
163         if (indices) {
164             const size_t ind_off = indices_d.blk_off(n, b_c, oh);
165             arg.indices = &indices[ind_off * ind_dt_size];
166         }
167         arg.oh = (oh == 0);
168         arg.kh_padding = jpp.kh - i_t_overflow - i_b_overflow;
169         arg.kh_padding_shift = i_t_overflow*jpp.kw;
170         arg.kw_padding = 0;
171         arg.ker_area_h = (float)(jpp.kh -
172             nstl::max(0, oh*jpp.stride_h - jpp.t_pad + jpp.kh - jpp.ih) -
173             nstl::max(0, jpp.t_pad - oh*jpp.stride_h));
174
175         (*kernel_)(&arg);
176     };
177
178     parallel_nd(mb, jpp.nb_c, [&](int n, int b_c) {
179         for (int oh = 0; oh < jpp.oh; ++oh) {
180             ker(n, b_c, oh);
181         }
182     });
183 }
184
185 template <cpu_isa_t isa>
186 void jit_uni_pooling_bwd_t<isa>::execute_backward_3d() const {
187     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(0));
188     auto diff_src = reinterpret_cast<data_t*>(this->memory(0));
189     auto indices = pd()->desc()->alg_kind == alg_kind::pooling_max ?
190         reinterpret_cast<const char*>(this->input_memory(1)) : nullptr;
191
192     const memory_desc_wrapper diff_src_d(pd()->diff_src_pd());
193     const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
194     const memory_desc_wrapper indices_d(pd()->workspace_pd());
195     const size_t ind_dt_size = indices
196         ? types::data_type_size(indices_d.data_type()) : 0;
197
198     const auto &jpp = pd()->jpp_;
199     int mb = pd()->MB();
200
201     auto ker = [&](int n, int b_c, int od, int oh, int id, int d_t_overflow,
202             int d_b_overflow, int zero_size, int kd) {
203         auto arg = jit_pool_call_s();
204
205         const int ij = oh * jpp.stride_h;
206         const int i_t_overflow = nstl::max(0, jpp.t_pad-ij);
207         const int i_b_overflow = nstl::max(jpp.ih, ij+jpp.kh-jpp.t_pad)-jpp.ih;
208         const int ih = nstl::max(ij - jpp.t_pad, 0);
209
210         arg.src = &diff_src[diff_src_d.blk_off(n, b_c, id + kd, ih)];
211         arg.dst = &diff_dst[diff_dst_d.blk_off(n, b_c, od, oh)];
212         if (indices) {
213             const size_t ind_off = indices_d.blk_off(n, b_c, od, oh);
214             arg.indices = &indices[ind_off * ind_dt_size];
215         }
216         arg.oh = zero_size;
217         arg.kd_padding = jpp.kd - d_t_overflow - d_b_overflow;
218         arg.kh_padding = jpp.kh - i_t_overflow - i_b_overflow;
219         arg.kh_padding_shift = i_t_overflow*jpp.kw + d_t_overflow*jpp.kw*jpp.kh
220             + kd * jpp.kw * jpp.kh;
221         arg.kd_padding_shift = (i_t_overflow + i_b_overflow)*jpp.kw;
222         arg.kw_padding = 0;
223         arg.ker_area_h = (float)(jpp.kh -
224             nstl::max(0, oh*jpp.stride_h - jpp.t_pad + jpp.kh - jpp.ih) -
225             nstl::max(0, jpp.t_pad - oh*jpp.stride_h)) * (jpp.kd -
226             nstl::max(0, od*jpp.stride_d - jpp.f_pad + jpp.kd - jpp.id) -
227             nstl::max(0, jpp.f_pad - od*jpp.stride_d));
228
229         (*kernel_)(&arg);
230     };
231
232     if (jpp.simple_alg) {
233
234         parallel_nd(mb, jpp.nb_c, jpp.od,
235             [&](int n, int b_c, int od) {
236             const int ik = od * jpp.stride_d;
237             const int d_t_overflow = nstl::max(0, jpp.f_pad - ik);
238             const int d_b_overflow = nstl::max(jpp.id, ik + jpp.kd
239                     - jpp.f_pad) - jpp.id;
240             const int id = nstl::max(ik - jpp.f_pad, 0);
241             int zero_s = jpp.stride_d - d_t_overflow - (nstl::max(
242                     jpp.id, ik + jpp.stride_d - jpp.f_pad) - jpp.id);
243             for (int oh = 0; oh < jpp.oh; ++oh) {
244                 ker(n, b_c, od, oh, id, d_t_overflow, d_b_overflow,
245                         (oh == 0) ? zero_s : 0, 0);
246             }
247         });
248     } else {
249         ptrdiff_t nelems = (ptrdiff_t)mb * (ptrdiff_t)jpp.c
250             * (ptrdiff_t)jpp.id * (ptrdiff_t)jpp.ih * (ptrdiff_t)jpp.iw;
251
252         parallel_nd(nelems, [&](ptrdiff_t i) { diff_src[i] = 0.f; });
253
254         for (int kd = 0; kd < jpp.kd; ++kd) {
255             parallel_nd(mb, jpp.nb_c, [&](int n, int b_c) {
256                 for (int od = 0; od < jpp.od; ++od) {
257                     const int ik = od * jpp.stride_d;
258                     const int d_t_overflow = nstl::max(0, jpp.f_pad-ik);
259                     const int d_b_overflow = nstl::max(jpp.id, ik + jpp.kd
260                             - jpp.f_pad) - jpp.id;
261                     if (kd >= jpp.kd - d_t_overflow - d_b_overflow)
262                         continue;
263                     const int id = nstl::max(ik - jpp.f_pad, 0);
264                     for (int oh = 0; oh < jpp.oh; ++oh) {
265                         ker(n, b_c, od, oh, id, d_t_overflow, d_b_overflow,
266                                 0, kd);
267                     }
268                 }
269             });
270         }
271     }
272 }
273
274
275 template struct jit_uni_pooling_fwd_t<sse42>;
276 template struct jit_uni_pooling_bwd_t<sse42>;
277 template struct jit_uni_pooling_fwd_t<avx>;
278 template struct jit_uni_pooling_bwd_t<avx>;
279 template struct jit_uni_pooling_fwd_t<avx512_common>;
280 template struct jit_uni_pooling_bwd_t<avx512_common>;
281
282 }
283 }
284 }
285
286 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s