1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "mkldnn_types.h"
19 #include "c_types_map.hpp"
20 #include "jit_uni_pooling.hpp"
21 #include "type_helpers.hpp"
28 template <cpu_isa_t isa>
29 void jit_uni_pooling_fwd_t<isa>::execute_forward() const {
30 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
31 auto dst = reinterpret_cast<data_t*>(this->memory(0));
32 auto indices = pd()->desc()->alg_kind == alg_kind::pooling_max ?
33 reinterpret_cast<unsigned char *>(this->memory(1)) : nullptr;
35 const memory_desc_wrapper src_d(pd()->src_pd());
36 const memory_desc_wrapper dst_d(pd()->dst_pd());
37 const memory_desc_wrapper indices_d(pd()->workspace_pd());
38 const size_t ind_dt_size = indices
39 ? types::data_type_size(indices_d.data_type()) : 0;
41 const auto &jpp = pd()->jpp_;
44 auto ker = [&](int n, int b_c, int oh) {
45 auto arg = jit_pool_call_s();
47 const int ij = oh * jpp.stride_h;
48 const int i_t_overflow = nstl::max(0, jpp.t_pad-ij);
49 const int i_b_overflow = nstl::max(jpp.ih, ij+jpp.kh-jpp.t_pad)-jpp.ih;
50 const int ih = nstl::max(ij - jpp.t_pad, 0);
52 arg.src = &src[src_d.blk_off(n, b_c, ih)];
53 arg.dst = &dst[dst_d.blk_off(n, b_c, oh)];
55 const size_t ind_off = indices_d.blk_off(n, b_c, oh);
56 arg.indices = &indices[ind_off * ind_dt_size];
59 arg.kh_padding = jpp.kh - i_t_overflow - i_b_overflow;
60 arg.kh_padding_shift = i_t_overflow*jpp.kw;
62 arg.ker_area_h = pd()->desc()->alg_kind == alg_kind::pooling_avg_exclude_padding
63 ? (float)(jpp.kh - nstl::max(0, oh*jpp.stride_h - jpp.t_pad + jpp.kh - jpp.ih) -
64 nstl::max(0, jpp.t_pad - oh*jpp.stride_h))
65 : (float)(jpp.kh - nstl::max(0, oh*jpp.stride_h - jpp.t_pad + jpp.kh - jpp.ih - jpp.b_pad));
70 parallel_nd(mb, jpp.nb_c, jpp.oh,
71 [&](int n, int b_c, int oh) {
76 template <cpu_isa_t isa>
77 void jit_uni_pooling_fwd_t<isa>::execute_forward_3d() const {
78 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
79 auto dst = reinterpret_cast<data_t*>(this->memory(0));
80 auto indices = pd()->desc()->alg_kind == alg_kind::pooling_max ?
81 reinterpret_cast<unsigned char *>(this->memory(1)) : nullptr;
83 const memory_desc_wrapper src_d(pd()->src_pd());
84 const memory_desc_wrapper dst_d(pd()->dst_pd());
85 const memory_desc_wrapper indices_d(pd()->workspace_pd());
86 const size_t ind_dt_size = indices
87 ? types::data_type_size(indices_d.data_type()) : 0;
89 const auto &jpp = pd()->jpp_;
92 auto ker = [&](int n, int b_c, int od, int oh, int id, int d_t_overflow,
94 auto arg = jit_pool_call_s();
96 const int ij = oh * jpp.stride_h;
97 const int i_t_overflow = nstl::max(0, jpp.t_pad-ij);
98 const int i_b_overflow = nstl::max(jpp.ih, ij+jpp.kh-jpp.t_pad)-jpp.ih;
99 const int ih = nstl::max(ij - jpp.t_pad, 0);
101 arg.src = &src[src_d.blk_off(n, b_c, id, ih)];
102 arg.dst = &dst[dst_d.blk_off(n, b_c, od, oh)];
104 const size_t ind_off = indices_d.blk_off(n, b_c, od, oh);
105 arg.indices = &indices[ind_off * ind_dt_size];
107 arg.oh = (oh + od == 0);
108 arg.kd_padding = jpp.kd - d_t_overflow - d_b_overflow;
109 arg.kh_padding = jpp.kh - i_t_overflow - i_b_overflow;
110 arg.kh_padding_shift = i_t_overflow*jpp.kw + d_t_overflow*jpp.kw*jpp.kh;
111 arg.kd_padding_shift = (i_t_overflow + i_b_overflow)*jpp.kw;
113 arg.ker_area_h = (float)(jpp.kh -
114 nstl::max(0, oh*jpp.stride_h - jpp.t_pad + jpp.kh - jpp.ih) -
115 nstl::max(0, jpp.t_pad - oh*jpp.stride_h)) * (jpp.kd -
116 nstl::max(0, od*jpp.stride_d - jpp.f_pad + jpp.kd - jpp.id) -
117 nstl::max(0, jpp.f_pad - od*jpp.stride_d));
123 parallel_nd(mb, jpp.nb_c, jpp.od,
124 [&](int n, int b_c, int od) {
125 const int ik = od * jpp.stride_d;
126 const int d_t_overflow = nstl::max(0, jpp.f_pad-ik);
127 const int d_b_overflow = nstl::max(jpp.id, ik+jpp.kd-jpp.f_pad)
129 const int id = nstl::max(ik - jpp.f_pad, 0);
130 for (int oh = 0; oh < jpp.oh; ++oh) {
131 ker(n, b_c, od, oh, id, d_t_overflow, d_b_overflow);
137 template <cpu_isa_t isa>
138 void jit_uni_pooling_bwd_t<isa>::execute_backward() const {
139 auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(0));
140 auto diff_src = reinterpret_cast<data_t*>(this->memory(0));
141 auto indices = pd()->desc()->alg_kind == alg_kind::pooling_max ?
142 reinterpret_cast<const char*>(this->input_memory(1)) : nullptr;
144 const memory_desc_wrapper diff_src_d(pd()->diff_src_pd());
145 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
146 const memory_desc_wrapper indices_d(pd()->workspace_pd());
147 const size_t ind_dt_size = indices
148 ? types::data_type_size(indices_d.data_type()) : 0;
150 const auto &jpp = pd()->jpp_;
153 auto ker = [&](int n, int b_c, int oh) {
154 auto arg = jit_pool_call_s();
156 const int ij = oh * jpp.stride_h;
157 const int i_t_overflow = nstl::max(0, jpp.t_pad-ij);
158 const int i_b_overflow = nstl::max(jpp.ih, ij+jpp.kh-jpp.t_pad)-jpp.ih;
159 const int ih = nstl::max(ij - jpp.t_pad, 0);
161 arg.src = &diff_src[diff_src_d.blk_off(n, b_c, ih)];
162 arg.dst = &diff_dst[diff_dst_d.blk_off(n, b_c, oh)];
164 const size_t ind_off = indices_d.blk_off(n, b_c, oh);
165 arg.indices = &indices[ind_off * ind_dt_size];
168 arg.kh_padding = jpp.kh - i_t_overflow - i_b_overflow;
169 arg.kh_padding_shift = i_t_overflow*jpp.kw;
171 arg.ker_area_h = (float)(jpp.kh -
172 nstl::max(0, oh*jpp.stride_h - jpp.t_pad + jpp.kh - jpp.ih) -
173 nstl::max(0, jpp.t_pad - oh*jpp.stride_h));
178 parallel_nd(mb, jpp.nb_c, [&](int n, int b_c) {
179 for (int oh = 0; oh < jpp.oh; ++oh) {
185 template <cpu_isa_t isa>
186 void jit_uni_pooling_bwd_t<isa>::execute_backward_3d() const {
187 auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(0));
188 auto diff_src = reinterpret_cast<data_t*>(this->memory(0));
189 auto indices = pd()->desc()->alg_kind == alg_kind::pooling_max ?
190 reinterpret_cast<const char*>(this->input_memory(1)) : nullptr;
192 const memory_desc_wrapper diff_src_d(pd()->diff_src_pd());
193 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
194 const memory_desc_wrapper indices_d(pd()->workspace_pd());
195 const size_t ind_dt_size = indices
196 ? types::data_type_size(indices_d.data_type()) : 0;
198 const auto &jpp = pd()->jpp_;
201 auto ker = [&](int n, int b_c, int od, int oh, int id, int d_t_overflow,
202 int d_b_overflow, int zero_size, int kd) {
203 auto arg = jit_pool_call_s();
205 const int ij = oh * jpp.stride_h;
206 const int i_t_overflow = nstl::max(0, jpp.t_pad-ij);
207 const int i_b_overflow = nstl::max(jpp.ih, ij+jpp.kh-jpp.t_pad)-jpp.ih;
208 const int ih = nstl::max(ij - jpp.t_pad, 0);
210 arg.src = &diff_src[diff_src_d.blk_off(n, b_c, id + kd, ih)];
211 arg.dst = &diff_dst[diff_dst_d.blk_off(n, b_c, od, oh)];
213 const size_t ind_off = indices_d.blk_off(n, b_c, od, oh);
214 arg.indices = &indices[ind_off * ind_dt_size];
217 arg.kd_padding = jpp.kd - d_t_overflow - d_b_overflow;
218 arg.kh_padding = jpp.kh - i_t_overflow - i_b_overflow;
219 arg.kh_padding_shift = i_t_overflow*jpp.kw + d_t_overflow*jpp.kw*jpp.kh
220 + kd * jpp.kw * jpp.kh;
221 arg.kd_padding_shift = (i_t_overflow + i_b_overflow)*jpp.kw;
223 arg.ker_area_h = (float)(jpp.kh -
224 nstl::max(0, oh*jpp.stride_h - jpp.t_pad + jpp.kh - jpp.ih) -
225 nstl::max(0, jpp.t_pad - oh*jpp.stride_h)) * (jpp.kd -
226 nstl::max(0, od*jpp.stride_d - jpp.f_pad + jpp.kd - jpp.id) -
227 nstl::max(0, jpp.f_pad - od*jpp.stride_d));
232 if (jpp.simple_alg) {
234 parallel_nd(mb, jpp.nb_c, jpp.od,
235 [&](int n, int b_c, int od) {
236 const int ik = od * jpp.stride_d;
237 const int d_t_overflow = nstl::max(0, jpp.f_pad - ik);
238 const int d_b_overflow = nstl::max(jpp.id, ik + jpp.kd
239 - jpp.f_pad) - jpp.id;
240 const int id = nstl::max(ik - jpp.f_pad, 0);
241 int zero_s = jpp.stride_d - d_t_overflow - (nstl::max(
242 jpp.id, ik + jpp.stride_d - jpp.f_pad) - jpp.id);
243 for (int oh = 0; oh < jpp.oh; ++oh) {
244 ker(n, b_c, od, oh, id, d_t_overflow, d_b_overflow,
245 (oh == 0) ? zero_s : 0, 0);
249 ptrdiff_t nelems = (ptrdiff_t)mb * (ptrdiff_t)jpp.c
250 * (ptrdiff_t)jpp.id * (ptrdiff_t)jpp.ih * (ptrdiff_t)jpp.iw;
252 parallel_nd(nelems, [&](ptrdiff_t i) { diff_src[i] = 0.f; });
254 for (int kd = 0; kd < jpp.kd; ++kd) {
255 parallel_nd(mb, jpp.nb_c, [&](int n, int b_c) {
256 for (int od = 0; od < jpp.od; ++od) {
257 const int ik = od * jpp.stride_d;
258 const int d_t_overflow = nstl::max(0, jpp.f_pad-ik);
259 const int d_b_overflow = nstl::max(jpp.id, ik + jpp.kd
260 - jpp.f_pad) - jpp.id;
261 if (kd >= jpp.kd - d_t_overflow - d_b_overflow)
263 const int id = nstl::max(ik - jpp.f_pad, 0);
264 for (int oh = 0; oh < jpp.oh; ++oh) {
265 ker(n, b_c, od, oh, id, d_t_overflow, d_b_overflow,
275 template struct jit_uni_pooling_fwd_t<sse42>;
276 template struct jit_uni_pooling_bwd_t<sse42>;
277 template struct jit_uni_pooling_fwd_t<avx>;
278 template struct jit_uni_pooling_bwd_t<avx>;
279 template struct jit_uni_pooling_fwd_t<avx512_common>;
280 template struct jit_uni_pooling_bwd_t<avx512_common>;
286 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s