updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / cpu_reorder.cpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <assert.h>
18
19 #include "cpu_engine.hpp"
20 #include "cpu_memory.hpp"
21 #include "type_helpers.hpp"
22
23 #include "cpu/jit_uni_reorder.hpp"
24 #include "cpu/simple_reorder.hpp"
25 #include "cpu/wino_reorder.hpp"
26 #include "cpu/rnn/rnn_reorders.hpp"
27
28 namespace mkldnn {
29 namespace impl {
30 namespace cpu {
31
32 using rpd_create_f = mkldnn::impl::engine_t::reorder_primitive_desc_create_f;
33
34 namespace {
35 using namespace mkldnn::impl::data_type;
36 using namespace mkldnn::impl::memory_format;
37
38 #define REG_SR(idt, ifmt, odt, ofmt, ...) \
39     simple_reorder_t<idt, ifmt, odt, ofmt, __VA_ARGS__>::pd_t::create
40
41 #define REG_SR_BIDIR(idt, ifmt, odt, ofmt) \
42     REG_SR(idt, ifmt, odt, ofmt, fmt_order::keep), \
43     REG_SR(idt, ifmt, odt, ofmt, fmt_order::reverse)
44
45 #define REG_SR_DIRECT_COPY(idt, odt) \
46     REG_SR(idt, any, odt, any, fmt_order::any, spec::direct_copy), \
47     REG_SR(idt, any, odt, any, fmt_order::any, spec::direct_copy_except_dim_0)
48
49 static const rpd_create_f cpu_reorder_impl_list[] = {
50     /* winograd */
51     wino_reorder_t<f32, f32>::pd_t::create,
52     wino_reorder_t<f32, s8>::pd_t::create,
53
54     /* rnn reorders */
55     rnn_data_reorder_t<f32, u8>::pd_t::create,
56     rnn_weights_reorder_t<f32, f32>::pd_t::create,
57     rnn_weights_reorder_t<f32, s8>::pd_t::create,
58
59 #if defined(__INTEL_COMPILER) || (defined(__GNUC__) && !defined(__clang__))
60     /* Direct copy for icc which is faster than jitted code;
61      * Direct copy for gcc which might or might not be faster than jitted
62      * code, but still worth it because doesn't require jitting, i.e. much
63      * faster creation time. This is tentative solution and should be removed
64      * later (when we will cache jitted code?...). */
65     REG_SR_DIRECT_COPY(f32, f32),
66 #endif
67
68 #ifdef __INTEL_COMPILER
69     /* direct copy for icc, which is faster than jitted code */
70     REG_SR_DIRECT_COPY(f32, s32),
71     REG_SR_DIRECT_COPY(f32, s8),
72 //    REG_SR_DIRECT_COPY(f32, u8), FIXME: Disabled due to accuracy failure on int8 network
73     REG_SR_DIRECT_COPY(s32, f32),
74     REG_SR_DIRECT_COPY(s32, s32),
75     REG_SR_DIRECT_COPY(s32, s8),
76     REG_SR_DIRECT_COPY(s32, u8),
77     REG_SR_DIRECT_COPY(s8, f32),
78     REG_SR_DIRECT_COPY(s8, s32),
79     REG_SR_DIRECT_COPY(s8, s8),
80     REG_SR_DIRECT_COPY(s8, u8),
81     REG_SR_DIRECT_COPY(u8, f32),
82     REG_SR_DIRECT_COPY(u8, s32),
83     REG_SR_DIRECT_COPY(u8, s8),
84     REG_SR_DIRECT_COPY(u8, u8),
85 #endif
86
87     /* jit */
88     jit_uni_reorder_create,
89
90     /* fp32: flat <-> blocked with tail */
91     REG_SR_BIDIR(f32, any, f32, nCw4c),
92
93     REG_SR_BIDIR(f32, nchw, bin, nhwc),
94     REG_SR_BIDIR(f32, nhwc, bin, nhwc),
95     REG_SR_DIRECT_COPY(bin, bin),
96
97     REG_SR_BIDIR(f32, any, f32, nCw8c),
98     REG_SR_BIDIR(f32, any, f32, OIw4i4o),
99     REG_SR_BIDIR(f32, any, f32, OIw8i8o),
100     REG_SR_BIDIR(f32, any, f32, OIw8o8i),
101     REG_SR_BIDIR(f32, any, f32, gOIw4i4o),
102     REG_SR_BIDIR(f32, any, f32, gOIw8i8o),
103     REG_SR_BIDIR(f32, any, f32, gOIw8o8i),
104
105     REG_SR_BIDIR(f32, any, f32, nCw16c),
106     REG_SR_BIDIR(f32, any, f32, OIw16o16i),
107     REG_SR_BIDIR(f32, any, f32, OIw16i16o),
108     REG_SR_BIDIR(f32, any, f32, IOw16o16i),
109     REG_SR_BIDIR(f32, any, f32, gOIw16o16i),
110     REG_SR_BIDIR(f32, any, f32, gOIw16i16o),
111     REG_SR_BIDIR(f32, any, f32, gIOw16o16i),
112
113     REG_SR_BIDIR(f32, any, f32, nChw4c),
114     REG_SR_BIDIR(f32, any, f32, nChw8c),
115     REG_SR_BIDIR(f32, any, f32, OIhw4i4o),
116     REG_SR_BIDIR(f32, any, f32, Ohwi8o),
117     REG_SR_BIDIR(f32, any, f32, OIhw8i8o),
118     REG_SR_BIDIR(f32, any, f32, OIhw8o8i),
119     REG_SR_BIDIR(f32, any, f32, gOIhw4i4o),
120     REG_SR_BIDIR(f32, any, f32, gOIhw4o4i),
121     REG_SR_BIDIR(f32, any, f32, gOhwi8o),
122     REG_SR_BIDIR(f32, any, f32, gOIhw8i8o),
123     REG_SR_BIDIR(f32, any, f32, gOIhw8o8i),
124
125     REG_SR_BIDIR(f32, any, f32, nChw16c),
126     REG_SR_BIDIR(f32, any, f32, Oihw4o),
127     REG_SR_BIDIR(f32, any, f32, Oihw16o),
128     REG_SR_BIDIR(f32, any, f32, Ohwi4o),
129     REG_SR_BIDIR(f32, any, f32, Ohwi16o),
130     REG_SR_BIDIR(f32, any, f32, OIhw16o16i),
131     REG_SR_BIDIR(f32, any, f32, OIhw16i16o),
132     REG_SR_BIDIR(f32, any, f32, IOhw16o16i),
133     REG_SR_BIDIR(f32, any, f32, gOihw4o),
134     REG_SR_BIDIR(f32, any, f32, gOihw16o),
135     REG_SR_BIDIR(f32, any, f32, gOhwi4o),
136     REG_SR_BIDIR(f32, any, f32, gOhwi16o),
137     REG_SR_BIDIR(f32, any, f32, gOIhw16o16i),
138     REG_SR_BIDIR(f32, any, f32, gOIhw16i16o),
139     REG_SR_BIDIR(f32, any, f32, gIOhw16o16i),
140
141     REG_SR_BIDIR(f32, any, f32, nCdhw4c),
142     REG_SR_BIDIR(f32, any, f32, nCdhw8c),
143     REG_SR_BIDIR(f32, any, f32, OIdhw4i4o),
144     REG_SR_BIDIR(f32, any, f32, Odhwi8o),
145     REG_SR_BIDIR(f32, any, f32, OIdhw8i8o),
146     REG_SR_BIDIR(f32, any, f32, OIdhw8o8i),
147     REG_SR_BIDIR(f32, any, f32, gOIdhw4i4o),
148     REG_SR_BIDIR(f32, any, f32, gOdhwi8o),
149     REG_SR_BIDIR(f32, any, f32, gOIdhw8i8o),
150     REG_SR_BIDIR(f32, any, f32, gOIdhw8o8i),
151
152     REG_SR_BIDIR(f32, any, f32, nCdhw16c),
153     REG_SR_BIDIR(f32, any, f32, Oidhw4o),
154     REG_SR_BIDIR(f32, any, f32, Oidhw16o),
155     REG_SR_BIDIR(f32, any, f32, Odhwi16o),
156     REG_SR_BIDIR(f32, any, f32, OIdhw16o16i),
157     REG_SR_BIDIR(f32, any, f32, OIdhw16i16o),
158     REG_SR_BIDIR(f32, any, f32, gOidhw4o),
159     REG_SR_BIDIR(f32, any, f32, gOidhw16o),
160     REG_SR_BIDIR(f32, any, f32, gOdhwi16o),
161     REG_SR_BIDIR(f32, any, f32, gOIdhw16o16i),
162     REG_SR_BIDIR(f32, any, f32, gOIdhw16i16o),
163
164     /* WA to prevent fallback on reference implementations */
165     REG_SR_DIRECT_COPY(u8, f32),
166     REG_SR_DIRECT_COPY(u8, s8),
167     REG_SR_DIRECT_COPY(s8, u8),
168     REG_SR_DIRECT_COPY(u8, u8),
169     REG_SR_DIRECT_COPY(s8, s8),
170
171  /* fp32: blocked <-> blocked with tail */
172     REG_SR_BIDIR(f32, nCw8c, f32, nCw16c),
173     REG_SR_BIDIR(f32, nChw4c, f32, nChw16c),
174     REG_SR_BIDIR(f32, nChw8c, f32, nChw16c),
175     REG_SR_BIDIR(f32, nCdhw8c, f32, nCdhw16c),
176
177     /* int: flat <-> blocked with tail */
178     REG_SR(f32, nChw8c, u8, nhwc, fmt_order::keep),
179     REG_SR(f32, nChw8c, s8, nhwc, fmt_order::keep),
180     REG_SR(u8, nhwc, f32, nChw8c, fmt_order::keep),
181     REG_SR(s8, nhwc, f32, nChw8c, fmt_order::keep),
182     REG_SR(f32, nhwc, u8, nhwc, fmt_order::keep),
183     REG_SR(f32, nhwc, s8, nhwc, fmt_order::keep),
184     REG_SR(u8, nhwc, f32, nhwc, fmt_order::keep),
185     REG_SR(s8, nhwc, f32, nhwc, fmt_order::keep),
186     REG_SR(s8, nhwc, u8, nhwc, fmt_order::keep),
187     REG_SR(u8, nhwc, s8, nhwc, fmt_order::keep),
188     REG_SR(u8, nhwc, s8, nhwc, fmt_order::keep),
189     REG_SR(f32, nchw, u8, nhwc, fmt_order::keep),
190     REG_SR(f32, nchw, s8, nhwc, fmt_order::keep),
191     REG_SR(u8, nchw, u8, nhwc, fmt_order::keep),
192     REG_SR(s8, nchw, s8, nhwc, fmt_order::keep),
193     REG_SR(u8, nhwc, f32, nchw, fmt_order::keep),
194
195     REG_SR_BIDIR(f32, any, s32, nChw8c),
196     REG_SR_BIDIR(f32, any, s8, nChw8c),
197     REG_SR_BIDIR(f32, any, u8, nChw8c),
198     REG_SR_BIDIR(s32, any, f32, nChw8c),
199     REG_SR_BIDIR(s32, any, s32, nChw8c),
200     REG_SR_BIDIR(s32, any, s8, nChw8c),
201     REG_SR_BIDIR(s32, any, u8, nChw8c),
202     REG_SR_BIDIR(s8, any, f32, nChw8c),
203     REG_SR_BIDIR(s8, any, s32, nChw8c),
204     REG_SR_BIDIR(s8, any, s8, nChw8c),
205     REG_SR_BIDIR(s8, any, u8, nChw8c),
206     REG_SR_BIDIR(u8, any, f32, nChw8c),
207     REG_SR_BIDIR(u8, any, s32, nChw8c),
208     REG_SR_BIDIR(u8, any, s8, nChw8c),
209     REG_SR_BIDIR(u8, any, u8, nChw8c),
210
211     REG_SR_BIDIR(f32, any, s32, nChw16c),
212     REG_SR_BIDIR(f32, any, s8, nChw16c),
213     REG_SR_BIDIR(f32, any, u8, nChw16c),
214     REG_SR_BIDIR(s32, any, f32, nChw16c),
215     REG_SR_BIDIR(s32, any, s32, nChw16c),
216     REG_SR_BIDIR(s32, any, s8, nChw16c),
217     REG_SR_BIDIR(s32, any, u8, nChw16c),
218     REG_SR_BIDIR(s8, any, f32, nChw16c),
219     REG_SR_BIDIR(s8, any, s32, nChw16c),
220     REG_SR_BIDIR(s8, any, s8, nChw16c),
221     REG_SR_BIDIR(s8, any, u8, nChw16c),
222     REG_SR_BIDIR(u8, any, f32, nChw16c),
223     REG_SR_BIDIR(u8, any, s32, nChw16c),
224     REG_SR_BIDIR(u8, any, s8, nChw16c),
225     REG_SR_BIDIR(u8, any, u8, nChw16c),
226
227     REG_SR_BIDIR(f32, any, f32, OIhw4i16o4i),
228     REG_SR_BIDIR(f32, any, s8, OIhw4i16o4i),
229     REG_SR_BIDIR(s8, any, f32, OIhw4i16o4i),
230     REG_SR_BIDIR(s8, any, s8, OIhw4i16o4i),
231     REG_SR_BIDIR(f32, any, s8, gOIhw4i16o4i),
232     REG_SR_BIDIR(s8, any, f32, gOIhw4i16o4i),
233     REG_SR_BIDIR(f32, any, f32, gOIhw4i16o4i),
234     REG_SR_BIDIR(s8, any, s8, gOIhw4i16o4i),
235
236     REG_SR(f32, any, f32, OhIw8o4i, fmt_order::keep),
237     REG_SR(f32, any, s8, OhIw8o4i, fmt_order::keep),
238     REG_SR(s8, any, f32, OhIw8o4i, fmt_order::keep),
239     REG_SR(s8, any, s8, OhIw8o4i, fmt_order::keep),
240     REG_SR(f32, any, s8, gOhIw8o4i, fmt_order::keep),
241     REG_SR(s8, any, f32, gOhIw8o4i, fmt_order::keep),
242     REG_SR(f32, any, f32, gOhIw8o4i, fmt_order::keep),
243     REG_SR(s8, any, s8, gOhIw8o4i, fmt_order::keep),
244     REG_SR(f32, oihw, s8, OhIw8o4i_s8s8, fmt_order::keep),
245     REG_SR(s8, oihw, s8, OhIw8o4i_s8s8, fmt_order::keep),
246     REG_SR(f32, goihw, s8, gOhIw8o4i_s8s8, fmt_order::keep),
247     REG_SR(s8, goihw, s8, gOhIw8o4i_s8s8, fmt_order::keep),
248     REG_SR(f32, oidhw, s8, OdhIw8o4i_s8s8, fmt_order::keep),
249     REG_SR(s8, oidhw, s8, OdhIw8o4i_s8s8, fmt_order::keep),
250     REG_SR(f32, goidhw, s8, gOdhIw8o4i_s8s8, fmt_order::keep),
251     REG_SR(s8, goidhw, s8, gOdhIw8o4i_s8s8, fmt_order::keep),
252
253     REG_SR(bin, any, bin, OhIw8o32i, fmt_order::keep),
254     REG_SR(bin, any, bin, OhIw16o32i, fmt_order::keep),
255
256     REG_SR(f32, any, s8, hwio_s8s8, fmt_order::keep),
257     REG_SR(f32, any, s8, hwigo_s8s8, fmt_order::keep),
258     REG_SR(s8, any, s8, hwio_s8s8, fmt_order::keep),
259     REG_SR(s8, any, s8, hwigo_s8s8, fmt_order::keep),
260
261     REG_SR(f32, any, s8, dhwio_s8s8, fmt_order::keep),
262     REG_SR(f32, any, s8, dhwigo_s8s8, fmt_order::keep),
263     REG_SR(s8, any, s8, dhwio_s8s8, fmt_order::keep),
264     REG_SR(s8, any, s8, dhwigo_s8s8, fmt_order::keep),
265
266     REG_SR(f32, goihw, s8, gOIhw4o4i_s8s8, fmt_order::keep),
267     REG_SR(f32, hwigo, s8, gOIhw4o4i_s8s8, fmt_order::keep),
268     REG_SR(s8, goihw, s8, gOIhw4o4i_s8s8, fmt_order::keep),
269     REG_SR(s8, hwigo, s8, gOIhw4o4i_s8s8, fmt_order::keep),
270
271     REG_SR(f32, oiw, s8, OIw4i16o4i_s8s8, fmt_order::keep),
272     REG_SR(f32, goiw, s8, gOIw4i16o4i_s8s8, fmt_order::keep),
273     REG_SR(f32, oihw, s8, OIhw4i16o4i_s8s8, fmt_order::keep),
274     REG_SR(f32, goihw, s8, gOIhw4i16o4i_s8s8, fmt_order::keep),
275     REG_SR(f32, hwio, s8, OIhw4i16o4i_s8s8, fmt_order::keep),
276     REG_SR(f32, hwigo, s8, gOIhw4i16o4i_s8s8, fmt_order::keep),
277     REG_SR(s8, oiw, s8, OIw4i16o4i_s8s8, fmt_order::keep),
278     REG_SR(s8, goiw, s8, gOIw4i16o4i_s8s8, fmt_order::keep),
279     REG_SR(s8, oihw, s8, OIhw4i16o4i_s8s8, fmt_order::keep),
280     REG_SR(s8, goihw, s8, gOIhw4i16o4i_s8s8, fmt_order::keep),
281     REG_SR(s8, hwio, s8, OIhw4i16o4i_s8s8, fmt_order::keep),
282     REG_SR(s8, hwigo, s8, gOIhw4i16o4i_s8s8, fmt_order::keep),
283
284     REG_SR(f32, goihw, s8, gOIhw2i8o4i_s8s8, fmt_order::keep),
285     REG_SR(f32, hwigo, s8, gOIhw2i8o4i_s8s8, fmt_order::keep),
286     REG_SR(s8, goihw, s8, gOIhw2i8o4i_s8s8, fmt_order::keep),
287     REG_SR(s8, hwigo, s8, gOIhw2i8o4i_s8s8, fmt_order::keep),
288
289     REG_SR(f32, goiw, s8, Goiw16g_s8s8, fmt_order::keep),
290     REG_SR(f32, goihw, s8, Goihw16g_s8s8, fmt_order::keep),
291     REG_SR(f32, hwigo, s8, Goihw16g_s8s8, fmt_order::keep),
292     REG_SR(s8, goiw, s8, Goiw16g_s8s8, fmt_order::keep),
293     REG_SR(s8, goihw, s8, Goihw16g_s8s8, fmt_order::keep),
294     REG_SR(s8, hwigo, s8, Goihw16g_s8s8, fmt_order::keep),
295
296     /* bf16 */
297     REG_SR_BIDIR(bf16, any, bf16, nChw16c),
298
299     REG_SR(f32, nchw, bf16, nChw16c, fmt_order::keep),
300     REG_SR(bf16, nChw16c, f32, nchw, fmt_order::keep),
301
302     REG_SR(f32, oihw, bf16, OIhw8i16o2i, fmt_order::keep),
303     REG_SR(f32, oihw, bf16, IOhw8i16o2i, fmt_order::keep),
304     REG_SR(f32, goihw, bf16, gOIhw8i16o2i, fmt_order::keep),
305     REG_SR(f32, goihw, bf16, gIOhw8i16o2i, fmt_order::keep),
306     REG_SR(f32, oihw, bf16, OIhw8o16i2o, fmt_order::keep),
307     REG_SR(f32, goihw, bf16, gOIhw8o16i2o, fmt_order::keep),
308     REG_SR(f32, oihw, bf16, IOhw8o16i2o, fmt_order::keep),
309     REG_SR(f32, goihw, bf16, gIOhw8o16i2o, fmt_order::keep),
310     REG_SR(f32, oihw, bf16, OIhw16i16o, fmt_order::keep),
311     REG_SR(f32, goihw, bf16, gOIhw16i16o, fmt_order::keep),
312
313     REG_SR(bf16, OIhw16i16o, f32, oihw, fmt_order::keep),
314     REG_SR(bf16, gOIhw16i16o, f32, goihw, fmt_order::keep),
315
316     REG_SR(bf16, any, bf16, any, fmt_order::any, spec::reference),
317     REG_SR(bf16, any, f32, any, fmt_order::any, spec::reference),
318     REG_SR(f32, any, bf16, any, fmt_order::any, spec::reference),
319
320     /* s16 <-> s16 */
321     REG_SR_DIRECT_COPY(s16, s16),
322
323     REG_SR_BIDIR(s16, any, s16, OIhw8i16o2i),
324     REG_SR_BIDIR(s16, any, s16, gOIhw8i16o2i),
325     REG_SR_BIDIR(s16, OIhw8i16o2i, s16, OIhw8o16i2o),
326     REG_SR_BIDIR(s16, gOIhw8i16o2i, s16, gOIhw8o16i2o),
327
328     /* reference: the last line of defence */
329     REG_SR(f32, any, f32, any, fmt_order::any, spec::reference),
330     REG_SR(f32, any, s32, any, fmt_order::any, spec::reference),
331     REG_SR(f32, any, s16, any, fmt_order::any, spec::reference),
332     REG_SR(f32, any, s8, any, fmt_order::any, spec::reference),
333     REG_SR(f32, any, u8, any, fmt_order::any, spec::reference),
334
335     REG_SR(s32, any, f32, any, fmt_order::any, spec::reference),
336     REG_SR(s32, any, s32, any, fmt_order::any, spec::reference),
337     REG_SR(s32, any, s16, any, fmt_order::any, spec::reference),
338     REG_SR(s32, any, s8, any, fmt_order::any, spec::reference),
339     REG_SR(s32, any, u8, any, fmt_order::any, spec::reference),
340
341     REG_SR(s16, any, f32, any, fmt_order::any, spec::reference),
342     REG_SR(s16, any, s32, any, fmt_order::any, spec::reference),
343     REG_SR(s16, any, s16, any, fmt_order::any, spec::reference),
344
345     REG_SR(s8, any, f32, any, fmt_order::any, spec::reference),
346     REG_SR(s8, any, s32, any, fmt_order::any, spec::reference),
347     REG_SR(s8, any, s8, any, fmt_order::any, spec::reference),
348     REG_SR(s8, any, u8, any, fmt_order::any, spec::reference),
349
350     REG_SR(u8, any, f32, any, fmt_order::any, spec::reference),
351     REG_SR(u8, any, s32, any, fmt_order::any, spec::reference),
352     REG_SR(u8, any, u8, any, fmt_order::any, spec::reference),
353     REG_SR(u8, any, s8, any, fmt_order::any, spec::reference),
354
355     /* eol */
356     nullptr,
357 };
358 }
359
360 const rpd_create_f *cpu_engine_t::get_reorder_implementation_list() const {
361     return cpu_reorder_impl_list;
362 }
363
364 }
365 }
366 }