Publishing R3
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / cpu_memory.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <assert.h>
18
19 #include "memory_pd.hpp"
20 #include "mkldnn_traits.hpp"
21 #include "mkldnn_thread.hpp"
22 #include "type_helpers.hpp"
23 #include "utils.hpp"
24
25 #include "cpu_memory.hpp"
26
27 namespace mkldnn {
28 namespace impl {
29 namespace cpu {
30
31 using namespace mkldnn::impl;
32 using namespace mkldnn::impl::data_type;
33 using namespace mkldnn::impl::status;
34 using namespace mkldnn::impl::memory_format;
35
36 template <data_type_t dt, memory_format_t fmt>
37 typename utils::enable_if<fmt == nChw8c || fmt == nChw16c || fmt == nCdhw16c
38 >::type typed_zero_pad_data(const memory_desc_wrapper &m_d,
39         typename prec_traits<dt>::type *data) {
40     constexpr int blksize = fmt == nChw8c ? 8 : 16;
41
42     const auto &dims = m_d.dims();
43     const auto &pdims = m_d.blocking_desc().padding_dims;
44
45     const int C = pdims[1] / blksize - 1;
46     const int c_tail_start = dims[1] % blksize;
47     assert(c_tail_start != 0);
48     const size_t sp_rest = utils::array_product(dims + 3, m_d.ndims() - 3);
49
50 #   pragma omp parallel for collapse(2) schedule(static)
51     for (int n = 0; n < dims[0]; ++n)
52     for (int sp0 = 0; sp0 < dims[2]; ++sp0) {
53         auto *d = &data[m_d.blk_off(n, C, sp0)];
54         for (size_t sp = 0; sp < sp_rest; ++sp) {
55             for (int c = c_tail_start; c < blksize; ++c)
56                 d[sp * blksize + c] = 0;
57         }
58     }
59 }
60
61 template <data_type_t dt, memory_format_t fmt>
62 typename utils::enable_if<false
63 || fmt == Ohwi8o || fmt == Oihw16o || fmt == Ohwi16o || fmt == Oidhw16o
64 || fmt == Odhwi16o || fmt == gOhwi8o || fmt == gOihw16o || fmt == gOhwi16o
65 || fmt == gOidhw16o || fmt == gOdhwi16o
66 >::type typed_zero_pad_weights(const memory_desc_wrapper &m_d,
67         typename prec_traits<dt>::type *data) {
68     constexpr int w_groups = false
69         || fmt == gOhwi8o || fmt == gOihw16o || fmt == gOhwi16o
70         || fmt == gOidhw16o || fmt == gOdhwi16o;
71
72     constexpr int is_3d = false
73         || fmt == Oidhw16o || fmt == Odhwi16o
74         || fmt == gOidhw16o || fmt == gOdhwi16o;
75
76     constexpr int blksize = fmt == Ohwi8o || fmt == gOhwi8o ? 8 : 16;
77
78     const auto &dims = m_d.dims();
79     const auto &pdims = m_d.blocking_desc().padding_dims;
80
81     const int G = w_groups ? dims[0] : 1;
82     const int NB_OC = pdims[w_groups + 0] / blksize;
83     const int IC = dims[w_groups + 1];
84     const int D = is_3d ? dims[w_groups + 2] : 1;
85     const int H = dims[w_groups + 2 + is_3d];
86     const int W = dims[w_groups + 3 + is_3d];
87
88     const int oc_tail = pdims[w_groups + 0] - dims[w_groups + 0];
89
90 #   pragma omp parallel for collapse(5)
91     for (int g = 0; g < G; ++g)
92     for (int ic = 0; ic < IC; ++ic)
93     for (int d = 0; d < D; ++d)
94     for (int h = 0; h < H; ++h)
95     for (int w = 0; w < W; ++w) {
96         auto x = &data[is_3d
97             ? m_d.blk_off<!w_groups>(g, NB_OC - 1, ic, d, h, w)
98             : m_d.blk_off<!w_groups>(g, NB_OC - 1, ic, h, w) ];
99         for (int oc = blksize - oc_tail; oc < blksize; ++oc)
100             x[oc] = 0;
101     }
102 }
103
104 template <data_type_t dt, memory_format_t fmt>
105 typename utils::enable_if<fmt == oIhw8i || fmt == oIhw16i>::type
106 typed_zero_pad_weights(const memory_desc_wrapper &m_d,
107         typename prec_traits<dt>::type *data) {
108     constexpr int blksize = fmt == oIhw8i ? 8 : 16;
109
110     constexpr int w_groups = 0;
111     constexpr int is_3d = 0;
112
113     const auto &dims = m_d.dims();
114     const auto &pdims = m_d.blocking_desc().padding_dims;
115
116     const int G = w_groups ? dims[0] : 1;
117     const int OC = dims[w_groups + 0];
118     const int NB_IC = pdims[w_groups + 1] / blksize;
119     const int D = is_3d ? dims[w_groups + 2] : 1;
120     const int H = dims[w_groups + 2 + is_3d];
121     const int W = dims[w_groups + 3 + is_3d];
122
123     const int ic_tail = pdims[w_groups + 1] - dims[w_groups + 1];
124
125 #   pragma omp parallel for collapse(5)
126     for (int g = 0; g < G; ++g)
127     for (int oc = 0; oc < OC; ++oc)
128     for (int d = 0; d < D; ++d)
129     for (int h = 0; h < H; ++h)
130     for (int w = 0; w < W; ++w) {
131         auto x = &data[is_3d
132             ? m_d.blk_off<!w_groups>(g, oc, NB_IC - 1, d, h, w)
133             : m_d.blk_off<!w_groups>(g, oc, NB_IC - 1, h, w) ];
134         for (int ic = blksize - ic_tail; ic < blksize; ++ic)
135             x[ic] = 0;
136     }
137 }
138
139 template <data_type_t dt, memory_format_t fmt>
140 typename utils::enable_if<false
141 || fmt == IOhw16o16i || fmt == gIOhw16o16i
142 || fmt == OIdhw16i16o || fmt == OIdhw16o16i || fmt == OIhw8i8o
143 || fmt == OIhw16i16o || fmt == OIhw4i16o4i || fmt == OIhw8i16o2i
144 || fmt == OIdhw8i16o2i || fmt == OIhw8o16i2o || fmt == OIhw8o8i
145 || fmt == OIhw16o16i || fmt == gOIhw8i8o
146 || fmt == gOIhw16i16o || fmt == gOIhw4i16o4i || fmt == gOIhw8i16o2i
147 || fmt == gOIdhw8i16o2i || fmt == gOIhw8o16i2o || fmt == gOIhw8o8i
148 || fmt == gOIhw16o16i || fmt == gOIdhw16i16o || fmt == gOIdhw16o16i
149 >::type typed_zero_pad_weights(const memory_desc_wrapper &m_d,
150         typename prec_traits<dt>::type *data) {
151     using data_t = typename prec_traits<dt>::type;
152     constexpr int w_groups = false
153         || fmt == gOIhw8i8o || fmt == gOIhw16i16o || fmt == gOIhw4i16o4i
154         || fmt == gOIhw8i16o2i || fmt == gOIdhw8i16o2i || fmt == gOIhw8o16i2o
155         || fmt == gOIhw8o8i || fmt == gOIhw16o16i || fmt == gIOhw16o16i
156         || fmt == gOIdhw16i16o || fmt == gOIdhw16o16i || fmt == gOidhw16o
157         || fmt == gOdhwi16o;
158
159     constexpr int is_3d = false
160         || fmt == OIdhw16i16o || fmt == OIdhw16o16i || fmt == OIdhw8i16o2i
161         || fmt == gOIdhw8i16o2i || fmt == gOIdhw16i16o || fmt == gOIdhw16o16i;
162
163     constexpr int blksize = 16;
164
165     const auto &dims = m_d.dims();
166     const auto &pdims = m_d.blocking_desc().padding_dims;
167
168     const int G = w_groups ? dims[0] : 1;
169     const int NB_OC = pdims[w_groups + 0] / blksize;
170     const int NB_IC = pdims[w_groups + 1] / blksize;
171     const int D = is_3d ? dims[w_groups + 2] : 1;
172     const int H = dims[w_groups + 2 + is_3d];
173     const int W = dims[w_groups + 3 + is_3d];
174
175     auto index = [&](const int ic, const int oc) {
176         if (utils::one_of(fmt,
177                     OIhw8i16o2i, gOIhw8i16o2i,
178                     OIdhw8i16o2i, gOIdhw8i16o2i))
179             return ((ic / 2) * blksize * 2 + 2 * oc + ic % 2);
180         else if (utils::one_of(fmt, OIhw4i16o4i, gOIhw4i16o4i))
181             return ((ic / 4) * blksize * 4 + oc * 4 + ic % 4);
182         else if (utils::one_of(fmt, OIhw8o16i2o, gOIhw8o16i2o))
183             return ((oc / 2) * blksize * 2 + 2 * ic + oc % 2);
184         else if (utils::one_of(fmt,
185                     OIhw16i16o, gOIhw16i16o,
186                     OIdhw16i16o, gOIdhw16i16o))
187             return (ic * blksize + oc);
188         else
189             return (oc * blksize + ic);
190     };
191
192     auto ker = [&](data_t *d, const int oc_tail, const int ic_tail) {
193         int oc = 0;
194         for (; oc < blksize - oc_tail; ++oc) {
195             for (int ic = blksize - ic_tail; ic < blksize; ++ic)
196                 d[index(ic, oc)] = 0;
197         }
198         for (; oc < blksize; ++oc)
199             for (int ic = 0; ic < blksize; ++ic)
200                 d[index(ic, oc)] = 0;
201     };
202
203     const int oc_tail = pdims[w_groups + 0] - dims[w_groups + 0];
204     const int ic_tail = pdims[w_groups + 1] - dims[w_groups + 1];
205
206     if (ic_tail) {
207 #       pragma omp parallel for collapse(5)
208         for (int g = 0; g < G; ++g)
209         for (int nb_oc = 0; nb_oc < NB_OC; ++nb_oc)
210         for (int d = 0; d < D; ++d)
211         for (int h = 0; h < H; ++h)
212         for (int w = 0; w < W; ++w) {
213             auto x = &data[is_3d
214                 ? m_d.blk_off<!w_groups>(g, nb_oc, NB_IC - 1, d, h, w)
215                 : m_d.blk_off<!w_groups>(g, nb_oc, NB_IC - 1, h, w) ];
216             ker(x, 0, ic_tail);
217         }
218     }
219
220     if (oc_tail) {
221 #       pragma omp parallel for collapse(5)
222         for (int g = 0; g < G; ++g)
223         for (int nb_ic = 0; nb_ic < NB_IC; ++nb_ic)
224         for (int d = 0; d < D; ++d)
225         for (int h = 0; h < H; ++h)
226         for (int w = 0; w < W; ++w) {
227             auto x = &data[is_3d
228                 ? m_d.blk_off<!w_groups>(g, NB_OC - 1, nb_ic, d, h, w)
229                 : m_d.blk_off<!w_groups>(g, NB_OC - 1, nb_ic, h, w) ];
230             ker(x, oc_tail, 0);
231         }
232     }
233 }
234
235 template <data_type_t dt, memory_format_t fmt>
236 typename utils::enable_if<fmt == Goihw8g || fmt == Goihw16g>::type
237 typed_zero_pad_weights(const memory_desc_wrapper &m_d,
238         typename prec_traits<dt>::type *data) {
239     constexpr int blksize = fmt == Goihw8g ? 8 : 16;
240
241     const auto &dims = m_d.dims();
242     const auto &pdims = m_d.blocking_desc().padding_dims;
243
244     const int G = pdims[0] / blksize - 1;
245     const int g_tail_start = dims[0] % blksize;
246     assert(g_tail_start != 0);
247     const ptrdiff_t sz_rest
248         = (ptrdiff_t)utils::array_product(dims + 1, m_d.ndims() - 1);
249
250     auto *d = &data[m_d.blk_off(G)];
251
252 #   pragma omp parallel for schedule(static)
253     for (ptrdiff_t s = 0; s < sz_rest; ++s) {
254         for (int g = g_tail_start; g < blksize; ++g)
255             d[s * blksize + g] = 0;
256     }
257 }
258
259 template <data_type_t dt>
260 status_t cpu_memory_t::typed_zero_pad() {
261     const memory_desc_wrapper mpd(&conf_);
262
263     // FIXME: guard this check for non-blocked layout
264     if (mpd.nelems(false) == mpd.nelems(true))
265         return success;
266
267     auto *data = (typename prec_traits<dt>::type *)data_;
268     const auto fmt = mpd.format();
269
270     /* data */
271 #   define MAYBE_DATA(f) if (fmt == f) \
272     { typed_zero_pad_data<dt, f>(mpd, data); return success; }
273     MAYBE_DATA(nChw8c);
274     MAYBE_DATA(nChw16c);
275     MAYBE_DATA(nCdhw16c);
276
277     /* weights */
278 #   define MAYBE_WEIGHTS(f) if (fmt == f) \
279     { typed_zero_pad_weights<dt, f>(mpd, data); return success; }
280     MAYBE_WEIGHTS(OIdhw16i16o);
281     MAYBE_WEIGHTS(OIdhw16o16i);
282     MAYBE_WEIGHTS(Oidhw16o);
283     MAYBE_WEIGHTS(Odhwi16o);
284     MAYBE_WEIGHTS(oIhw8i);
285     MAYBE_WEIGHTS(oIhw16i);
286     MAYBE_WEIGHTS(OIhw8i8o);
287     MAYBE_WEIGHTS(OIhw16i16o);
288     MAYBE_WEIGHTS(OIhw4i16o4i);
289     MAYBE_WEIGHTS(OIhw8i16o2i);
290     MAYBE_WEIGHTS(OIdhw8i16o2i);
291     MAYBE_WEIGHTS(OIhw8o16i2o);
292     MAYBE_WEIGHTS(OIhw8o8i);
293     MAYBE_WEIGHTS(OIhw16o16i);
294     MAYBE_WEIGHTS(IOhw16o16i);
295     MAYBE_WEIGHTS(Oihw16o);
296     MAYBE_WEIGHTS(Ohwi8o);
297     MAYBE_WEIGHTS(Ohwi16o);
298     MAYBE_WEIGHTS(gOIhw8i8o);
299     MAYBE_WEIGHTS(gOIhw16i16o);
300     MAYBE_WEIGHTS(gOIhw4i16o4i);
301     MAYBE_WEIGHTS(gOIhw8i16o2i);
302     MAYBE_WEIGHTS(gOIdhw8i16o2i);
303     MAYBE_WEIGHTS(gOIhw8o16i2o);
304     MAYBE_WEIGHTS(gOIhw8o8i);
305     MAYBE_WEIGHTS(gOIhw16o16i);
306     MAYBE_WEIGHTS(gIOhw16o16i);
307     MAYBE_WEIGHTS(gOihw16o);
308     MAYBE_WEIGHTS(gOhwi8o);
309     MAYBE_WEIGHTS(gOhwi16o);
310     MAYBE_WEIGHTS(gOIdhw16i16o);
311     MAYBE_WEIGHTS(gOIdhw16o16i);
312     MAYBE_WEIGHTS(gOidhw16o);
313     MAYBE_WEIGHTS(gOdhwi16o);
314     MAYBE_WEIGHTS(Goihw8g);
315     MAYBE_WEIGHTS(Goihw16g);
316 #   undef MAYBE_WEIGHTS
317
318     return unimplemented;
319 }
320
321 status_t cpu_memory_t::zero_pad() {
322     memory_desc_wrapper md(&conf_);
323     const bool skip_zeroing = false
324         || data_ == nullptr
325         || md.is_zero()
326         || !md.is_blocking_desc();
327     if (skip_zeroing) return success;
328
329     switch (md.data_type()) {
330         case f32: return typed_zero_pad<f32>();
331         case s32: return typed_zero_pad<s32>();
332         case s16: return typed_zero_pad<s16>();
333         case s8: return typed_zero_pad<s8>();
334         case u8: return typed_zero_pad<u8>();
335         default: assert(!"memory is undefined"); return unimplemented;
336     }
337     return unimplemented;
338 }
339
340 }
341 }
342 }