Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / cpu_memory.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <assert.h>
18
19 #include "memory_pd.hpp"
20 #include "mkldnn_traits.hpp"
21 #include "mkldnn_thread.hpp"
22 #include "type_helpers.hpp"
23 #include "utils.hpp"
24
25 #include "format_traits.hpp"
26
27 #include "cpu_memory.hpp"
28
29 namespace mkldnn {
30 namespace impl {
31 namespace cpu {
32
33 using namespace mkldnn::impl;
34 using namespace mkldnn::impl::data_type;
35 using namespace mkldnn::impl::status;
36 using namespace mkldnn::impl::memory_format;
37
38 using dk = data_kind_t;
39 using bf = block_format_t;
40
41 template <data_type_t dt, memory_format_t fmt>
42 typename utils::enable_if<format_traits<fmt>::data_kind == dk::data>::type
43 typed_zero_pad_data(
44     const memory_desc_wrapper &m_d, typename prec_traits<dt>::type *data) {
45     constexpr int blksize = format_traits<fmt>::blk_size;
46
47     const auto &dims = m_d.dims();
48     const auto &pdims = m_d.blocking_desc().padding_dims;
49
50     const int C = pdims[1] / blksize - 1;
51     const int c_tail_start = dims[1] % blksize;
52     assert(c_tail_start != 0);
53     const size_t sp_rest = utils::array_product(dims + 3, m_d.ndims() - 3);
54
55     parallel_nd(dims[0], dims[2], [&](int n, int sp0) {
56         auto *d = &data[m_d.blk_off(n, C, sp0)];
57         for (size_t sp = 0; sp < sp_rest; ++sp) {
58             for (int c = c_tail_start; c < blksize; ++c)
59                 d[sp * blksize + c] = 0;
60         }
61     });
62 }
63
64 template <data_type_t dt, memory_format_t fmt>
65 typename utils::enable_if<false
66 || format_traits<fmt>::blk_fmt == bf::_4o
67 || format_traits<fmt>::blk_fmt == bf::_8o
68 || format_traits<fmt>::blk_fmt == bf::_16o
69 >::type typed_zero_pad_weights(const memory_desc_wrapper &m_d,
70         typename prec_traits<dt>::type *data) {
71     static constexpr int w_groups = format_traits<fmt>::data_kind == dk::gwei;
72     constexpr int is_1d = format_traits<fmt>::ndims_sp == 1;
73     constexpr int is_3d = format_traits<fmt>::ndims_sp == 3;
74     constexpr int blksize = format_traits<fmt>::blk_size;
75
76     const auto &dims = m_d.dims();
77     const auto &pdims = m_d.blocking_desc().padding_dims;
78
79     const int G = w_groups ? dims[0] : 1;
80     const int NB_OC = pdims[w_groups + 0] / blksize;
81     const int IC = dims[w_groups + 1];
82     const int D = is_3d ? dims[w_groups + 2] : 1;
83     const int H = is_1d ? 1 : dims[w_groups + 2 + is_3d];
84     const int W = dims[w_groups + 3 - is_1d + is_3d];
85
86     const int oc_tail = pdims[w_groups + 0] - dims[w_groups + 0];
87
88     parallel_nd(G, IC, D, H, W,
89         [&](int g, int ic, int d, int h, int w) {
90         auto x = &data[wei_blk_off_like_gwei3D<fmt>(m_d,
91                 g, NB_OC - 1, ic, d, h, w)];
92         for (int oc = blksize - oc_tail; oc < blksize; ++oc)
93             x[oc] = 0;
94     });
95 }
96
97 template <data_type_t dt, memory_format_t fmt>
98 typename utils::enable_if<false
99 || format_traits<fmt>::blk_fmt == bf::_8i
100 || format_traits<fmt>::blk_fmt == bf::_16i
101 >::type typed_zero_pad_weights(const memory_desc_wrapper &m_d,
102         typename prec_traits<dt>::type *data) {
103     static constexpr int w_groups = format_traits<fmt>::data_kind == dk::gwei;
104     constexpr int is_1d = format_traits<fmt>::ndims_sp == 1;
105     constexpr int is_3d = format_traits<fmt>::ndims_sp == 3;
106     constexpr int blksize = format_traits<fmt>::blk_size;
107
108     const auto &dims = m_d.dims();
109     const auto &pdims = m_d.blocking_desc().padding_dims;
110
111     const int G = w_groups ? dims[0] : 1;
112     const int OC = dims[w_groups + 0];
113     const int NB_IC = pdims[w_groups + 1] / blksize;
114     const int D = is_3d ? dims[w_groups + 2] : 1;
115     const int H = is_1d ? 1 : dims[w_groups + 2 + is_3d];
116     const int W = dims[w_groups + 3 + is_3d];
117
118     const int ic_tail = pdims[w_groups + 1] - dims[w_groups + 1];
119
120     parallel_nd(G, OC, D, H, W,
121         [&](int g, int oc, int d, int h, int w) {
122         auto x = &data[wei_blk_off_like_gwei3D<fmt>(m_d,
123                 g, oc, NB_IC - 1, d, h, w)];
124         for (int ic = blksize - ic_tail; ic < blksize; ++ic)
125             x[ic] = 0;
126     });
127 }
128
129 template <data_type_t dt, memory_format_t fmt>
130 typename utils::enable_if<
131 block_format_traits<format_traits<fmt>::blk_fmt>::blk_ndims == 2>::type
132 typed_zero_pad_weights(const memory_desc_wrapper &m_d,
133         typename prec_traits<dt>::type *data) {
134     using data_t = typename prec_traits<dt>::type;
135     static constexpr int w_groups = format_traits<fmt>::data_kind == dk::gwei;
136     constexpr int is_1d = format_traits<fmt>::ndims_sp == 1;
137     constexpr int is_3d = format_traits<fmt>::ndims_sp == 3;
138     constexpr int blksize = format_traits<fmt>::blk_size;
139     const auto &dims = m_d.dims();
140     const auto &pdims = m_d.blocking_desc().padding_dims;
141
142     const int G = w_groups ? dims[0] : 1;
143     const int NB_OC = pdims[w_groups + 0] / blksize;
144     const int NB_IC = pdims[w_groups + 1] / blksize;
145     const int D = is_3d ? dims[w_groups + 2] : 1;
146     const int H = is_1d ? 1 : dims[w_groups + 2 + is_3d];
147     const int W = dims[w_groups + 3 - is_1d + is_3d];
148
149     auto ker = [&](data_t *d, const int oc_tail, const int ic_tail) {
150 #       define blk_off OI_blk_off<format_traits<fmt>::blk_fmt>
151         int oc = 0;
152         for (; oc < blksize - oc_tail; ++oc) {
153             for (int ic = blksize - ic_tail; ic < blksize; ++ic)
154                 d[blk_off(oc, ic)] = 0;
155         }
156         for (; oc < blksize; ++oc)
157             for (int ic = 0; ic < blksize; ++ic)
158                 d[blk_off(oc, ic)] = 0;
159 #       undef blk_off
160     };
161
162     const int oc_tail = pdims[w_groups + 0] - dims[w_groups + 0];
163     const int ic_tail = pdims[w_groups + 1] - dims[w_groups + 1];
164
165     if (ic_tail) {
166         parallel_nd(G, NB_OC, D, H, W,
167             [&](int g, int nb_oc, int d, int h, int w) {
168             auto x = &data[wei_blk_off_like_gwei3D<fmt>(m_d,
169                     g, nb_oc, NB_IC - 1, d, h, w)];
170             ker(x, 0, ic_tail);
171         });
172     }
173
174     if (oc_tail) {
175         parallel_nd(G, NB_IC, D, H, W,
176             [&](int g, int nb_ic, int d, int h, int w) {
177             auto x = &data[wei_blk_off_like_gwei3D<fmt>(m_d,
178                     g, NB_OC - 1, nb_ic, d, h, w)];
179             ker(x, oc_tail, 0);
180         });
181     }
182 }
183
184 template <data_type_t dt, memory_format_t fmt>
185 typename utils::enable_if<false
186 || format_traits<fmt>::blk_fmt == bf::_8g
187 || format_traits<fmt>::blk_fmt == bf::_16g
188 >::type typed_zero_pad_weights(const memory_desc_wrapper &m_d,
189         typename prec_traits<dt>::type *data) {
190     constexpr int blksize = format_traits<fmt>::blk_size;
191
192     const auto &dims = m_d.dims();
193     const auto &pdims = m_d.blocking_desc().padding_dims;
194
195     const int G = pdims[0] / blksize - 1;
196     const int g_tail_start = dims[0] % blksize;
197     assert(g_tail_start != 0);
198     const ptrdiff_t sz_rest
199         = (ptrdiff_t)utils::array_product(dims + 1, m_d.ndims() - 1);
200
201     auto *d = &data[m_d.blk_off(G)];
202
203     parallel_nd(sz_rest, [&](ptrdiff_t s) {
204         for (int g = g_tail_start; g < blksize; ++g)
205             d[s * blksize + g] = 0;
206     });
207 }
208
209 template <data_type_t dt>
210 void typed_zero_pad_generic_blocked(const memory_desc_wrapper &m_d,
211         typename prec_traits<dt>::type *data) {
212     const int ndims = m_d.ndims();
213     const auto &dims = m_d.dims();
214     const auto &pdims = m_d.blocking_desc().padding_dims;
215
216     const ptrdiff_t nelems = (ptrdiff_t)m_d.nelems(true);
217
218     /* [D_0] .. [D_k][D_k+1] .. [D_ndim - 1]
219      *            |  \                     /
220      *            |   ---------------------
221      *           has        contiguous
222      *         padding
223      *
224      * step     <-- D_k+1 * ... * D_ndims-1
225      * step_dim <-- k
226      */
227
228     ptrdiff_t step = 1;
229     int step_dim = ndims - 1;
230     for (; step_dim >= 0; --step_dim) {
231         if (dims[step_dim] != pdims[step_dim]) break;
232         step *= dims[step_dim];
233     }
234
235     assert(step_dim >= 0 && "no zero padding is required");
236     if (step_dim < 0) return;
237
238     parallel_nd(nelems / step, [&](ptrdiff_t e1) {
239         bool need_zero = false;
240
241         ptrdiff_t idx = e1;
242         for (int d = step_dim; d >= 0; --d) {
243             if (idx % pdims[d] >= dims[d]) {
244                 need_zero = true;
245                 break;
246             }
247             idx /= pdims[d];
248         }
249
250         if (need_zero) {
251             for (ptrdiff_t e0 = 0; e0 < step; ++e0)
252                 data[m_d.off_l(e1 * step + e0, true)] = 0;
253         }
254     });
255 }
256
257 template <data_type_t dt>
258 status_t cpu_memory_t::typed_zero_pad() const {
259     const memory_desc_wrapper mpd(pd());
260
261     // FIXME: guard this check for non-blocked layout
262     if (mpd.nelems(false) == mpd.nelems(true))
263         return success;
264
265     auto *data = (typename prec_traits<dt>::type *)data_;
266     const auto fmt = mpd.format();
267
268     /* data */
269 #   define MAYBE_DATA(f) if (fmt == f) \
270     { typed_zero_pad_data<dt, f>(mpd, data); return success; }
271     MAYBE_DATA(nCw4c);
272     MAYBE_DATA(nCw8c);
273     MAYBE_DATA(nCw16c);
274     MAYBE_DATA(nChw4c);
275     MAYBE_DATA(nChw8c);
276     MAYBE_DATA(nCdhw4c);
277     MAYBE_DATA(nCdhw8c);
278     MAYBE_DATA(nChw16c);
279     MAYBE_DATA(nCdhw16c);
280
281     /* weights */
282 #   define MAYBE_WEIGHTS(f) if (fmt == f) \
283     { typed_zero_pad_weights<dt, f>(mpd, data); return success; }
284     MAYBE_WEIGHTS(OIdhw4i4o);
285     MAYBE_WEIGHTS(OIdhw8i8o);
286     MAYBE_WEIGHTS(OIdhw8o8i);
287     MAYBE_WEIGHTS(OIdhw16i16o);
288     MAYBE_WEIGHTS(OIdhw16o16i);
289     MAYBE_WEIGHTS(Oidhw4o);
290     MAYBE_WEIGHTS(Oidhw16o);
291     MAYBE_WEIGHTS(Odhwi16o);
292     MAYBE_WEIGHTS(Odhwi8o);
293     MAYBE_WEIGHTS(oIhw8i);
294     MAYBE_WEIGHTS(oIhw16i);
295     MAYBE_WEIGHTS(oIdhw8i);
296     MAYBE_WEIGHTS(oIdhw16i);
297     MAYBE_WEIGHTS(OIhw4i4o);
298     MAYBE_WEIGHTS(OIhw8i8o);
299     MAYBE_WEIGHTS(OIhw16i16o);
300     MAYBE_WEIGHTS(OIhw4i16o4i);
301     MAYBE_WEIGHTS(OIhw4i16o4i_s8s8);
302     MAYBE_WEIGHTS(OIw4i4o);
303     MAYBE_WEIGHTS(Owi8o);
304     MAYBE_WEIGHTS(OIw8i8o);
305     MAYBE_WEIGHTS(OIw8o8i);
306     MAYBE_WEIGHTS(OIw16i16o);
307     MAYBE_WEIGHTS(OIw16o16i);
308     MAYBE_WEIGHTS(Oiw4o);
309     MAYBE_WEIGHTS(Oiw16o);
310     MAYBE_WEIGHTS(Owi16o);
311     MAYBE_WEIGHTS(OIw8i16o2i);
312     MAYBE_WEIGHTS(OIw8o16i2o);
313     MAYBE_WEIGHTS(IOw16o16i);
314     MAYBE_WEIGHTS(OIhw8i16o2i);
315     MAYBE_WEIGHTS(OIdhw8i16o2i);
316     MAYBE_WEIGHTS(OIhw8o16i2o);
317     MAYBE_WEIGHTS(OIhw8o8i);
318     MAYBE_WEIGHTS(OIhw16o16i);
319     MAYBE_WEIGHTS(IOhw16o16i);
320     MAYBE_WEIGHTS(Oihw4o);
321     MAYBE_WEIGHTS(Oihw16o);
322     MAYBE_WEIGHTS(Ohwi8o);
323     MAYBE_WEIGHTS(Ohwi4o);
324     MAYBE_WEIGHTS(Ohwi16o);
325     MAYBE_WEIGHTS(gOIhw4o4i_s8s8);
326     MAYBE_WEIGHTS(gOIhw4o4i_s8s8);
327     MAYBE_WEIGHTS(gOIhw4i4o);
328     MAYBE_WEIGHTS(gOIhw8i8o);
329     MAYBE_WEIGHTS(gOIhw16i16o);
330     MAYBE_WEIGHTS(gOIhw4i16o4i);
331     MAYBE_WEIGHTS(gOIhw4i16o4i_s8s8);
332     MAYBE_WEIGHTS(gOIhw2i8o4i);
333     MAYBE_WEIGHTS(gOIhw2i8o4i_s8s8);
334     MAYBE_WEIGHTS(gOIw4i4o);
335     MAYBE_WEIGHTS(gOwi8o);
336     MAYBE_WEIGHTS(gOIw8i8o);
337     MAYBE_WEIGHTS(gOIw8o8i);
338     MAYBE_WEIGHTS(gOIw16i16o);
339     MAYBE_WEIGHTS(gOIw16o16i);
340     MAYBE_WEIGHTS(gOiw4o);
341     MAYBE_WEIGHTS(gOiw16o);
342     MAYBE_WEIGHTS(gOwi16o);
343     MAYBE_WEIGHTS(gOIw8i16o2i);
344     MAYBE_WEIGHTS(gOIw8o16i2o);
345     MAYBE_WEIGHTS(gIOw16o16i);
346     MAYBE_WEIGHTS(gOIhw8i16o2i);
347     MAYBE_WEIGHTS(gOIdhw8i16o2i);
348     MAYBE_WEIGHTS(gOIhw8o16i2o);
349     MAYBE_WEIGHTS(gOIhw8o8i);
350     MAYBE_WEIGHTS(gOIhw16o16i);
351     MAYBE_WEIGHTS(gIOhw16o16i);
352     MAYBE_WEIGHTS(gOihw4o);
353     MAYBE_WEIGHTS(gOihw16o);
354     MAYBE_WEIGHTS(gOhwi8o);
355     MAYBE_WEIGHTS(gOhwi4o);
356     MAYBE_WEIGHTS(gOhwi16o);
357     MAYBE_WEIGHTS(gOIdhw4i4o);
358     MAYBE_WEIGHTS(gOIdhw8i8o);
359     MAYBE_WEIGHTS(gOIdhw8o8i);
360     MAYBE_WEIGHTS(gOIdhw16i16o);
361     MAYBE_WEIGHTS(gOIdhw16o16i);
362     MAYBE_WEIGHTS(gOidhw4o);
363     MAYBE_WEIGHTS(gOidhw16o);
364     MAYBE_WEIGHTS(gOdhwi16o);
365     MAYBE_WEIGHTS(gOdhwi8o);
366     MAYBE_WEIGHTS(Goihw8g);
367     MAYBE_WEIGHTS(Goihw16g);
368 #   undef MAYBE_WEIGHTS
369
370     // the last line of defence
371     if (types::format_normalize(fmt) == blocked) {
372         typed_zero_pad_generic_blocked<dt>(mpd, data);
373         return success;
374     }
375
376     return unimplemented;
377 }
378
379 status_t cpu_memory_t::zero_pad() const {
380     memory_desc_wrapper md(pd());
381     const bool skip_zeroing = false
382         || data_ == nullptr
383         || md.is_zero()
384         || !md.is_blocking_desc();
385     if (skip_zeroing) return success;
386
387     switch (md.data_type()) {
388         case f32: return typed_zero_pad<f32>();
389         case s32: return typed_zero_pad<s32>();
390         case s16: return typed_zero_pad<s16>();
391         case s8: return typed_zero_pad<s8>();
392         case u8: return typed_zero_pad<u8>();
393         case bin: return typed_zero_pad<u8>();
394         default: assert(!"memory is undefined"); return unimplemented;
395     }
396     return unimplemented;
397 }
398
399 }
400 }
401 }