1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
19 #include "c_types_map.hpp"
20 #include "memory_desc_wrapper.hpp"
21 #include "mkldnn_debug.h"
23 #include "type_helpers.hpp"
26 #include "cpu_primitive.hpp"
27 #include "cpu_reorder_pd.hpp"
28 #include "jit_uni_reorder.hpp"
30 using namespace mkldnn::impl::types;
31 using namespace mkldnn::impl::status;
39 /** ad-hoc structure to describe blocked memory layout */
40 struct layout_desc_t {
48 status_t cvt_mem_desc_to_layout_desc(const memory_desc_t &md_,
50 using namespace mkldnn::impl::memory_format;
52 auto md = memory_desc_wrapper(md_);
53 auto bd = md.blocking_desc();
56 ld.dt = md.data_type();
58 auto P = [&ld](int id, int dim, ptrdiff_t stride) {
59 assert((size_t)ld.ndims < sizeof(ld.dims) / sizeof(ld.dims[0]));
61 ld.dims[ld.ndims] = dim;
62 ld.strides[ld.ndims] = stride;
67 switch (md.format()) {
68 case memory_format::undef:
69 case memory_format::any:
73 case gOIhw2i8o4i_s8s8:
74 case gOIhw4i16o4i_s8s8:
75 case OIhw4i16o4i_s8s8:
78 return invalid_arguments;
80 P(0, bd.padding_dims[0] / 16, bd.strides[0][0]);
82 P(1, bd.padding_dims[1] / 16, bd.strides[0][1]);
85 P(2, bd.padding_dims[2], bd.strides[0][2]);
86 P(3, bd.padding_dims[3], bd.strides[0][3]);
91 P(0, bd.padding_dims[0] / 16, bd.strides[0][0]);
93 P(1, bd.padding_dims[1] / 16, bd.strides[0][1]);
96 P(2, bd.padding_dims[2], bd.strides[0][2]);
97 if (md.format() == OIhw8i16o2i || md.format() == OIdhw8i16o2i)
98 P(3, bd.padding_dims[3], bd.strides[0][3]);
99 if (md.format() == OIdhw8i16o2i)
100 P(4, bd.padding_dims[4], bd.strides[0][4]);
104 P(0, bd.padding_dims[0] / 16, bd.strides[0][0]);
107 P(1, bd.padding_dims[1] / 16, bd.strides[0][1]);
109 P(2, bd.padding_dims[2], bd.strides[0][2]);
110 if (md.format() == OIhw8o16i2o)
111 P(3, bd.padding_dims[3], bd.strides[0][3]);
114 P(0, bd.padding_dims[0], bd.strides[0][0]);
115 P(1, bd.padding_dims[1] / 8, bd.strides[0][1]);
117 P(2, bd.padding_dims[2] / 8, bd.strides[0][2]);
120 P(3, bd.padding_dims[3], bd.strides[0][3]);
121 P(4, bd.padding_dims[4], bd.strides[0][4]);
124 P(0, bd.padding_dims[0], bd.strides[0][0]);
125 P(1, bd.padding_dims[1] / 16, bd.strides[0][1]);
127 P(2, bd.padding_dims[2] / 16, bd.strides[0][2]);
130 P(3, bd.padding_dims[3], bd.strides[0][3]);
131 P(4, bd.padding_dims[4], bd.strides[0][4]);
136 P(0, bd.padding_dims[0], bd.strides[0][0]);
137 P(1, bd.padding_dims[1] / 16, bd.strides[0][1]);
139 P(2, bd.padding_dims[2] / 16, bd.strides[0][2]);
142 P(3, bd.padding_dims[3], bd.strides[0][3]);
143 if (md.format() == gOIhw8i16o2i || md.format() == gOIdhw8i16o2i)
144 P(4, bd.padding_dims[4], bd.strides[0][4]);
145 if (md.format() == gOIdhw8i16o2i)
146 P(5, bd.padding_dims[5], bd.strides[0][5]);
150 P(0, bd.padding_dims[0], bd.strides[0][0]);
151 P(1, bd.padding_dims[1] / 16, bd.strides[0][1]);
154 P(2, bd.padding_dims[2] / 16, bd.strides[0][2]);
156 P(3, bd.padding_dims[3], bd.strides[0][3]);
157 if (md.format() == gOIhw8o16i2o)
158 P(4, bd.padding_dims[4], bd.strides[0][4]);
163 /* regular blocked format */
164 for (int d = 0; d < md.ndims(); ++d) {
165 P(d, bd.padding_dims[d] / bd.block_dims[d], bd.strides[0][d]);
166 if (bd.block_dims[d] != 1)
167 P(d, bd.block_dims[d], bd.strides[1][d]);
173 status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd,
174 const primitive_attr_t *attr) {
175 auto im_d = memory_desc_wrapper(imd);
176 auto om_d = memory_desc_wrapper(omd);
179 && im_d.is_blocking_desc()
180 && om_d.is_blocking_desc()
181 && !im_d.has_zero_dim()
182 && !om_d.has_zero_dim();
184 return unimplemented;
186 /* padding_dim consistency check */
187 for (int d = 0; d < im_d.ndims(); ++d) {
188 const auto pdim = im_d.blocking_desc().padding_dims[d];
190 && pdim == om_d.blocking_desc().padding_dims[d]
191 && pdim % im_d.blocking_desc().block_dims[d] == 0
192 && pdim % om_d.blocking_desc().block_dims[d] == 0;
193 if (!ok) return unimplemented;
196 layout_desc_t ild, old;
197 status_t status = cvt_mem_desc_to_layout_desc(imd, ild);
198 if (status != success) return status;
199 status = cvt_mem_desc_to_layout_desc(omd, old);
200 if (status != success) return status;
205 p.scale_type = attr->output_scales_.has_default_values()
207 : (attr->output_scales_.mask_ == 0
208 ? scale_type_t::COMMON
209 : scale_type_t::MANY);
211 ptrdiff_t ss[max_ndims] = {0};
212 if (p.scale_type == scale_type_t::MANY) {
213 ptrdiff_t last_ss = 1;
214 for (int d = old.ndims - 1; d >=0; --d) {
215 assert((d == 0 || old.id[d - 1] <= old.id[d])
216 && "logical dimensions should be in ascending order");
217 if (attr->output_scales_.mask_ & (1 << old.id[d])) {
219 last_ss *= old.dims[d];
226 int i_pos = 0; /* state for input -- current dimension */
227 int o_pos = 0; /* state for output -- current dimension */
229 while (i_pos < ild.ndims && o_pos < old.ndims) {
230 assert(ild.id[i_pos] == old.id[o_pos]);
231 if (ild.id[i_pos] != old.id[o_pos])
232 return runtime_error;
234 assert(ndims < max_ndims);
235 if (ndims == max_ndims)
236 return runtime_error;
238 if (ild.dims[i_pos] == old.dims[o_pos]) {
239 p.nodes[ndims].n = ild.dims[i_pos];
240 p.nodes[ndims].is = ild.strides[i_pos];
241 p.nodes[ndims].os = old.strides[o_pos];
242 p.nodes[ndims].ss = ss[o_pos];
246 } else if (ild.dims[i_pos] < old.dims[o_pos]) {
247 assert(old.dims[o_pos] % ild.dims[i_pos] == 0);
248 int factor = old.dims[o_pos] / ild.dims[i_pos];
249 p.nodes[ndims].n = ild.dims[i_pos];
250 p.nodes[ndims].is = ild.strides[i_pos];
251 p.nodes[ndims].os = old.strides[o_pos] * factor;
252 p.nodes[ndims].ss = ss[o_pos] * factor;
255 old.dims[o_pos] = factor;
256 } else if (ild.dims[i_pos] > old.dims[o_pos]) {
257 assert(ild.dims[i_pos] % old.dims[o_pos] == 0);
258 int factor = ild.dims[i_pos] / old.dims[o_pos];
259 p.nodes[ndims].n = old.dims[o_pos];
260 p.nodes[ndims].is = ild.strides[i_pos] * factor;
261 p.nodes[ndims].os = old.strides[o_pos];
262 p.nodes[ndims].ss = ss[o_pos];
265 ild.dims[i_pos] = factor;
270 dims_t zero_pos = {0};
271 p.ioff = memory_desc_wrapper(imd).off_v(zero_pos);
272 p.ooff = memory_desc_wrapper(omd).off_v(zero_pos);
274 const int sum_idx = attr->post_ops_.find(primitive_kind::sum);
275 p.beta = sum_idx == -1 ? 0.f : attr->post_ops_.entry_[sum_idx].sum.scale;
280 void prb_normalize(prb_t &p) {
281 for (int d = 0; d < p.ndims; ++d) {
283 for (int j = d + 1; j < p.ndims; ++j) {
285 || p.nodes[j].os < p.nodes[min_pos].os
287 && p.nodes[j].os == p.nodes[min_pos].os
288 && p.nodes[j].n < p.nodes[min_pos].n);
289 if (new_min) min_pos = j;
292 nstl::swap(p.nodes[d], p.nodes[min_pos]);
296 void prb_simplify(prb_t &p) {
297 #if defined(__GNUC__) && __GNUC__ >= 4
298 /* GCC produces bogus array subscript is above array bounds warning for
299 * the `p.nodes[j - 1] = p.nodes[j]` line below, so disable it for now. */
300 #pragma GCC diagnostic push
301 #pragma GCC diagnostic ignored "-Warray-bounds"
303 for (int d = 0; d < p.ndims - 1; ++d) {
304 auto &this_node = p.nodes[d + 0];
305 auto &next_node = p.nodes[d + 1];
306 const bool fold = false
307 || next_node.n == (size_t)1 // trivial case, just drop next node
308 || (true // or real folding if possible
309 && next_node.is == (ptrdiff_t)this_node.n * this_node.is
310 && next_node.os == (ptrdiff_t)this_node.n * this_node.os
311 && next_node.ss == (ptrdiff_t)this_node.n * this_node.ss);
313 this_node.n *= next_node.n;
314 for (int j = d + 2; j < p.ndims; ++j)
315 p.nodes[j - 1] = p.nodes[j];
317 --d; // make another try
320 #if defined(__GNUC__) && __GNUC__ >= 4
321 #pragma GCC diagnostic pop
325 void prb_node_split(prb_t &p, int dim, size_t n1) {
326 assert(dim < p.ndims);
327 assert(p.ndims < max_ndims);
328 assert(p.nodes[dim].n % n1 == 0);
332 for (int d = p.ndims; d > dim + 1; --d)
333 p.nodes[d] = p.nodes[d - 1];
335 p.nodes[dim + 1].n = p.nodes[dim].n / n1;
336 p.nodes[dim + 1].is = p.nodes[dim].is * n1;
337 p.nodes[dim + 1].os = p.nodes[dim].os * n1;
338 p.nodes[dim + 1].ss = p.nodes[dim].ss * n1;
343 void prb_node_swap(prb_t &p, int d0, int d1) {
344 assert(d0 < p.ndims);
345 assert(d1 < p.ndims);
346 assert(p.ndims < max_ndims);
348 if (d0 == d1) return;
350 nstl::swap(p.nodes[d0], p.nodes[d1]);
353 void prb_node_move(prb_t &p, int d0, int d1) {
354 assert(d0 < p.ndims);
355 assert(d1 < p.ndims);
356 assert(p.ndims < max_ndims);
358 if (d0 == d1) return;
360 node_t node = p.nodes[d0];
363 for (int d = d0; d < d1; ++d)
364 p.nodes[d] = p.nodes[d + 1];
366 for (int d = d0; d > d1; --d)
367 p.nodes[d] = p.nodes[d - 1];
372 void prb_dump(const prb_t &p) {
373 printf("@@@ type:%s:%s ndims:%d ", mkldnn_dt2str(p.itype),
374 mkldnn_dt2str(p.otype), p.ndims);
375 for (int d = 0; d < p.ndims; ++d)
376 printf("[%zu:%td:%td:%td]",
377 p.nodes[d].n, p.nodes[d].is, p.nodes[d].os, p.nodes[d].ss);
378 printf(" off:%zu:%zu\n", p.ioff, p.ooff);