1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
19 #include "cpu_engine.hpp"
20 #include "cpu_memory.hpp"
21 #include "type_helpers.hpp"
23 #include "cpu/jit_uni_reorder.hpp"
24 #include "cpu/simple_reorder.hpp"
25 #include "cpu/wino_reorder.hpp"
26 #include "cpu/rnn/rnn_reorders.hpp"
32 using rpd_create_f = mkldnn::impl::engine_t::reorder_primitive_desc_create_f;
35 using namespace mkldnn::impl::data_type;
36 using namespace mkldnn::impl::memory_format;
38 #define REG_SR(idt, ifmt, odt, ofmt, ...) \
39 simple_reorder_t<idt, ifmt, odt, ofmt, __VA_ARGS__>::pd_t::create
41 #define REG_SR_BIDIR(idt, ifmt, odt, ofmt) \
42 REG_SR(idt, ifmt, odt, ofmt, fmt_order::keep), \
43 REG_SR(idt, ifmt, odt, ofmt, fmt_order::reverse)
45 #define REG_SR_DIRECT_COPY(idt, odt) \
46 REG_SR(idt, any, odt, any, fmt_order::any, spec::direct_copy), \
47 REG_SR(idt, any, odt, any, fmt_order::any, spec::direct_copy_except_dim_0)
49 static const rpd_create_f cpu_reorder_impl_list[] = {
51 wino_reorder_t<f32, f32>::pd_t::create,
52 wino_reorder_t<f32, s8>::pd_t::create,
55 rnn_data_reorder_t<f32, u8>::pd_t::create,
56 rnn_weights_reorder_t<f32, f32>::pd_t::create,
57 rnn_weights_reorder_t<f32, s8>::pd_t::create,
59 #if defined(__INTEL_COMPILER) || (defined(__GNUC__) && !defined(__clang__))
60 /* Direct copy for icc which is faster than jitted code;
61 * Direct copy for gcc which might or might not be faster than jitted
62 * code, but still worth it because doesn't require jitting, i.e. much
63 * faster creation time. This is tentative solution and should be removed
64 * later (when we will cache jitted code?...). */
65 REG_SR_DIRECT_COPY(f32, f32),
68 #ifdef __INTEL_COMPILER
69 /* direct copy for icc, which is faster than jitted code */
70 REG_SR_DIRECT_COPY(f32, s32),
71 REG_SR_DIRECT_COPY(f32, s8),
72 // REG_SR_DIRECT_COPY(f32, u8), FIXME: Disabled due to accuracy failure on int8 network
73 REG_SR_DIRECT_COPY(s32, f32),
74 REG_SR_DIRECT_COPY(s32, s32),
75 REG_SR_DIRECT_COPY(s32, s8),
76 REG_SR_DIRECT_COPY(s32, u8),
77 REG_SR_DIRECT_COPY(s8, f32),
78 REG_SR_DIRECT_COPY(s8, s32),
79 REG_SR_DIRECT_COPY(s8, s8),
80 REG_SR_DIRECT_COPY(s8, u8),
81 REG_SR_DIRECT_COPY(u8, f32),
82 REG_SR_DIRECT_COPY(u8, s32),
83 REG_SR_DIRECT_COPY(u8, s8),
84 REG_SR_DIRECT_COPY(u8, u8),
88 jit_uni_reorder_create,
90 /* fp32: flat <-> blocked with tail */
91 REG_SR_BIDIR(f32, any, f32, nCw4c),
93 REG_SR_BIDIR(f32, nchw, bin, nhwc),
94 REG_SR_BIDIR(f32, nhwc, bin, nhwc),
95 REG_SR_DIRECT_COPY(bin, bin),
97 REG_SR_BIDIR(f32, any, f32, nCw8c),
98 REG_SR_BIDIR(f32, any, f32, OIw4i4o),
99 REG_SR_BIDIR(f32, any, f32, OIw8i8o),
100 REG_SR_BIDIR(f32, any, f32, OIw8o8i),
101 REG_SR_BIDIR(f32, any, f32, gOIw4i4o),
102 REG_SR_BIDIR(f32, any, f32, gOIw8i8o),
103 REG_SR_BIDIR(f32, any, f32, gOIw8o8i),
105 REG_SR_BIDIR(f32, any, f32, nCw16c),
106 REG_SR_BIDIR(f32, any, f32, OIw16o16i),
107 REG_SR_BIDIR(f32, any, f32, OIw16i16o),
108 REG_SR_BIDIR(f32, any, f32, IOw16o16i),
109 REG_SR_BIDIR(f32, any, f32, gOIw16o16i),
110 REG_SR_BIDIR(f32, any, f32, gOIw16i16o),
111 REG_SR_BIDIR(f32, any, f32, gIOw16o16i),
113 REG_SR_BIDIR(f32, any, f32, nChw4c),
114 REG_SR_BIDIR(f32, any, f32, nChw8c),
115 REG_SR_BIDIR(f32, any, f32, OIhw4i4o),
116 REG_SR_BIDIR(f32, any, f32, Ohwi8o),
117 REG_SR_BIDIR(f32, any, f32, OIhw8i8o),
118 REG_SR_BIDIR(f32, any, f32, OIhw8o8i),
119 REG_SR_BIDIR(f32, any, f32, gOIhw4i4o),
120 REG_SR_BIDIR(f32, any, f32, gOIhw4o4i),
121 REG_SR_BIDIR(f32, any, f32, gOhwi8o),
122 REG_SR_BIDIR(f32, any, f32, gOIhw8i8o),
123 REG_SR_BIDIR(f32, any, f32, gOIhw8o8i),
125 REG_SR_BIDIR(f32, any, f32, nChw16c),
126 REG_SR_BIDIR(f32, any, f32, Oihw4o),
127 REG_SR_BIDIR(f32, any, f32, Oihw16o),
128 REG_SR_BIDIR(f32, any, f32, Ohwi4o),
129 REG_SR_BIDIR(f32, any, f32, Ohwi16o),
130 REG_SR_BIDIR(f32, any, f32, OIhw16o16i),
131 REG_SR_BIDIR(f32, any, f32, OIhw16i16o),
132 REG_SR_BIDIR(f32, any, f32, IOhw16o16i),
133 REG_SR_BIDIR(f32, any, f32, gOihw4o),
134 REG_SR_BIDIR(f32, any, f32, gOihw16o),
135 REG_SR_BIDIR(f32, any, f32, gOhwi4o),
136 REG_SR_BIDIR(f32, any, f32, gOhwi16o),
137 REG_SR_BIDIR(f32, any, f32, gOIhw16o16i),
138 REG_SR_BIDIR(f32, any, f32, gOIhw16i16o),
139 REG_SR_BIDIR(f32, any, f32, gIOhw16o16i),
141 REG_SR_BIDIR(f32, any, f32, nCdhw4c),
142 REG_SR_BIDIR(f32, any, f32, nCdhw8c),
143 REG_SR_BIDIR(f32, any, f32, OIdhw4i4o),
144 REG_SR_BIDIR(f32, any, f32, Odhwi8o),
145 REG_SR_BIDIR(f32, any, f32, OIdhw8i8o),
146 REG_SR_BIDIR(f32, any, f32, OIdhw8o8i),
147 REG_SR_BIDIR(f32, any, f32, gOIdhw4i4o),
148 REG_SR_BIDIR(f32, any, f32, gOdhwi8o),
149 REG_SR_BIDIR(f32, any, f32, gOIdhw8i8o),
150 REG_SR_BIDIR(f32, any, f32, gOIdhw8o8i),
152 REG_SR_BIDIR(f32, any, f32, nCdhw16c),
153 REG_SR_BIDIR(f32, any, f32, Oidhw4o),
154 REG_SR_BIDIR(f32, any, f32, Oidhw16o),
155 REG_SR_BIDIR(f32, any, f32, Odhwi16o),
156 REG_SR_BIDIR(f32, any, f32, OIdhw16o16i),
157 REG_SR_BIDIR(f32, any, f32, OIdhw16i16o),
158 REG_SR_BIDIR(f32, any, f32, gOidhw4o),
159 REG_SR_BIDIR(f32, any, f32, gOidhw16o),
160 REG_SR_BIDIR(f32, any, f32, gOdhwi16o),
161 REG_SR_BIDIR(f32, any, f32, gOIdhw16o16i),
162 REG_SR_BIDIR(f32, any, f32, gOIdhw16i16o),
164 /* WA to prevent fallback on reference implementations */
165 REG_SR_DIRECT_COPY(u8, f32),
166 REG_SR_DIRECT_COPY(u8, s8),
167 REG_SR_DIRECT_COPY(s8, u8),
168 REG_SR_DIRECT_COPY(u8, u8),
169 REG_SR_DIRECT_COPY(s8, s8),
171 /* fp32: blocked <-> blocked with tail */
172 REG_SR_BIDIR(f32, nCw8c, f32, nCw16c),
173 REG_SR_BIDIR(f32, nChw8c, f32, nChw16c),
174 REG_SR_BIDIR(f32, nCdhw8c, f32, nCdhw16c),
176 /* int: flat <-> blocked with tail */
177 REG_SR(f32, nChw8c, u8, nhwc, fmt_order::keep),
178 REG_SR(f32, nChw8c, s8, nhwc, fmt_order::keep),
179 REG_SR(u8, nhwc, f32, nChw8c, fmt_order::keep),
180 REG_SR(s8, nhwc, f32, nChw8c, fmt_order::keep),
181 REG_SR(f32, nhwc, u8, nhwc, fmt_order::keep),
182 REG_SR(f32, nhwc, s8, nhwc, fmt_order::keep),
183 REG_SR(u8, nhwc, f32, nhwc, fmt_order::keep),
184 REG_SR(s8, nhwc, f32, nhwc, fmt_order::keep),
185 REG_SR(s8, nhwc, u8, nhwc, fmt_order::keep),
186 REG_SR(u8, nhwc, s8, nhwc, fmt_order::keep),
187 REG_SR(u8, nhwc, s8, nhwc, fmt_order::keep),
188 REG_SR(f32, nchw, u8, nhwc, fmt_order::keep),
189 REG_SR(f32, nchw, s8, nhwc, fmt_order::keep),
190 REG_SR(u8, nchw, u8, nhwc, fmt_order::keep),
191 REG_SR(s8, nchw, s8, nhwc, fmt_order::keep),
192 REG_SR(u8, nhwc, f32, nchw, fmt_order::keep),
194 REG_SR_BIDIR(f32, any, s32, nChw8c),
195 REG_SR_BIDIR(f32, any, s8, nChw8c),
196 REG_SR_BIDIR(f32, any, u8, nChw8c),
197 REG_SR_BIDIR(s32, any, f32, nChw8c),
198 REG_SR_BIDIR(s32, any, s32, nChw8c),
199 REG_SR_BIDIR(s32, any, s8, nChw8c),
200 REG_SR_BIDIR(s32, any, u8, nChw8c),
201 REG_SR_BIDIR(s8, any, f32, nChw8c),
202 REG_SR_BIDIR(s8, any, s32, nChw8c),
203 REG_SR_BIDIR(s8, any, s8, nChw8c),
204 REG_SR_BIDIR(s8, any, u8, nChw8c),
205 REG_SR_BIDIR(u8, any, f32, nChw8c),
206 REG_SR_BIDIR(u8, any, s32, nChw8c),
207 REG_SR_BIDIR(u8, any, s8, nChw8c),
208 REG_SR_BIDIR(u8, any, u8, nChw8c),
210 REG_SR_BIDIR(f32, any, s32, nChw16c),
211 REG_SR_BIDIR(f32, any, s8, nChw16c),
212 REG_SR_BIDIR(f32, any, u8, nChw16c),
213 REG_SR_BIDIR(s32, any, f32, nChw16c),
214 REG_SR_BIDIR(s32, any, s32, nChw16c),
215 REG_SR_BIDIR(s32, any, s8, nChw16c),
216 REG_SR_BIDIR(s32, any, u8, nChw16c),
217 REG_SR_BIDIR(s8, any, f32, nChw16c),
218 REG_SR_BIDIR(s8, any, s32, nChw16c),
219 REG_SR_BIDIR(s8, any, s8, nChw16c),
220 REG_SR_BIDIR(s8, any, u8, nChw16c),
221 REG_SR_BIDIR(u8, any, f32, nChw16c),
222 REG_SR_BIDIR(u8, any, s32, nChw16c),
223 REG_SR_BIDIR(u8, any, s8, nChw16c),
224 REG_SR_BIDIR(u8, any, u8, nChw16c),
226 REG_SR_BIDIR(f32, any, f32, OIhw4i16o4i),
227 REG_SR_BIDIR(f32, any, s8, OIhw4i16o4i),
228 REG_SR_BIDIR(s8, any, f32, OIhw4i16o4i),
229 REG_SR_BIDIR(s8, any, s8, OIhw4i16o4i),
230 REG_SR_BIDIR(f32, any, s8, gOIhw4i16o4i),
231 REG_SR_BIDIR(s8, any, f32, gOIhw4i16o4i),
232 REG_SR_BIDIR(f32, any, f32, gOIhw4i16o4i),
233 REG_SR_BIDIR(s8, any, s8, gOIhw4i16o4i),
235 REG_SR(f32, any, f32, OhIw8o4i, fmt_order::keep),
236 REG_SR(f32, any, s8, OhIw8o4i, fmt_order::keep),
237 REG_SR(s8, any, f32, OhIw8o4i, fmt_order::keep),
238 REG_SR(s8, any, s8, OhIw8o4i, fmt_order::keep),
239 REG_SR(f32, any, s8, gOhIw8o4i, fmt_order::keep),
240 REG_SR(s8, any, f32, gOhIw8o4i, fmt_order::keep),
241 REG_SR(f32, any, f32, gOhIw8o4i, fmt_order::keep),
242 REG_SR(s8, any, s8, gOhIw8o4i, fmt_order::keep),
243 REG_SR(f32, oihw, s8, OhIw8o4i_s8s8, fmt_order::keep),
244 REG_SR(s8, oihw, s8, OhIw8o4i_s8s8, fmt_order::keep),
245 REG_SR(f32, goihw, s8, gOhIw8o4i_s8s8, fmt_order::keep),
246 REG_SR(s8, goihw, s8, gOhIw8o4i_s8s8, fmt_order::keep),
248 REG_SR(bin, any, bin, OhIw8o32i, fmt_order::keep),
249 REG_SR(bin, any, bin, OhIw16o32i, fmt_order::keep),
251 REG_SR(f32, any, s8, hwio_s8s8, fmt_order::keep),
252 REG_SR(f32, any, s8, hwigo_s8s8, fmt_order::keep),
253 REG_SR(s8, any, s8, hwio_s8s8, fmt_order::keep),
254 REG_SR(s8, any, s8, hwigo_s8s8, fmt_order::keep),
256 REG_SR(f32, goihw, s8, gOIhw4o4i_s8s8, fmt_order::keep),
257 REG_SR(s8, goihw, s8, gOIhw4o4i_s8s8, fmt_order::keep),
259 REG_SR(f32, oihw, s8, OIhw4i16o4i_s8s8, fmt_order::keep),
260 REG_SR(f32, goihw, s8, gOIhw4i16o4i_s8s8, fmt_order::keep),
261 REG_SR(s8, oihw, s8, OIhw4i16o4i_s8s8, fmt_order::keep),
262 REG_SR(s8, goihw, s8, gOIhw4i16o4i_s8s8, fmt_order::keep),
264 REG_SR(f32, goihw, s8, gOIhw2i8o4i_s8s8, fmt_order::keep),
265 REG_SR(s8, goihw, s8, gOIhw2i8o4i_s8s8, fmt_order::keep),
267 REG_SR(f32, goihw, s8, Goihw16g_s8s8, fmt_order::keep),
268 REG_SR(s8, goihw, s8, Goihw16g_s8s8, fmt_order::keep),
270 REG_SR_DIRECT_COPY(s16, s16),
272 REG_SR_BIDIR(s16, any, s16, OIhw8i16o2i),
273 REG_SR_BIDIR(s16, any, s16, gOIhw8i16o2i),
274 REG_SR_BIDIR(s16, OIhw8i16o2i, s16, OIhw8o16i2o),
275 REG_SR_BIDIR(s16, gOIhw8i16o2i, s16, gOIhw8o16i2o),
277 /* reference: the last line of defence */
278 REG_SR(f32, any, f32, any, fmt_order::any, spec::reference),
279 REG_SR(f32, any, s32, any, fmt_order::any, spec::reference),
280 REG_SR(f32, any, s16, any, fmt_order::any, spec::reference),
281 REG_SR(f32, any, s8, any, fmt_order::any, spec::reference),
282 REG_SR(f32, any, u8, any, fmt_order::any, spec::reference),
284 REG_SR(s32, any, f32, any, fmt_order::any, spec::reference),
285 REG_SR(s32, any, s32, any, fmt_order::any, spec::reference),
286 REG_SR(s32, any, s16, any, fmt_order::any, spec::reference),
287 REG_SR(s32, any, s8, any, fmt_order::any, spec::reference),
288 REG_SR(s32, any, u8, any, fmt_order::any, spec::reference),
290 REG_SR(s16, any, f32, any, fmt_order::any, spec::reference),
291 REG_SR(s16, any, s32, any, fmt_order::any, spec::reference),
292 REG_SR(s16, any, s16, any, fmt_order::any, spec::reference),
294 REG_SR(s8, any, f32, any, fmt_order::any, spec::reference),
295 REG_SR(s8, any, s32, any, fmt_order::any, spec::reference),
296 REG_SR(s8, any, s8, any, fmt_order::any, spec::reference),
297 REG_SR(s8, any, u8, any, fmt_order::any, spec::reference),
299 REG_SR(u8, any, f32, any, fmt_order::any, spec::reference),
300 REG_SR(u8, any, s32, any, fmt_order::any, spec::reference),
301 REG_SR(u8, any, u8, any, fmt_order::any, spec::reference),
302 REG_SR(u8, any, s8, any, fmt_order::any, spec::reference),
309 const rpd_create_f *cpu_engine_t::get_reorder_implementation_list() const {
310 return cpu_reorder_impl_list;