1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
19 #include "memory_pd.hpp"
20 #include "mkldnn_traits.hpp"
21 #include "mkldnn_thread.hpp"
22 #include "type_helpers.hpp"
25 #include "format_traits.hpp"
27 #include "cpu_memory.hpp"
33 using namespace mkldnn::impl;
34 using namespace mkldnn::impl::data_type;
35 using namespace mkldnn::impl::status;
36 using namespace mkldnn::impl::memory_format;
38 using dk = data_kind_t;
39 using bf = block_format_t;
41 template <data_type_t dt, memory_format_t fmt>
42 typename utils::enable_if<format_traits<fmt>::data_kind == dk::data>::type
44 const memory_desc_wrapper &m_d, typename prec_traits<dt>::type *data) {
45 constexpr int blksize = format_traits<fmt>::blk_size;
47 const auto &dims = m_d.dims();
48 const auto &pdims = m_d.blocking_desc().padding_dims;
50 const int C = pdims[1] / blksize - 1;
51 const int c_tail_start = dims[1] % blksize;
52 assert(c_tail_start != 0);
53 const size_t sp_rest = utils::array_product(dims + 3, m_d.ndims() - 3);
55 parallel_nd(dims[0], dims[2], [&](int n, int sp0) {
56 auto *d = &data[m_d.blk_off(n, C, sp0)];
57 for (size_t sp = 0; sp < sp_rest; ++sp) {
58 for (int c = c_tail_start; c < blksize; ++c)
59 d[sp * blksize + c] = 0;
64 template <data_type_t dt, memory_format_t fmt>
65 typename utils::enable_if<false
66 || format_traits<fmt>::blk_fmt == bf::_4o
67 || format_traits<fmt>::blk_fmt == bf::_8o
68 || format_traits<fmt>::blk_fmt == bf::_16o
69 >::type typed_zero_pad_weights(const memory_desc_wrapper &m_d,
70 typename prec_traits<dt>::type *data) {
71 static constexpr int w_groups = format_traits<fmt>::data_kind == dk::gwei;
72 constexpr int is_1d = format_traits<fmt>::ndims_sp == 1;
73 constexpr int is_3d = format_traits<fmt>::ndims_sp == 3;
74 constexpr int blksize = format_traits<fmt>::blk_size;
76 const auto &dims = m_d.dims();
77 const auto &pdims = m_d.blocking_desc().padding_dims;
79 const int G = w_groups ? dims[0] : 1;
80 const int NB_OC = pdims[w_groups + 0] / blksize;
81 const int IC = dims[w_groups + 1];
82 const int D = is_3d ? dims[w_groups + 2] : 1;
83 const int H = is_1d ? 1 : dims[w_groups + 2 + is_3d];
84 const int W = dims[w_groups + 3 - is_1d + is_3d];
86 const int oc_tail = pdims[w_groups + 0] - dims[w_groups + 0];
88 parallel_nd(G, IC, D, H, W,
89 [&](int g, int ic, int d, int h, int w) {
90 auto x = &data[wei_blk_off_like_gwei3D<fmt>(m_d,
91 g, NB_OC - 1, ic, d, h, w)];
92 for (int oc = blksize - oc_tail; oc < blksize; ++oc)
97 template <data_type_t dt, memory_format_t fmt>
98 typename utils::enable_if<false
99 || format_traits<fmt>::blk_fmt == bf::_8i
100 || format_traits<fmt>::blk_fmt == bf::_16i
101 >::type typed_zero_pad_weights(const memory_desc_wrapper &m_d,
102 typename prec_traits<dt>::type *data) {
103 static constexpr int w_groups = format_traits<fmt>::data_kind == dk::gwei;
104 constexpr int is_1d = format_traits<fmt>::ndims_sp == 1;
105 constexpr int is_3d = format_traits<fmt>::ndims_sp == 3;
106 constexpr int blksize = format_traits<fmt>::blk_size;
108 const auto &dims = m_d.dims();
109 const auto &pdims = m_d.blocking_desc().padding_dims;
111 const int G = w_groups ? dims[0] : 1;
112 const int OC = dims[w_groups + 0];
113 const int NB_IC = pdims[w_groups + 1] / blksize;
114 const int D = is_3d ? dims[w_groups + 2] : 1;
115 const int H = is_1d ? 1 : dims[w_groups + 2 + is_3d];
116 const int W = dims[w_groups + 3 + is_3d];
118 const int ic_tail = pdims[w_groups + 1] - dims[w_groups + 1];
120 parallel_nd(G, OC, D, H, W,
121 [&](int g, int oc, int d, int h, int w) {
122 auto x = &data[wei_blk_off_like_gwei3D<fmt>(m_d,
123 g, oc, NB_IC - 1, d, h, w)];
124 for (int ic = blksize - ic_tail; ic < blksize; ++ic)
129 template <data_type_t dt, memory_format_t fmt>
130 typename utils::enable_if<
131 block_format_traits<format_traits<fmt>::blk_fmt>::blk_ndims == 2>::type
132 typed_zero_pad_weights(const memory_desc_wrapper &m_d,
133 typename prec_traits<dt>::type *data) {
134 using data_t = typename prec_traits<dt>::type;
135 static constexpr int w_groups = format_traits<fmt>::data_kind == dk::gwei;
136 constexpr int is_1d = format_traits<fmt>::ndims_sp == 1;
137 constexpr int is_3d = format_traits<fmt>::ndims_sp == 3;
138 constexpr int blksize = format_traits<fmt>::blk_size;
139 const auto &dims = m_d.dims();
140 const auto &pdims = m_d.blocking_desc().padding_dims;
142 const int G = w_groups ? dims[0] : 1;
143 const int NB_OC = pdims[w_groups + 0] / blksize;
144 const int NB_IC = pdims[w_groups + 1] / blksize;
145 const int D = is_3d ? dims[w_groups + 2] : 1;
146 const int H = is_1d ? 1 : dims[w_groups + 2 + is_3d];
147 const int W = dims[w_groups + 3 - is_1d + is_3d];
149 auto ker = [&](data_t *d, const int oc_tail, const int ic_tail) {
150 # define blk_off OI_blk_off<format_traits<fmt>::blk_fmt>
152 for (; oc < blksize - oc_tail; ++oc) {
153 for (int ic = blksize - ic_tail; ic < blksize; ++ic)
154 d[blk_off(oc, ic)] = 0;
156 for (; oc < blksize; ++oc)
157 for (int ic = 0; ic < blksize; ++ic)
158 d[blk_off(oc, ic)] = 0;
162 const int oc_tail = pdims[w_groups + 0] - dims[w_groups + 0];
163 const int ic_tail = pdims[w_groups + 1] - dims[w_groups + 1];
166 parallel_nd(G, NB_OC, D, H, W,
167 [&](int g, int nb_oc, int d, int h, int w) {
168 auto x = &data[wei_blk_off_like_gwei3D<fmt>(m_d,
169 g, nb_oc, NB_IC - 1, d, h, w)];
175 parallel_nd(G, NB_IC, D, H, W,
176 [&](int g, int nb_ic, int d, int h, int w) {
177 auto x = &data[wei_blk_off_like_gwei3D<fmt>(m_d,
178 g, NB_OC - 1, nb_ic, d, h, w)];
184 template <data_type_t dt, memory_format_t fmt>
185 typename utils::enable_if<false
186 || format_traits<fmt>::blk_fmt == bf::_8g
187 || format_traits<fmt>::blk_fmt == bf::_16g
188 >::type typed_zero_pad_weights(const memory_desc_wrapper &m_d,
189 typename prec_traits<dt>::type *data) {
190 constexpr int blksize = format_traits<fmt>::blk_size;
192 const auto &dims = m_d.dims();
193 const auto &pdims = m_d.blocking_desc().padding_dims;
195 const int G = pdims[0] / blksize - 1;
196 const int g_tail_start = dims[0] % blksize;
197 assert(g_tail_start != 0);
198 const ptrdiff_t sz_rest
199 = (ptrdiff_t)utils::array_product(dims + 1, m_d.ndims() - 1);
201 auto *d = &data[m_d.blk_off(G)];
203 parallel_nd(sz_rest, [&](ptrdiff_t s) {
204 for (int g = g_tail_start; g < blksize; ++g)
205 d[s * blksize + g] = 0;
209 template <data_type_t dt>
210 void typed_zero_pad_generic_blocked(const memory_desc_wrapper &m_d,
211 typename prec_traits<dt>::type *data) {
212 const int ndims = m_d.ndims();
213 const auto &dims = m_d.dims();
214 const auto &pdims = m_d.blocking_desc().padding_dims;
216 const ptrdiff_t nelems = (ptrdiff_t)m_d.nelems(true);
218 /* [D_0] .. [D_k][D_k+1] .. [D_ndim - 1]
220 * | ---------------------
224 * step <-- D_k+1 * ... * D_ndims-1
229 int step_dim = ndims - 1;
230 for (; step_dim >= 0; --step_dim) {
231 if (dims[step_dim] != pdims[step_dim]) break;
232 step *= dims[step_dim];
235 assert(step_dim >= 0 && "no zero padding is required");
236 if (step_dim < 0) return;
238 parallel_nd(nelems / step, [&](ptrdiff_t e1) {
239 bool need_zero = false;
242 for (int d = step_dim; d >= 0; --d) {
243 if (idx % pdims[d] >= dims[d]) {
251 for (ptrdiff_t e0 = 0; e0 < step; ++e0)
252 data[m_d.off_l(e1 * step + e0, true)] = 0;
257 template <data_type_t dt>
258 status_t cpu_memory_t::typed_zero_pad() const {
259 const memory_desc_wrapper mpd(pd());
261 // FIXME: guard this check for non-blocked layout
262 if (mpd.nelems(false) == mpd.nelems(true))
265 auto *data = (typename prec_traits<dt>::type *)data_;
266 const auto fmt = mpd.format();
269 # define MAYBE_DATA(f) if (fmt == f) \
270 { typed_zero_pad_data<dt, f>(mpd, data); return success; }
279 MAYBE_DATA(nCdhw16c);
282 # define MAYBE_WEIGHTS(f) if (fmt == f) \
283 { typed_zero_pad_weights<dt, f>(mpd, data); return success; }
284 MAYBE_WEIGHTS(OIdhw4i4o);
285 MAYBE_WEIGHTS(OIdhw8i8o);
286 MAYBE_WEIGHTS(OIdhw8o8i);
287 MAYBE_WEIGHTS(OIdhw16i16o);
288 MAYBE_WEIGHTS(OIdhw16o16i);
289 MAYBE_WEIGHTS(Oidhw4o);
290 MAYBE_WEIGHTS(Oidhw16o);
291 MAYBE_WEIGHTS(Odhwi16o);
292 MAYBE_WEIGHTS(Odhwi8o);
293 MAYBE_WEIGHTS(oIhw8i);
294 MAYBE_WEIGHTS(oIhw16i);
295 MAYBE_WEIGHTS(oIdhw8i);
296 MAYBE_WEIGHTS(oIdhw16i);
297 MAYBE_WEIGHTS(OIhw4i4o);
298 MAYBE_WEIGHTS(OIhw8i8o);
299 MAYBE_WEIGHTS(OIhw16i16o);
300 MAYBE_WEIGHTS(OIhw4i16o4i);
301 MAYBE_WEIGHTS(OIhw4i16o4i_s8s8);
302 MAYBE_WEIGHTS(OIw4i4o);
303 MAYBE_WEIGHTS(Owi8o);
304 MAYBE_WEIGHTS(OIw8i8o);
305 MAYBE_WEIGHTS(OIw8o8i);
306 MAYBE_WEIGHTS(OIw16i16o);
307 MAYBE_WEIGHTS(OIw16o16i);
308 MAYBE_WEIGHTS(Oiw4o);
309 MAYBE_WEIGHTS(Oiw16o);
310 MAYBE_WEIGHTS(Owi16o);
311 MAYBE_WEIGHTS(OIw8i16o2i);
312 MAYBE_WEIGHTS(OIw8o16i2o);
313 MAYBE_WEIGHTS(IOw16o16i);
314 MAYBE_WEIGHTS(OIhw8i16o2i);
315 MAYBE_WEIGHTS(OIdhw8i16o2i);
316 MAYBE_WEIGHTS(OIhw8o16i2o);
317 MAYBE_WEIGHTS(OIhw8o8i);
318 MAYBE_WEIGHTS(OIhw16o16i);
319 MAYBE_WEIGHTS(IOhw16o16i);
320 MAYBE_WEIGHTS(Oihw4o);
321 MAYBE_WEIGHTS(Oihw16o);
322 MAYBE_WEIGHTS(Ohwi8o);
323 MAYBE_WEIGHTS(Ohwi4o);
324 MAYBE_WEIGHTS(Ohwi16o);
325 MAYBE_WEIGHTS(gOIhw4o4i_s8s8);
326 MAYBE_WEIGHTS(gOIhw4o4i_s8s8);
327 MAYBE_WEIGHTS(gOIhw4i4o);
328 MAYBE_WEIGHTS(gOIhw8i8o);
329 MAYBE_WEIGHTS(gOIhw16i16o);
330 MAYBE_WEIGHTS(gOIhw4i16o4i);
331 MAYBE_WEIGHTS(gOIhw4i16o4i_s8s8);
332 MAYBE_WEIGHTS(gOIhw2i8o4i);
333 MAYBE_WEIGHTS(gOIhw2i8o4i_s8s8);
334 MAYBE_WEIGHTS(gOIw4i4o);
335 MAYBE_WEIGHTS(gOwi8o);
336 MAYBE_WEIGHTS(gOIw8i8o);
337 MAYBE_WEIGHTS(gOIw8o8i);
338 MAYBE_WEIGHTS(gOIw16i16o);
339 MAYBE_WEIGHTS(gOIw16o16i);
340 MAYBE_WEIGHTS(gOiw4o);
341 MAYBE_WEIGHTS(gOiw16o);
342 MAYBE_WEIGHTS(gOwi16o);
343 MAYBE_WEIGHTS(gOIw8i16o2i);
344 MAYBE_WEIGHTS(gOIw8o16i2o);
345 MAYBE_WEIGHTS(gIOw16o16i);
346 MAYBE_WEIGHTS(gOIhw8i16o2i);
347 MAYBE_WEIGHTS(gOIdhw8i16o2i);
348 MAYBE_WEIGHTS(gOIhw8o16i2o);
349 MAYBE_WEIGHTS(gOIhw8o8i);
350 MAYBE_WEIGHTS(gOIhw16o16i);
351 MAYBE_WEIGHTS(gIOhw16o16i);
352 MAYBE_WEIGHTS(gOihw4o);
353 MAYBE_WEIGHTS(gOihw16o);
354 MAYBE_WEIGHTS(gOhwi8o);
355 MAYBE_WEIGHTS(gOhwi4o);
356 MAYBE_WEIGHTS(gOhwi16o);
357 MAYBE_WEIGHTS(gOIdhw4i4o);
358 MAYBE_WEIGHTS(gOIdhw8i8o);
359 MAYBE_WEIGHTS(gOIdhw8o8i);
360 MAYBE_WEIGHTS(gOIdhw16i16o);
361 MAYBE_WEIGHTS(gOIdhw16o16i);
362 MAYBE_WEIGHTS(gOidhw4o);
363 MAYBE_WEIGHTS(gOidhw16o);
364 MAYBE_WEIGHTS(gOdhwi16o);
365 MAYBE_WEIGHTS(gOdhwi8o);
366 MAYBE_WEIGHTS(Goihw8g);
367 MAYBE_WEIGHTS(Goihw16g);
368 # undef MAYBE_WEIGHTS
370 // the last line of defence
371 if (types::format_normalize(fmt) == blocked) {
372 typed_zero_pad_generic_blocked<dt>(mpd, data);
376 return unimplemented;
379 status_t cpu_memory_t::zero_pad() const {
380 memory_desc_wrapper md(pd());
381 const bool skip_zeroing = false
384 || !md.is_blocking_desc();
385 if (skip_zeroing) return success;
387 switch (md.data_type()) {
388 case f32: return typed_zero_pad<f32>();
389 case s32: return typed_zero_pad<s32>();
390 case s16: return typed_zero_pad<s16>();
391 case s8: return typed_zero_pad<s8>();
392 case u8: return typed_zero_pad<u8>();
393 case bin: return typed_zero_pad<u8>();
394 default: assert(!"memory is undefined"); return unimplemented;
396 return unimplemented;