1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef _JIT_UNI_REORDER_HPP
18 #define _JIT_UNI_REORDER_HPP
22 #include "c_types_map.hpp"
23 #include "type_helpers.hpp"
25 #include "cpu_primitive.hpp"
26 #include "cpu_reorder_pd.hpp"
34 constexpr int max_ndims = TENSOR_MAX_DIMS;
38 ptrdiff_t is; // input stride
39 ptrdiff_t os; // output stride
40 ptrdiff_t ss; // scale stride
43 enum class scale_type_t { NONE, COMMON, MANY };
49 node_t nodes[max_ndims];
52 scale_type_t scale_type;
56 status_t prb_init(prb_t &prb, const memory_desc_t &imd,
57 const memory_desc_t &omd, const primitive_attr_t *attr);
59 /** sorts the problem nodes so that output strides come in ascending order */
60 void prb_normalize(prb_t &p);
62 /** folds nodes together if possible */
63 void prb_simplify(prb_t &p);
65 /** splits the node dim into two of sizes n1 and n / n1
66 * @warning n must be multiple of n1 */
67 void prb_node_split(prb_t &p, int dim, size_t n1);
69 /** swaps d0 and d1 nodes */
70 void prb_node_swap(prb_t &p, int d0, int d1);
72 /** moves node d0 to the d1 position.
73 * nodes (d0, d1] are shifted to the left if d0 < d1 or
74 * to the right if d0 > d1 */
75 void prb_node_move(prb_t &p, int d0, int d1);
77 /** dumps the problem to stdout */
78 void prb_dump(const prb_t &p);
92 kernel_t(const desc_t &desc): desc_(desc), ker_(nullptr) {}
93 void operator()(const call_param_t *c) const { assert(ker_); ker_(c); }
94 virtual ~kernel_t() {}
96 /** inits kernel descriptor:
97 * desc -- kernel descriptor (output)
98 * prb -- transposition problem (input)
99 * ndims_ker_max -- limit the maximum number of dimensions kernel
100 * will process (optional, 0 -- no limitation) */
101 static status_t desc_init(desc_t &desc, const prb_t &prb,
102 int ndims_ker_max = 0);
104 /** creates kernel for the problem described in desc */
105 static kernel_t *create(const desc_t &desc);
109 const prb_t &prb_ = desc_.prb;
110 void (*ker_)(const call_param_t *);
113 /* TODO: add trans_t class */
117 /* for cpu reorder list */
118 status_t jit_uni_reorder_create(reorder_pd_t **reorder_pd,
119 const memory_pd_t *input_pd, const memory_pd_t *output_pd,
120 const primitive_attr_t *attr);