updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_reorder_utils.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <assert.h>
18
19 #include "c_types_map.hpp"
20 #include "memory_desc_wrapper.hpp"
21 #include "mkldnn_debug.h"
22 #include "nstl.hpp"
23 #include "type_helpers.hpp"
24 #include "utils.hpp"
25
26 #include "cpu_primitive.hpp"
27 #include "cpu_reorder_pd.hpp"
28 #include "jit_uni_reorder.hpp"
29
30 using namespace mkldnn::impl::types;
31 using namespace mkldnn::impl::status;
32
33 namespace mkldnn {
34 namespace impl {
35 namespace cpu {
36
37 namespace tr {
38
39 /** ad-hoc structure to describe blocked memory layout */
40 struct layout_desc_t {
41     data_type_t dt;
42     int ndims;
43     dims_t id;
44     dims_t dims;
45     strides_t strides;
46 };
47
48 status_t cvt_mem_desc_to_layout_desc(const memory_desc_t &md_,
49         layout_desc_t &ld) {
50     using namespace mkldnn::impl::memory_format;
51     using namespace mkldnn::impl::data_type;
52
53     auto md = memory_desc_wrapper(md_);
54     auto bd = md.blocking_desc();
55
56     ld.ndims = 0;
57     ld.dt = md.data_type();
58
59     auto P = [&ld](int id, int dim, ptrdiff_t stride) {
60         assert((size_t)ld.ndims < sizeof(ld.dims) / sizeof(ld.dims[0]));
61         ld.id[ld.ndims] = id;
62         ld.dims[ld.ndims] = dim;
63         ld.strides[ld.ndims] = stride;
64         ++ld.ndims;
65     };
66
67     /* special cases */
68     switch (md.format()) {
69     case memory_format::undef:
70     case memory_format::any:
71     case hwio_s8s8:
72     case hwigo_s8s8:
73     case gOIhw4o4i_s8s8:
74     case gOIhw2i8o4i_s8s8:
75     case gOIw4i16o4i_s8s8:
76     case OIw4i16o4i_s8s8:
77     case gOIhw4i16o4i_s8s8:
78     case OIhw4i16o4i_s8s8:
79     case Goihw16g_s8s8:
80     case Goiw16g_s8s8:
81     case wino_fmt:
82         return invalid_arguments;
83     case OIw4i16o4i:
84     case OIhw4i16o4i:
85         P(0, bd.padding_dims[0] / 16, bd.strides[0][0]);
86         P(0, 16, 4);
87         P(1, bd.padding_dims[1] / 16, bd.strides[0][1]);
88         P(1, 4, 16*4);
89         P(1, 4, 1);
90         P(2, bd.padding_dims[2], bd.strides[0][2]);
91         if (md.format() == OIhw4i16o4i)
92             P(3, bd.padding_dims[3], bd.strides[0][3]);
93         return success;
94     case OIw8i16o2i:
95     case OIhw8i16o2i:
96     case IOhw8i16o2i:
97     case OIdhw8i16o2i:
98         P(0, bd.padding_dims[0] / 16, bd.strides[0][0]);
99         P(0, 16, 2);
100         P(1, bd.padding_dims[1] / 16, bd.strides[0][1]);
101         P(1, 8, 16*2);
102         P(1, 2, 1);
103         P(2, bd.padding_dims[2], bd.strides[0][2]);
104         if (utils::one_of(md.format(), OIhw8i16o2i, IOhw8i16o2i)
105                 || md.format() == OIdhw8i16o2i)
106             P(3, bd.padding_dims[3], bd.strides[0][3]);
107         if (md.format() == OIdhw8i16o2i)
108             P(4, bd.padding_dims[4], bd.strides[0][4]);
109         return success;
110     case OIw8o16i2o:
111     case IOw8o16i2o:
112     case OIhw8o16i2o:
113     case IOhw8o16i2o:
114     case OIdhw8o16i2o:
115     case IOdhw8o16i2o:
116         P(0, bd.padding_dims[0] / 16, bd.strides[0][0]);
117         P(0, 8, 16*2);
118         P(0, 2, 1);
119         P(1, bd.padding_dims[1] / 16, bd.strides[0][1]);
120         P(1, 16, 2);
121         P(2, bd.padding_dims[2], bd.strides[0][2]);
122         if (utils::one_of(md.format(), OIhw8o16i2o, IOhw8o16i2o,
123                                        OIdhw8o16i2o, IOdhw8o16i2o))
124             P(3, bd.padding_dims[3], bd.strides[0][3]);
125         if (utils::one_of(md.format(), OIdhw8o16i2o, IOdhw8o16i2o))
126             P(4, bd.padding_dims[4], bd.strides[0][4]);
127         return success;
128     case gOIhw2i8o4i:
129         P(0, bd.padding_dims[0], bd.strides[0][0]);
130         P(1, bd.padding_dims[1] / 8, bd.strides[0][1]);
131         P(1, 8, 4);
132         P(2, bd.padding_dims[2] / 8, bd.strides[0][2]);
133         P(2, 2, 8*4);
134         P(2, 4, 1);
135         P(3, bd.padding_dims[3], bd.strides[0][3]);
136         P(4, bd.padding_dims[4], bd.strides[0][4]);
137         return success;
138     case gOIw4i16o4i:
139     case gOIhw4i16o4i:
140         P(0, bd.padding_dims[0], bd.strides[0][0]);
141         P(1, bd.padding_dims[1] / 16, bd.strides[0][1]);
142         P(1, 16, 4);
143         P(2, bd.padding_dims[2] / 16, bd.strides[0][2]);
144         P(2, 4, 16*4);
145         P(2, 4, 1);
146         P(3, bd.padding_dims[3], bd.strides[0][3]);
147         if (md.format() == gOIhw4i16o4i)
148             P(4, bd.padding_dims[4], bd.strides[0][4]);
149         return success;
150     case gOIw8i16o2i:
151     case gOIhw8i16o2i:
152     case gIOhw8i16o2i:
153     case gOIdhw8i16o2i:
154         P(0, bd.padding_dims[0], bd.strides[0][0]);
155         P(1, bd.padding_dims[1] / 16, bd.strides[0][1]);
156         P(1, 16, 2);
157         P(2, bd.padding_dims[2] / 16, bd.strides[0][2]);
158         P(2, 8, 16*2);
159         P(2, 2, 1);
160         P(3, bd.padding_dims[3], bd.strides[0][3]);
161         if (utils::one_of(md.format(), gOIhw8i16o2i, gIOhw8i16o2i)
162                 || md.format() == gOIdhw8i16o2i)
163             P(4, bd.padding_dims[4], bd.strides[0][4]);
164         if (md.format() == gOIdhw8i16o2i)
165             P(5, bd.padding_dims[5], bd.strides[0][5]);
166         return success;
167     case gOIw8o16i2o:
168     case gIOw8o16i2o:
169     case gOIhw8o16i2o:
170     case gIOhw8o16i2o:
171     case gOIdhw8o16i2o:
172     case gIOdhw8o16i2o:
173         P(0, bd.padding_dims[0], bd.strides[0][0]);
174         P(1, bd.padding_dims[1] / 16, bd.strides[0][1]);
175         P(1, 8, 16*2);
176         P(1, 2, 1);
177         P(2, bd.padding_dims[2] / 16, bd.strides[0][2]);
178         P(2, 16, 2);
179         P(3, bd.padding_dims[3], bd.strides[0][3]);
180         if (utils::one_of(md.format(), gOIhw8o16i2o, gIOhw8o16i2o,
181                                        gOIdhw8o16i2o, gIOdhw8o16i2o))
182             P(4, bd.padding_dims[4], bd.strides[0][4]);
183         if (utils::one_of(md.format(), gOIdhw8o16i2o, gIOdhw8o16i2o))
184             P(5, bd.padding_dims[5], bd.strides[0][5]);
185         return success;
186     default: break;
187     }
188
189     /* regular blocked format */
190     for (int d = 0; d < md.ndims(); ++d) {
191         P(d, bd.padding_dims[d] / bd.block_dims[d], bd.strides[0][d]);
192         if (bd.block_dims[d] != 1)
193             P(d, bd.block_dims[d], bd.strides[1][d]);
194     }
195
196     return success;
197 }
198
199 status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd,
200         const primitive_attr_t *attr) {
201     auto im_d = memory_desc_wrapper(imd);
202     auto om_d = memory_desc_wrapper(omd);
203
204     bool ok = true
205         && im_d.is_blocking_desc()
206         && om_d.is_blocking_desc()
207         && !im_d.has_zero_dim()
208         && !om_d.has_zero_dim();
209     if (!ok)
210         return unimplemented;
211
212     /* padding_dim consistency check */
213     for (int d = 0; d < im_d.ndims(); ++d) {
214         const auto pdim = im_d.blocking_desc().padding_dims[d];
215         bool ok = true
216             && pdim == om_d.blocking_desc().padding_dims[d]
217             && pdim % im_d.blocking_desc().block_dims[d] == 0
218             && pdim % om_d.blocking_desc().block_dims[d] == 0;
219             if (!ok) return unimplemented;
220     }
221
222     layout_desc_t ild, old;
223     status_t status = cvt_mem_desc_to_layout_desc(imd, ild);
224     if (status != success) return status;
225     status = cvt_mem_desc_to_layout_desc(omd, old);
226     if (status != success) return status;
227
228     p.itype = ild.dt;
229     p.otype = old.dt;
230
231     p.scale_type = attr->output_scales_.has_default_values()
232         ? scale_type_t::NONE
233         : (attr->output_scales_.mask_ == 0
234                 ? scale_type_t::COMMON
235                 : scale_type_t::MANY);
236
237     ptrdiff_t ss[max_ndims] = {0};
238     if (p.scale_type == scale_type_t::MANY) {
239         ptrdiff_t last_ss = 1;
240         for (int d = old.ndims - 1; d >=0; --d) {
241             assert((d == 0 || old.id[d - 1] <= old.id[d])
242                     && "logical dimensions should be in ascending order");
243             if (attr->output_scales_.mask_ & (1 << old.id[d])) {
244                 ss[d] = last_ss;
245                 last_ss *= old.dims[d];
246             }
247         }
248     }
249
250     int ndims = 0;
251
252     int i_pos = 0; /* state for input  -- current dimension */
253     int o_pos = 0; /* state for output -- current dimension */
254
255     while (i_pos < ild.ndims && o_pos < old.ndims) {
256         assert(ild.id[i_pos] == old.id[o_pos]);
257         if (ild.id[i_pos] != old.id[o_pos])
258             return runtime_error;
259
260         assert(ndims < max_ndims);
261         if (ndims == max_ndims)
262             return runtime_error;
263
264         if (ild.dims[i_pos] == old.dims[o_pos]) {
265             p.nodes[ndims].n = ild.dims[i_pos];
266             p.nodes[ndims].is = ild.strides[i_pos];
267             p.nodes[ndims].os = old.strides[o_pos];
268             p.nodes[ndims].ss = ss[o_pos];
269             ++ndims;
270             ++i_pos;
271             ++o_pos;
272         } else if (ild.dims[i_pos] < old.dims[o_pos]) {
273             assert(old.dims[o_pos] % ild.dims[i_pos] == 0);
274             int factor = old.dims[o_pos] / ild.dims[i_pos];
275             p.nodes[ndims].n = ild.dims[i_pos];
276             p.nodes[ndims].is = ild.strides[i_pos];
277             p.nodes[ndims].os = old.strides[o_pos] * factor;
278             p.nodes[ndims].ss = ss[o_pos] * factor;
279             ++ndims;
280             ++i_pos;
281             old.dims[o_pos] = factor;
282         } else if (ild.dims[i_pos] > old.dims[o_pos]) {
283             assert(ild.dims[i_pos] % old.dims[o_pos] == 0);
284             int factor = ild.dims[i_pos] / old.dims[o_pos];
285             p.nodes[ndims].n = old.dims[o_pos];
286             p.nodes[ndims].is = ild.strides[i_pos] * factor;
287             p.nodes[ndims].os = old.strides[o_pos];
288             p.nodes[ndims].ss = ss[o_pos];
289             ++ndims;
290             ++o_pos;
291             ild.dims[i_pos] = factor;
292         }
293     }
294     p.ndims = ndims;
295
296     dims_t zero_pos = {0};
297     p.ioff = memory_desc_wrapper(imd).off_v(zero_pos);
298     p.ooff = memory_desc_wrapper(omd).off_v(zero_pos);
299
300     const int sum_idx = attr->post_ops_.find(primitive_kind::sum);
301     p.beta = sum_idx == -1 ? 0.f : attr->post_ops_.entry_[sum_idx].sum.scale;
302
303     return success;
304 }
305
306 void prb_normalize(prb_t &p) {
307     for (int d = 0; d < p.ndims; ++d) {
308         int min_pos = d;
309         for (int j = d + 1; j < p.ndims; ++j) {
310             bool new_min = false
311                 || p.nodes[j].os < p.nodes[min_pos].os
312                 || (true
313                         && p.nodes[j].os == p.nodes[min_pos].os
314                         && p.nodes[j].n < p.nodes[min_pos].n);
315             if (new_min) min_pos = j;
316         }
317         if (min_pos != d)
318             nstl::swap(p.nodes[d], p.nodes[min_pos]);
319     }
320 }
321
322 void prb_simplify(prb_t &p) {
323 #if defined(__GNUC__) && __GNUC__ >= 4
324 /* GCC produces bogus array subscript is above array bounds warning for
325  * the `p.nodes[j - 1] = p.nodes[j]` line below, so disable it for now. */
326 #pragma GCC diagnostic push
327 #pragma GCC diagnostic ignored "-Warray-bounds"
328 #endif
329     for (int d = 0; d < p.ndims - 1; ++d) {
330         auto &this_node = p.nodes[d + 0];
331         auto &next_node = p.nodes[d + 1];
332         const bool fold = false
333             || next_node.n == (size_t)1 // trivial case, just drop next node
334             || (true // or real folding if possible
335                     && next_node.is == (ptrdiff_t)this_node.n * this_node.is
336                     && next_node.os == (ptrdiff_t)this_node.n * this_node.os
337                     && next_node.ss == (ptrdiff_t)this_node.n * this_node.ss);
338         if (fold) {
339             this_node.n *= next_node.n;
340             for (int j = d + 2; j < p.ndims; ++j)
341                 p.nodes[j - 1] = p.nodes[j];
342             --p.ndims;
343             --d; // make another try
344         }
345     }
346 #if defined(__GNUC__) && __GNUC__ >= 4
347 #pragma GCC diagnostic pop
348 #endif
349 }
350
351 void prb_node_split(prb_t &p, int dim, size_t n1) {
352     assert(dim < p.ndims);
353     assert(p.ndims < max_ndims);
354     assert(p.nodes[dim].n % n1 == 0);
355
356     p.ndims += 1;
357
358     for (int d = p.ndims; d > dim + 1; --d)
359         p.nodes[d] = p.nodes[d - 1];
360
361     p.nodes[dim + 1].n = p.nodes[dim].n / n1;
362     p.nodes[dim + 1].is = p.nodes[dim].is * n1;
363     p.nodes[dim + 1].os = p.nodes[dim].os * n1;
364     p.nodes[dim + 1].ss = p.nodes[dim].ss * n1;
365
366     p.nodes[dim].n = n1;
367 }
368
369 void prb_node_swap(prb_t &p, int d0, int d1) {
370     assert(d0 < p.ndims);
371     assert(d1 < p.ndims);
372     assert(p.ndims < max_ndims);
373
374     if (d0 == d1) return;
375
376     nstl::swap(p.nodes[d0], p.nodes[d1]);
377 }
378
379 void prb_node_move(prb_t &p, int d0, int d1) {
380     assert(d0 < p.ndims);
381     assert(d1 < p.ndims);
382     assert(p.ndims < max_ndims);
383
384     if (d0 == d1) return;
385
386     node_t node = p.nodes[d0];
387
388     if (d0 < d1)
389         for (int d = d0; d < d1; ++d)
390             p.nodes[d] = p.nodes[d + 1];
391     else
392         for (int d = d0; d > d1; --d)
393             p.nodes[d] = p.nodes[d - 1];
394
395     p.nodes[d1] = node;
396 }
397
398 void prb_dump(const prb_t &p) {
399     printf("@@@ type:%s:%s ndims:%d ", mkldnn_dt2str(p.itype),
400             mkldnn_dt2str(p.otype), p.ndims);
401     for (int d = 0; d < p.ndims; ++d)
402         printf("[%zu:%td:%td:%td]",
403                 p.nodes[d].n, p.nodes[d].is, p.nodes[d].os, p.nodes[d].ss);
404     printf(" off:%zu:%zu\n", p.ioff, p.ooff);
405 }
406
407 }
408
409 }
410 }
411 }