Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_reorder_utils.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <assert.h>
18
19 #include "c_types_map.hpp"
20 #include "memory_desc_wrapper.hpp"
21 #include "mkldnn_debug.h"
22 #include "nstl.hpp"
23 #include "type_helpers.hpp"
24 #include "utils.hpp"
25
26 #include "cpu_primitive.hpp"
27 #include "cpu_reorder_pd.hpp"
28 #include "jit_uni_reorder.hpp"
29
30 using namespace mkldnn::impl::types;
31 using namespace mkldnn::impl::status;
32
33 namespace mkldnn {
34 namespace impl {
35 namespace cpu {
36
37 namespace tr {
38
39 /** ad-hoc structure to describe blocked memory layout */
40 struct layout_desc_t {
41     data_type_t dt;
42     int ndims;
43     dims_t id;
44     dims_t dims;
45     strides_t strides;
46 };
47
48 status_t cvt_mem_desc_to_layout_desc(const memory_desc_t &md_,
49         layout_desc_t &ld) {
50     using namespace mkldnn::impl::memory_format;
51
52     auto md = memory_desc_wrapper(md_);
53     auto bd = md.blocking_desc();
54
55     ld.ndims = 0;
56     ld.dt = md.data_type();
57
58     auto P = [&ld](int id, int dim, ptrdiff_t stride) {
59         assert((size_t)ld.ndims < sizeof(ld.dims) / sizeof(ld.dims[0]));
60         ld.id[ld.ndims] = id;
61         ld.dims[ld.ndims] = dim;
62         ld.strides[ld.ndims] = stride;
63         ++ld.ndims;
64     };
65
66     /* special cases */
67     switch (md.format()) {
68     case memory_format::undef:
69     case memory_format::any:
70     case hwio_s8s8:
71     case hwigo_s8s8:
72     case gOIhw4o4i_s8s8:
73     case gOIhw2i8o4i_s8s8:
74     case gOIhw4i16o4i_s8s8:
75     case OIhw4i16o4i_s8s8:
76     case Goihw16g_s8s8:
77     case wino_fmt:
78         return invalid_arguments;
79     case OIhw4i16o4i:
80         P(0, bd.padding_dims[0] / 16, bd.strides[0][0]);
81         P(0, 16, 4);
82         P(1, bd.padding_dims[1] / 16, bd.strides[0][1]);
83         P(1, 4, 16*4);
84         P(1, 4, 1);
85         P(2, bd.padding_dims[2], bd.strides[0][2]);
86         P(3, bd.padding_dims[3], bd.strides[0][3]);
87         return success;
88     case OIw8i16o2i:
89     case OIhw8i16o2i:
90     case OIdhw8i16o2i:
91         P(0, bd.padding_dims[0] / 16, bd.strides[0][0]);
92         P(0, 16, 2);
93         P(1, bd.padding_dims[1] / 16, bd.strides[0][1]);
94         P(1, 8, 16*2);
95         P(1, 2, 1);
96         P(2, bd.padding_dims[2], bd.strides[0][2]);
97         if (md.format() == OIhw8i16o2i || md.format() == OIdhw8i16o2i)
98             P(3, bd.padding_dims[3], bd.strides[0][3]);
99         if (md.format() == OIdhw8i16o2i)
100             P(4, bd.padding_dims[4], bd.strides[0][4]);
101         return success;
102     case OIw8o16i2o:
103     case OIhw8o16i2o:
104         P(0, bd.padding_dims[0] / 16, bd.strides[0][0]);
105         P(0, 8, 16*2);
106         P(0, 2, 1);
107         P(1, bd.padding_dims[1] / 16, bd.strides[0][1]);
108         P(1, 16, 2);
109         P(2, bd.padding_dims[2], bd.strides[0][2]);
110         if (md.format() == OIhw8o16i2o)
111             P(3, bd.padding_dims[3], bd.strides[0][3]);
112         return success;
113     case gOIhw2i8o4i:
114         P(0, bd.padding_dims[0], bd.strides[0][0]);
115         P(1, bd.padding_dims[1] / 8, bd.strides[0][1]);
116         P(1, 8, 4);
117         P(2, bd.padding_dims[2] / 8, bd.strides[0][2]);
118         P(2, 2, 8*4);
119         P(2, 4, 1);
120         P(3, bd.padding_dims[3], bd.strides[0][3]);
121         P(4, bd.padding_dims[4], bd.strides[0][4]);
122         return success;
123     case gOIhw4i16o4i:
124         P(0, bd.padding_dims[0], bd.strides[0][0]);
125         P(1, bd.padding_dims[1] / 16, bd.strides[0][1]);
126         P(1, 16, 4);
127         P(2, bd.padding_dims[2] / 16, bd.strides[0][2]);
128         P(2, 4, 16*4);
129         P(2, 4, 1);
130         P(3, bd.padding_dims[3], bd.strides[0][3]);
131         P(4, bd.padding_dims[4], bd.strides[0][4]);
132         return success;
133     case gOIw8i16o2i:
134     case gOIhw8i16o2i:
135     case gOIdhw8i16o2i:
136         P(0, bd.padding_dims[0], bd.strides[0][0]);
137         P(1, bd.padding_dims[1] / 16, bd.strides[0][1]);
138         P(1, 16, 2);
139         P(2, bd.padding_dims[2] / 16, bd.strides[0][2]);
140         P(2, 8, 16*2);
141         P(2, 2, 1);
142         P(3, bd.padding_dims[3], bd.strides[0][3]);
143         if (md.format() == gOIhw8i16o2i || md.format() == gOIdhw8i16o2i)
144             P(4, bd.padding_dims[4], bd.strides[0][4]);
145         if (md.format() == gOIdhw8i16o2i)
146             P(5, bd.padding_dims[5], bd.strides[0][5]);
147         return success;
148     case gOIw8o16i2o:
149     case gOIhw8o16i2o:
150         P(0, bd.padding_dims[0], bd.strides[0][0]);
151         P(1, bd.padding_dims[1] / 16, bd.strides[0][1]);
152         P(1, 8, 16*2);
153         P(1, 2, 1);
154         P(2, bd.padding_dims[2] / 16, bd.strides[0][2]);
155         P(2, 16, 2);
156         P(3, bd.padding_dims[3], bd.strides[0][3]);
157         if (md.format() == gOIhw8o16i2o)
158             P(4, bd.padding_dims[4], bd.strides[0][4]);
159         return success;
160     default: break;
161     }
162
163     /* regular blocked format */
164     for (int d = 0; d < md.ndims(); ++d) {
165         P(d, bd.padding_dims[d] / bd.block_dims[d], bd.strides[0][d]);
166         if (bd.block_dims[d] != 1)
167             P(d, bd.block_dims[d], bd.strides[1][d]);
168     }
169
170     return success;
171 }
172
173 status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd,
174         const primitive_attr_t *attr) {
175     auto im_d = memory_desc_wrapper(imd);
176     auto om_d = memory_desc_wrapper(omd);
177
178     bool ok = true
179         && im_d.is_blocking_desc()
180         && om_d.is_blocking_desc()
181         && !im_d.has_zero_dim()
182         && !om_d.has_zero_dim();
183     if (!ok)
184         return unimplemented;
185
186     /* padding_dim consistency check */
187     for (int d = 0; d < im_d.ndims(); ++d) {
188         const auto pdim = im_d.blocking_desc().padding_dims[d];
189         bool ok = true
190             && pdim == om_d.blocking_desc().padding_dims[d]
191             && pdim % im_d.blocking_desc().block_dims[d] == 0
192             && pdim % om_d.blocking_desc().block_dims[d] == 0;
193             if (!ok) return unimplemented;
194     }
195
196     layout_desc_t ild, old;
197     status_t status = cvt_mem_desc_to_layout_desc(imd, ild);
198     if (status != success) return status;
199     status = cvt_mem_desc_to_layout_desc(omd, old);
200     if (status != success) return status;
201
202     p.itype = ild.dt;
203     p.otype = old.dt;
204
205     p.scale_type = attr->output_scales_.has_default_values()
206         ? scale_type_t::NONE
207         : (attr->output_scales_.mask_ == 0
208                 ? scale_type_t::COMMON
209                 : scale_type_t::MANY);
210
211     ptrdiff_t ss[max_ndims] = {0};
212     if (p.scale_type == scale_type_t::MANY) {
213         ptrdiff_t last_ss = 1;
214         for (int d = old.ndims - 1; d >=0; --d) {
215             assert((d == 0 || old.id[d - 1] <= old.id[d])
216                     && "logical dimensions should be in ascending order");
217             if (attr->output_scales_.mask_ & (1 << old.id[d])) {
218                 ss[d] = last_ss;
219                 last_ss *= old.dims[d];
220             }
221         }
222     }
223
224     int ndims = 0;
225
226     int i_pos = 0; /* state for input  -- current dimension */
227     int o_pos = 0; /* state for output -- current dimension */
228
229     while (i_pos < ild.ndims && o_pos < old.ndims) {
230         assert(ild.id[i_pos] == old.id[o_pos]);
231         if (ild.id[i_pos] != old.id[o_pos])
232             return runtime_error;
233
234         assert(ndims < max_ndims);
235         if (ndims == max_ndims)
236             return runtime_error;
237
238         if (ild.dims[i_pos] == old.dims[o_pos]) {
239             p.nodes[ndims].n = ild.dims[i_pos];
240             p.nodes[ndims].is = ild.strides[i_pos];
241             p.nodes[ndims].os = old.strides[o_pos];
242             p.nodes[ndims].ss = ss[o_pos];
243             ++ndims;
244             ++i_pos;
245             ++o_pos;
246         } else if (ild.dims[i_pos] < old.dims[o_pos]) {
247             assert(old.dims[o_pos] % ild.dims[i_pos] == 0);
248             int factor = old.dims[o_pos] / ild.dims[i_pos];
249             p.nodes[ndims].n = ild.dims[i_pos];
250             p.nodes[ndims].is = ild.strides[i_pos];
251             p.nodes[ndims].os = old.strides[o_pos] * factor;
252             p.nodes[ndims].ss = ss[o_pos] * factor;
253             ++ndims;
254             ++i_pos;
255             old.dims[o_pos] = factor;
256         } else if (ild.dims[i_pos] > old.dims[o_pos]) {
257             assert(ild.dims[i_pos] % old.dims[o_pos] == 0);
258             int factor = ild.dims[i_pos] / old.dims[o_pos];
259             p.nodes[ndims].n = old.dims[o_pos];
260             p.nodes[ndims].is = ild.strides[i_pos] * factor;
261             p.nodes[ndims].os = old.strides[o_pos];
262             p.nodes[ndims].ss = ss[o_pos];
263             ++ndims;
264             ++o_pos;
265             ild.dims[i_pos] = factor;
266         }
267     }
268     p.ndims = ndims;
269
270     dims_t zero_pos = {0};
271     p.ioff = memory_desc_wrapper(imd).off_v(zero_pos);
272     p.ooff = memory_desc_wrapper(omd).off_v(zero_pos);
273
274     const int sum_idx = attr->post_ops_.find(primitive_kind::sum);
275     p.beta = sum_idx == -1 ? 0.f : attr->post_ops_.entry_[sum_idx].sum.scale;
276
277     return success;
278 }
279
280 void prb_normalize(prb_t &p) {
281     for (int d = 0; d < p.ndims; ++d) {
282         int min_pos = d;
283         for (int j = d + 1; j < p.ndims; ++j) {
284             bool new_min = false
285                 || p.nodes[j].os < p.nodes[min_pos].os
286                 || (true
287                         && p.nodes[j].os == p.nodes[min_pos].os
288                         && p.nodes[j].n < p.nodes[min_pos].n);
289             if (new_min) min_pos = j;
290         }
291         if (min_pos != d)
292             nstl::swap(p.nodes[d], p.nodes[min_pos]);
293     }
294 }
295
296 void prb_simplify(prb_t &p) {
297 #if defined(__GNUC__) && __GNUC__ >= 4
298 /* GCC produces bogus array subscript is above array bounds warning for
299  * the `p.nodes[j - 1] = p.nodes[j]` line below, so disable it for now. */
300 #pragma GCC diagnostic push
301 #pragma GCC diagnostic ignored "-Warray-bounds"
302 #endif
303     for (int d = 0; d < p.ndims - 1; ++d) {
304         auto &this_node = p.nodes[d + 0];
305         auto &next_node = p.nodes[d + 1];
306         const bool fold = false
307             || next_node.n == (size_t)1 // trivial case, just drop next node
308             || (true // or real folding if possible
309                     && next_node.is == (ptrdiff_t)this_node.n * this_node.is
310                     && next_node.os == (ptrdiff_t)this_node.n * this_node.os
311                     && next_node.ss == (ptrdiff_t)this_node.n * this_node.ss);
312         if (fold) {
313             this_node.n *= next_node.n;
314             for (int j = d + 2; j < p.ndims; ++j)
315                 p.nodes[j - 1] = p.nodes[j];
316             --p.ndims;
317             --d; // make another try
318         }
319     }
320 #if defined(__GNUC__) && __GNUC__ >= 4
321 #pragma GCC diagnostic pop
322 #endif
323 }
324
325 void prb_node_split(prb_t &p, int dim, size_t n1) {
326     assert(dim < p.ndims);
327     assert(p.ndims < max_ndims);
328     assert(p.nodes[dim].n % n1 == 0);
329
330     p.ndims += 1;
331
332     for (int d = p.ndims; d > dim + 1; --d)
333         p.nodes[d] = p.nodes[d - 1];
334
335     p.nodes[dim + 1].n = p.nodes[dim].n / n1;
336     p.nodes[dim + 1].is = p.nodes[dim].is * n1;
337     p.nodes[dim + 1].os = p.nodes[dim].os * n1;
338     p.nodes[dim + 1].ss = p.nodes[dim].ss * n1;
339
340     p.nodes[dim].n = n1;
341 }
342
343 void prb_node_swap(prb_t &p, int d0, int d1) {
344     assert(d0 < p.ndims);
345     assert(d1 < p.ndims);
346     assert(p.ndims < max_ndims);
347
348     if (d0 == d1) return;
349
350     nstl::swap(p.nodes[d0], p.nodes[d1]);
351 }
352
353 void prb_node_move(prb_t &p, int d0, int d1) {
354     assert(d0 < p.ndims);
355     assert(d1 < p.ndims);
356     assert(p.ndims < max_ndims);
357
358     if (d0 == d1) return;
359
360     node_t node = p.nodes[d0];
361
362     if (d0 < d1)
363         for (int d = d0; d < d1; ++d)
364             p.nodes[d] = p.nodes[d + 1];
365     else
366         for (int d = d0; d > d1; --d)
367             p.nodes[d] = p.nodes[d - 1];
368
369     p.nodes[d1] = node;
370 }
371
372 void prb_dump(const prb_t &p) {
373     printf("@@@ type:%s:%s ndims:%d ", mkldnn_dt2str(p.itype),
374             mkldnn_dt2str(p.otype), p.ndims);
375     for (int d = 0; d < p.ndims; ++d)
376         printf("[%zu:%td:%td:%td]",
377                 p.nodes[d].n, p.nodes[d].is, p.nodes[d].os, p.nodes[d].ss);
378     printf(" off:%zu:%zu\n", p.ioff, p.ooff);
379 }
380
381 }
382
383 }
384 }
385 }