1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
19 #include "c_types_map.hpp"
20 #include "memory_desc_wrapper.hpp"
21 #include "mkldnn_debug.h"
23 #include "type_helpers.hpp"
26 #include "cpu_primitive.hpp"
27 #include "cpu_reorder_pd.hpp"
28 #include "jit_uni_reorder.hpp"
30 using namespace mkldnn::impl::types;
31 using namespace mkldnn::impl::status;
39 /** ad-hoc structure to describe blocked memory layout */
40 struct layout_desc_t {
48 status_t cvt_mem_desc_to_layout_desc(const memory_desc_t &md_,
50 using namespace mkldnn::impl::memory_format;
51 using namespace mkldnn::impl::data_type;
53 auto md = memory_desc_wrapper(md_);
54 auto bd = md.blocking_desc();
57 ld.dt = md.data_type();
59 auto P = [&ld](int id, int dim, ptrdiff_t stride) {
60 assert((size_t)ld.ndims < sizeof(ld.dims) / sizeof(ld.dims[0]));
62 ld.dims[ld.ndims] = dim;
63 ld.strides[ld.ndims] = stride;
68 switch (md.format()) {
69 case memory_format::undef:
70 case memory_format::any:
74 case gOIhw2i8o4i_s8s8:
75 case gOIw4i16o4i_s8s8:
77 case gOIhw4i16o4i_s8s8:
78 case OIhw4i16o4i_s8s8:
82 return invalid_arguments;
85 P(0, bd.padding_dims[0] / 16, bd.strides[0][0]);
87 P(1, bd.padding_dims[1] / 16, bd.strides[0][1]);
90 P(2, bd.padding_dims[2], bd.strides[0][2]);
91 if (md.format() == OIhw4i16o4i)
92 P(3, bd.padding_dims[3], bd.strides[0][3]);
98 P(0, bd.padding_dims[0] / 16, bd.strides[0][0]);
100 P(1, bd.padding_dims[1] / 16, bd.strides[0][1]);
103 P(2, bd.padding_dims[2], bd.strides[0][2]);
104 if (utils::one_of(md.format(), OIhw8i16o2i, IOhw8i16o2i)
105 || md.format() == OIdhw8i16o2i)
106 P(3, bd.padding_dims[3], bd.strides[0][3]);
107 if (md.format() == OIdhw8i16o2i)
108 P(4, bd.padding_dims[4], bd.strides[0][4]);
116 P(0, bd.padding_dims[0] / 16, bd.strides[0][0]);
119 P(1, bd.padding_dims[1] / 16, bd.strides[0][1]);
121 P(2, bd.padding_dims[2], bd.strides[0][2]);
122 if (utils::one_of(md.format(), OIhw8o16i2o, IOhw8o16i2o,
123 OIdhw8o16i2o, IOdhw8o16i2o))
124 P(3, bd.padding_dims[3], bd.strides[0][3]);
125 if (utils::one_of(md.format(), OIdhw8o16i2o, IOdhw8o16i2o))
126 P(4, bd.padding_dims[4], bd.strides[0][4]);
129 P(0, bd.padding_dims[0], bd.strides[0][0]);
130 P(1, bd.padding_dims[1] / 8, bd.strides[0][1]);
132 P(2, bd.padding_dims[2] / 8, bd.strides[0][2]);
135 P(3, bd.padding_dims[3], bd.strides[0][3]);
136 P(4, bd.padding_dims[4], bd.strides[0][4]);
140 P(0, bd.padding_dims[0], bd.strides[0][0]);
141 P(1, bd.padding_dims[1] / 16, bd.strides[0][1]);
143 P(2, bd.padding_dims[2] / 16, bd.strides[0][2]);
146 P(3, bd.padding_dims[3], bd.strides[0][3]);
147 if (md.format() == gOIhw4i16o4i)
148 P(4, bd.padding_dims[4], bd.strides[0][4]);
154 P(0, bd.padding_dims[0], bd.strides[0][0]);
155 P(1, bd.padding_dims[1] / 16, bd.strides[0][1]);
157 P(2, bd.padding_dims[2] / 16, bd.strides[0][2]);
160 P(3, bd.padding_dims[3], bd.strides[0][3]);
161 if (utils::one_of(md.format(), gOIhw8i16o2i, gIOhw8i16o2i)
162 || md.format() == gOIdhw8i16o2i)
163 P(4, bd.padding_dims[4], bd.strides[0][4]);
164 if (md.format() == gOIdhw8i16o2i)
165 P(5, bd.padding_dims[5], bd.strides[0][5]);
173 P(0, bd.padding_dims[0], bd.strides[0][0]);
174 P(1, bd.padding_dims[1] / 16, bd.strides[0][1]);
177 P(2, bd.padding_dims[2] / 16, bd.strides[0][2]);
179 P(3, bd.padding_dims[3], bd.strides[0][3]);
180 if (utils::one_of(md.format(), gOIhw8o16i2o, gIOhw8o16i2o,
181 gOIdhw8o16i2o, gIOdhw8o16i2o))
182 P(4, bd.padding_dims[4], bd.strides[0][4]);
183 if (utils::one_of(md.format(), gOIdhw8o16i2o, gIOdhw8o16i2o))
184 P(5, bd.padding_dims[5], bd.strides[0][5]);
189 /* regular blocked format */
190 for (int d = 0; d < md.ndims(); ++d) {
191 P(d, bd.padding_dims[d] / bd.block_dims[d], bd.strides[0][d]);
192 if (bd.block_dims[d] != 1)
193 P(d, bd.block_dims[d], bd.strides[1][d]);
199 status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd,
200 const primitive_attr_t *attr) {
201 auto im_d = memory_desc_wrapper(imd);
202 auto om_d = memory_desc_wrapper(omd);
205 && im_d.is_blocking_desc()
206 && om_d.is_blocking_desc()
207 && !im_d.has_zero_dim()
208 && !om_d.has_zero_dim();
210 return unimplemented;
212 /* padding_dim consistency check */
213 for (int d = 0; d < im_d.ndims(); ++d) {
214 const auto pdim = im_d.blocking_desc().padding_dims[d];
216 && pdim == om_d.blocking_desc().padding_dims[d]
217 && pdim % im_d.blocking_desc().block_dims[d] == 0
218 && pdim % om_d.blocking_desc().block_dims[d] == 0;
219 if (!ok) return unimplemented;
222 layout_desc_t ild, old;
223 status_t status = cvt_mem_desc_to_layout_desc(imd, ild);
224 if (status != success) return status;
225 status = cvt_mem_desc_to_layout_desc(omd, old);
226 if (status != success) return status;
231 p.scale_type = attr->output_scales_.has_default_values()
233 : (attr->output_scales_.mask_ == 0
234 ? scale_type_t::COMMON
235 : scale_type_t::MANY);
237 ptrdiff_t ss[max_ndims] = {0};
238 if (p.scale_type == scale_type_t::MANY) {
239 ptrdiff_t last_ss = 1;
240 for (int d = old.ndims - 1; d >=0; --d) {
241 assert((d == 0 || old.id[d - 1] <= old.id[d])
242 && "logical dimensions should be in ascending order");
243 if (attr->output_scales_.mask_ & (1 << old.id[d])) {
245 last_ss *= old.dims[d];
252 int i_pos = 0; /* state for input -- current dimension */
253 int o_pos = 0; /* state for output -- current dimension */
255 while (i_pos < ild.ndims && o_pos < old.ndims) {
256 assert(ild.id[i_pos] == old.id[o_pos]);
257 if (ild.id[i_pos] != old.id[o_pos])
258 return runtime_error;
260 assert(ndims < max_ndims);
261 if (ndims == max_ndims)
262 return runtime_error;
264 if (ild.dims[i_pos] == old.dims[o_pos]) {
265 p.nodes[ndims].n = ild.dims[i_pos];
266 p.nodes[ndims].is = ild.strides[i_pos];
267 p.nodes[ndims].os = old.strides[o_pos];
268 p.nodes[ndims].ss = ss[o_pos];
272 } else if (ild.dims[i_pos] < old.dims[o_pos]) {
273 assert(old.dims[o_pos] % ild.dims[i_pos] == 0);
274 int factor = old.dims[o_pos] / ild.dims[i_pos];
275 p.nodes[ndims].n = ild.dims[i_pos];
276 p.nodes[ndims].is = ild.strides[i_pos];
277 p.nodes[ndims].os = old.strides[o_pos] * factor;
278 p.nodes[ndims].ss = ss[o_pos] * factor;
281 old.dims[o_pos] = factor;
282 } else if (ild.dims[i_pos] > old.dims[o_pos]) {
283 assert(ild.dims[i_pos] % old.dims[o_pos] == 0);
284 int factor = ild.dims[i_pos] / old.dims[o_pos];
285 p.nodes[ndims].n = old.dims[o_pos];
286 p.nodes[ndims].is = ild.strides[i_pos] * factor;
287 p.nodes[ndims].os = old.strides[o_pos];
288 p.nodes[ndims].ss = ss[o_pos];
291 ild.dims[i_pos] = factor;
296 dims_t zero_pos = {0};
297 p.ioff = memory_desc_wrapper(imd).off_v(zero_pos);
298 p.ooff = memory_desc_wrapper(omd).off_v(zero_pos);
300 const int sum_idx = attr->post_ops_.find(primitive_kind::sum);
301 p.beta = sum_idx == -1 ? 0.f : attr->post_ops_.entry_[sum_idx].sum.scale;
306 void prb_normalize(prb_t &p) {
307 for (int d = 0; d < p.ndims; ++d) {
309 for (int j = d + 1; j < p.ndims; ++j) {
311 || p.nodes[j].os < p.nodes[min_pos].os
313 && p.nodes[j].os == p.nodes[min_pos].os
314 && p.nodes[j].n < p.nodes[min_pos].n);
315 if (new_min) min_pos = j;
318 nstl::swap(p.nodes[d], p.nodes[min_pos]);
322 void prb_simplify(prb_t &p) {
323 #if defined(__GNUC__) && __GNUC__ >= 4
324 /* GCC produces bogus array subscript is above array bounds warning for
325 * the `p.nodes[j - 1] = p.nodes[j]` line below, so disable it for now. */
326 #pragma GCC diagnostic push
327 #pragma GCC diagnostic ignored "-Warray-bounds"
329 for (int d = 0; d < p.ndims - 1; ++d) {
330 auto &this_node = p.nodes[d + 0];
331 auto &next_node = p.nodes[d + 1];
332 const bool fold = false
333 || next_node.n == (size_t)1 // trivial case, just drop next node
334 || (true // or real folding if possible
335 && next_node.is == (ptrdiff_t)this_node.n * this_node.is
336 && next_node.os == (ptrdiff_t)this_node.n * this_node.os
337 && next_node.ss == (ptrdiff_t)this_node.n * this_node.ss);
339 this_node.n *= next_node.n;
340 for (int j = d + 2; j < p.ndims; ++j)
341 p.nodes[j - 1] = p.nodes[j];
343 --d; // make another try
346 #if defined(__GNUC__) && __GNUC__ >= 4
347 #pragma GCC diagnostic pop
351 void prb_node_split(prb_t &p, int dim, size_t n1) {
352 assert(dim < p.ndims);
353 assert(p.ndims < max_ndims);
354 assert(p.nodes[dim].n % n1 == 0);
358 for (int d = p.ndims; d > dim + 1; --d)
359 p.nodes[d] = p.nodes[d - 1];
361 p.nodes[dim + 1].n = p.nodes[dim].n / n1;
362 p.nodes[dim + 1].is = p.nodes[dim].is * n1;
363 p.nodes[dim + 1].os = p.nodes[dim].os * n1;
364 p.nodes[dim + 1].ss = p.nodes[dim].ss * n1;
369 void prb_node_swap(prb_t &p, int d0, int d1) {
370 assert(d0 < p.ndims);
371 assert(d1 < p.ndims);
372 assert(p.ndims < max_ndims);
374 if (d0 == d1) return;
376 nstl::swap(p.nodes[d0], p.nodes[d1]);
379 void prb_node_move(prb_t &p, int d0, int d1) {
380 assert(d0 < p.ndims);
381 assert(d1 < p.ndims);
382 assert(p.ndims < max_ndims);
384 if (d0 == d1) return;
386 node_t node = p.nodes[d0];
389 for (int d = d0; d < d1; ++d)
390 p.nodes[d] = p.nodes[d + 1];
392 for (int d = d0; d > d1; --d)
393 p.nodes[d] = p.nodes[d - 1];
398 void prb_dump(const prb_t &p) {
399 printf("@@@ type:%s:%s ndims:%d ", mkldnn_dt2str(p.itype),
400 mkldnn_dt2str(p.otype), p.ndims);
401 for (int d = 0; d < p.ndims; ++d)
402 printf("[%zu:%td:%td:%td]",
403 p.nodes[d].n, p.nodes[d].is, p.nodes[d].os, p.nodes[d].ss);
404 printf(" off:%zu:%zu\n", p.ioff, p.ooff);