1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
19 #include "memory_pd.hpp"
20 #include "mkldnn_traits.hpp"
21 #include "mkldnn_thread.hpp"
22 #include "type_helpers.hpp"
25 #include "cpu_memory.hpp"
31 using namespace mkldnn::impl;
32 using namespace mkldnn::impl::data_type;
33 using namespace mkldnn::impl::status;
34 using namespace mkldnn::impl::memory_format;
36 template <data_type_t dt, memory_format_t fmt>
37 typename utils::enable_if<fmt == nChw8c || fmt == nChw16c || fmt == nCdhw16c
38 >::type typed_zero_pad_data(const memory_desc_wrapper &m_d,
39 typename prec_traits<dt>::type *data) {
40 constexpr int blksize = fmt == nChw8c ? 8 : 16;
42 const auto &dims = m_d.dims();
43 const auto &pdims = m_d.blocking_desc().padding_dims;
45 const int C = pdims[1] / blksize - 1;
46 const int c_tail_start = dims[1] % blksize;
47 assert(c_tail_start != 0);
48 const size_t sp_rest = utils::array_product(dims + 3, m_d.ndims() - 3);
50 # pragma omp parallel for collapse(2) schedule(static)
51 for (int n = 0; n < dims[0]; ++n)
52 for (int sp0 = 0; sp0 < dims[2]; ++sp0) {
53 auto *d = &data[m_d.blk_off(n, C, sp0)];
54 for (size_t sp = 0; sp < sp_rest; ++sp) {
55 for (int c = c_tail_start; c < blksize; ++c)
56 d[sp * blksize + c] = 0;
61 template <data_type_t dt, memory_format_t fmt>
62 typename utils::enable_if<false
63 || fmt == Ohwi8o || fmt == Oihw16o || fmt == Ohwi16o || fmt == Oidhw16o
64 || fmt == Odhwi16o || fmt == gOhwi8o || fmt == gOihw16o || fmt == gOhwi16o
65 || fmt == gOidhw16o || fmt == gOdhwi16o
66 >::type typed_zero_pad_weights(const memory_desc_wrapper &m_d,
67 typename prec_traits<dt>::type *data) {
68 constexpr int w_groups = false
69 || fmt == gOhwi8o || fmt == gOihw16o || fmt == gOhwi16o
70 || fmt == gOidhw16o || fmt == gOdhwi16o;
72 constexpr int is_3d = false
73 || fmt == Oidhw16o || fmt == Odhwi16o
74 || fmt == gOidhw16o || fmt == gOdhwi16o;
76 constexpr int blksize = fmt == Ohwi8o || fmt == gOhwi8o ? 8 : 16;
78 const auto &dims = m_d.dims();
79 const auto &pdims = m_d.blocking_desc().padding_dims;
81 const int G = w_groups ? dims[0] : 1;
82 const int NB_OC = pdims[w_groups + 0] / blksize;
83 const int IC = dims[w_groups + 1];
84 const int D = is_3d ? dims[w_groups + 2] : 1;
85 const int H = dims[w_groups + 2 + is_3d];
86 const int W = dims[w_groups + 3 + is_3d];
88 const int oc_tail = pdims[w_groups + 0] - dims[w_groups + 0];
90 # pragma omp parallel for collapse(5)
91 for (int g = 0; g < G; ++g)
92 for (int ic = 0; ic < IC; ++ic)
93 for (int d = 0; d < D; ++d)
94 for (int h = 0; h < H; ++h)
95 for (int w = 0; w < W; ++w) {
97 ? m_d.blk_off<!w_groups>(g, NB_OC - 1, ic, d, h, w)
98 : m_d.blk_off<!w_groups>(g, NB_OC - 1, ic, h, w) ];
99 for (int oc = blksize - oc_tail; oc < blksize; ++oc)
104 template <data_type_t dt, memory_format_t fmt>
105 typename utils::enable_if<fmt == oIhw8i || fmt == oIhw16i>::type
106 typed_zero_pad_weights(const memory_desc_wrapper &m_d,
107 typename prec_traits<dt>::type *data) {
108 constexpr int blksize = fmt == oIhw8i ? 8 : 16;
110 constexpr int w_groups = 0;
111 constexpr int is_3d = 0;
113 const auto &dims = m_d.dims();
114 const auto &pdims = m_d.blocking_desc().padding_dims;
116 const int G = w_groups ? dims[0] : 1;
117 const int OC = dims[w_groups + 0];
118 const int NB_IC = pdims[w_groups + 1] / blksize;
119 const int D = is_3d ? dims[w_groups + 2] : 1;
120 const int H = dims[w_groups + 2 + is_3d];
121 const int W = dims[w_groups + 3 + is_3d];
123 const int ic_tail = pdims[w_groups + 1] - dims[w_groups + 1];
125 # pragma omp parallel for collapse(5)
126 for (int g = 0; g < G; ++g)
127 for (int oc = 0; oc < OC; ++oc)
128 for (int d = 0; d < D; ++d)
129 for (int h = 0; h < H; ++h)
130 for (int w = 0; w < W; ++w) {
132 ? m_d.blk_off<!w_groups>(g, oc, NB_IC - 1, d, h, w)
133 : m_d.blk_off<!w_groups>(g, oc, NB_IC - 1, h, w) ];
134 for (int ic = blksize - ic_tail; ic < blksize; ++ic)
139 template <data_type_t dt, memory_format_t fmt>
140 typename utils::enable_if<false
141 || fmt == IOhw16o16i || fmt == gIOhw16o16i
142 || fmt == OIdhw16i16o || fmt == OIdhw16o16i || fmt == OIhw8i8o
143 || fmt == OIhw16i16o || fmt == OIhw4i16o4i || fmt == OIhw8i16o2i
144 || fmt == OIdhw8i16o2i || fmt == OIhw8o16i2o || fmt == OIhw8o8i
145 || fmt == OIhw16o16i || fmt == gOIhw8i8o
146 || fmt == gOIhw16i16o || fmt == gOIhw4i16o4i || fmt == gOIhw8i16o2i
147 || fmt == gOIdhw8i16o2i || fmt == gOIhw8o16i2o || fmt == gOIhw8o8i
148 || fmt == gOIhw16o16i || fmt == gOIdhw16i16o || fmt == gOIdhw16o16i
149 >::type typed_zero_pad_weights(const memory_desc_wrapper &m_d,
150 typename prec_traits<dt>::type *data) {
151 using data_t = typename prec_traits<dt>::type;
152 constexpr int w_groups = false
153 || fmt == gOIhw8i8o || fmt == gOIhw16i16o || fmt == gOIhw4i16o4i
154 || fmt == gOIhw8i16o2i || fmt == gOIdhw8i16o2i || fmt == gOIhw8o16i2o
155 || fmt == gOIhw8o8i || fmt == gOIhw16o16i || fmt == gIOhw16o16i
156 || fmt == gOIdhw16i16o || fmt == gOIdhw16o16i || fmt == gOidhw16o
159 constexpr int is_3d = false
160 || fmt == OIdhw16i16o || fmt == OIdhw16o16i || fmt == OIdhw8i16o2i
161 || fmt == gOIdhw8i16o2i || fmt == gOIdhw16i16o || fmt == gOIdhw16o16i;
163 constexpr int blksize = 16;
165 const auto &dims = m_d.dims();
166 const auto &pdims = m_d.blocking_desc().padding_dims;
168 const int G = w_groups ? dims[0] : 1;
169 const int NB_OC = pdims[w_groups + 0] / blksize;
170 const int NB_IC = pdims[w_groups + 1] / blksize;
171 const int D = is_3d ? dims[w_groups + 2] : 1;
172 const int H = dims[w_groups + 2 + is_3d];
173 const int W = dims[w_groups + 3 + is_3d];
175 auto index = [&](const int ic, const int oc) {
176 if (utils::one_of(fmt,
177 OIhw8i16o2i, gOIhw8i16o2i,
178 OIdhw8i16o2i, gOIdhw8i16o2i))
179 return ((ic / 2) * blksize * 2 + 2 * oc + ic % 2);
180 else if (utils::one_of(fmt, OIhw4i16o4i, gOIhw4i16o4i))
181 return ((ic / 4) * blksize * 4 + oc * 4 + ic % 4);
182 else if (utils::one_of(fmt, OIhw8o16i2o, gOIhw8o16i2o))
183 return ((oc / 2) * blksize * 2 + 2 * ic + oc % 2);
184 else if (utils::one_of(fmt,
185 OIhw16i16o, gOIhw16i16o,
186 OIdhw16i16o, gOIdhw16i16o))
187 return (ic * blksize + oc);
189 return (oc * blksize + ic);
192 auto ker = [&](data_t *d, const int oc_tail, const int ic_tail) {
194 for (; oc < blksize - oc_tail; ++oc) {
195 for (int ic = blksize - ic_tail; ic < blksize; ++ic)
196 d[index(ic, oc)] = 0;
198 for (; oc < blksize; ++oc)
199 for (int ic = 0; ic < blksize; ++ic)
200 d[index(ic, oc)] = 0;
203 const int oc_tail = pdims[w_groups + 0] - dims[w_groups + 0];
204 const int ic_tail = pdims[w_groups + 1] - dims[w_groups + 1];
207 # pragma omp parallel for collapse(5)
208 for (int g = 0; g < G; ++g)
209 for (int nb_oc = 0; nb_oc < NB_OC; ++nb_oc)
210 for (int d = 0; d < D; ++d)
211 for (int h = 0; h < H; ++h)
212 for (int w = 0; w < W; ++w) {
214 ? m_d.blk_off<!w_groups>(g, nb_oc, NB_IC - 1, d, h, w)
215 : m_d.blk_off<!w_groups>(g, nb_oc, NB_IC - 1, h, w) ];
221 # pragma omp parallel for collapse(5)
222 for (int g = 0; g < G; ++g)
223 for (int nb_ic = 0; nb_ic < NB_IC; ++nb_ic)
224 for (int d = 0; d < D; ++d)
225 for (int h = 0; h < H; ++h)
226 for (int w = 0; w < W; ++w) {
228 ? m_d.blk_off<!w_groups>(g, NB_OC - 1, nb_ic, d, h, w)
229 : m_d.blk_off<!w_groups>(g, NB_OC - 1, nb_ic, h, w) ];
235 template <data_type_t dt, memory_format_t fmt>
236 typename utils::enable_if<fmt == Goihw8g || fmt == Goihw16g>::type
237 typed_zero_pad_weights(const memory_desc_wrapper &m_d,
238 typename prec_traits<dt>::type *data) {
239 constexpr int blksize = fmt == Goihw8g ? 8 : 16;
241 const auto &dims = m_d.dims();
242 const auto &pdims = m_d.blocking_desc().padding_dims;
244 const int G = pdims[0] / blksize - 1;
245 const int g_tail_start = dims[0] % blksize;
246 assert(g_tail_start != 0);
247 const ptrdiff_t sz_rest
248 = (ptrdiff_t)utils::array_product(dims + 1, m_d.ndims() - 1);
250 auto *d = &data[m_d.blk_off(G)];
252 # pragma omp parallel for schedule(static)
253 for (ptrdiff_t s = 0; s < sz_rest; ++s) {
254 for (int g = g_tail_start; g < blksize; ++g)
255 d[s * blksize + g] = 0;
259 template <data_type_t dt>
260 status_t cpu_memory_t::typed_zero_pad() {
261 const memory_desc_wrapper mpd(&conf_);
263 // FIXME: guard this check for non-blocked layout
264 if (mpd.nelems(false) == mpd.nelems(true))
267 auto *data = (typename prec_traits<dt>::type *)data_;
268 const auto fmt = mpd.format();
271 # define MAYBE_DATA(f) if (fmt == f) \
272 { typed_zero_pad_data<dt, f>(mpd, data); return success; }
275 MAYBE_DATA(nCdhw16c);
278 # define MAYBE_WEIGHTS(f) if (fmt == f) \
279 { typed_zero_pad_weights<dt, f>(mpd, data); return success; }
280 MAYBE_WEIGHTS(OIdhw16i16o);
281 MAYBE_WEIGHTS(OIdhw16o16i);
282 MAYBE_WEIGHTS(Oidhw16o);
283 MAYBE_WEIGHTS(Odhwi16o);
284 MAYBE_WEIGHTS(oIhw8i);
285 MAYBE_WEIGHTS(oIhw16i);
286 MAYBE_WEIGHTS(OIhw8i8o);
287 MAYBE_WEIGHTS(OIhw16i16o);
288 MAYBE_WEIGHTS(OIhw4i16o4i);
289 MAYBE_WEIGHTS(OIhw8i16o2i);
290 MAYBE_WEIGHTS(OIdhw8i16o2i);
291 MAYBE_WEIGHTS(OIhw8o16i2o);
292 MAYBE_WEIGHTS(OIhw8o8i);
293 MAYBE_WEIGHTS(OIhw16o16i);
294 MAYBE_WEIGHTS(IOhw16o16i);
295 MAYBE_WEIGHTS(Oihw16o);
296 MAYBE_WEIGHTS(Ohwi8o);
297 MAYBE_WEIGHTS(Ohwi16o);
298 MAYBE_WEIGHTS(gOIhw8i8o);
299 MAYBE_WEIGHTS(gOIhw16i16o);
300 MAYBE_WEIGHTS(gOIhw4i16o4i);
301 MAYBE_WEIGHTS(gOIhw8i16o2i);
302 MAYBE_WEIGHTS(gOIdhw8i16o2i);
303 MAYBE_WEIGHTS(gOIhw8o16i2o);
304 MAYBE_WEIGHTS(gOIhw8o8i);
305 MAYBE_WEIGHTS(gOIhw16o16i);
306 MAYBE_WEIGHTS(gIOhw16o16i);
307 MAYBE_WEIGHTS(gOihw16o);
308 MAYBE_WEIGHTS(gOhwi8o);
309 MAYBE_WEIGHTS(gOhwi16o);
310 MAYBE_WEIGHTS(gOIdhw16i16o);
311 MAYBE_WEIGHTS(gOIdhw16o16i);
312 MAYBE_WEIGHTS(gOidhw16o);
313 MAYBE_WEIGHTS(gOdhwi16o);
314 MAYBE_WEIGHTS(Goihw8g);
315 MAYBE_WEIGHTS(Goihw16g);
316 # undef MAYBE_WEIGHTS
318 return unimplemented;
321 status_t cpu_memory_t::zero_pad() {
322 memory_desc_wrapper md(&conf_);
323 const bool skip_zeroing = false
326 || !md.is_blocking_desc();
327 if (skip_zeroing) return success;
329 switch (md.data_type()) {
330 case f32: return typed_zero_pad<f32>();
331 case s32: return typed_zero_pad<s32>();
332 case s16: return typed_zero_pad<s16>();
333 case s8: return typed_zero_pad<s8>();
334 case u8: return typed_zero_pad<u8>();
335 default: assert(!"memory is undefined"); return unimplemented;
337 return unimplemented;