ad791e6ff72ea51ceb158d1648d77b8171573057
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / cpu_reorder.cpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <assert.h>
18
19 #include "cpu_engine.hpp"
20 #include "cpu_memory.hpp"
21 #include "type_helpers.hpp"
22
23 #include "cpu/jit_uni_reorder.hpp"
24 #include "cpu/simple_reorder.hpp"
25 #include "cpu/wino_reorder.hpp"
26
27 namespace mkldnn {
28 namespace impl {
29 namespace cpu {
30
31 using rpd_create_f = mkldnn::impl::engine_t::reorder_primitive_desc_create_f;
32
33 namespace {
34 using namespace mkldnn::impl::data_type;
35 using namespace mkldnn::impl::memory_format;
36
37 #define REG_SR(idt, ifmt, odt, ofmt, ...) \
38     simple_reorder_t<idt, ifmt, odt, ofmt, __VA_ARGS__>::pd_t::create
39
40 #define REG_SR_BIDIR(idt, ifmt, odt, ofmt) \
41     REG_SR(idt, ifmt, odt, ofmt, fmt_order::keep), \
42     REG_SR(idt, ifmt, odt, ofmt, fmt_order::reverse)
43
44 #define REG_SR_DIRECT_COPY(idt, odt) \
45     REG_SR(idt, any, odt, any, fmt_order::any, spec::direct_copy), \
46     REG_SR(idt, any, odt, any, fmt_order::any, spec::direct_copy_except_dim_0)
47
48 static const rpd_create_f cpu_reorder_impl_list[] = {
49     /* winograd */
50     wino_reorder_t<f32, f32>::pd_t::create,
51     wino_reorder_t<f32, s8>::pd_t::create,
52
53 #ifdef __INTEL_COMPILER
54     /* direct copy for icc, which is faster than jitted code */
55     REG_SR_DIRECT_COPY(f32, f32),
56     REG_SR_DIRECT_COPY(f32, s32),
57     REG_SR_DIRECT_COPY(f32, s8),
58 //    REG_SR_DIRECT_COPY(f32, u8), FIXME: Disabled due to accuracy failure on int8 network
59     REG_SR_DIRECT_COPY(s32, f32),
60     REG_SR_DIRECT_COPY(s32, s32),
61     REG_SR_DIRECT_COPY(s32, s8),
62     REG_SR_DIRECT_COPY(s32, u8),
63     REG_SR_DIRECT_COPY(s8, f32),
64     REG_SR_DIRECT_COPY(s8, s32),
65     REG_SR_DIRECT_COPY(s8, s8),
66     REG_SR_DIRECT_COPY(s8, u8),
67     REG_SR_DIRECT_COPY(u8, f32),
68     REG_SR_DIRECT_COPY(u8, s32),
69     REG_SR_DIRECT_COPY(u8, s8),
70     REG_SR_DIRECT_COPY(u8, u8),
71 #endif
72
73     /* jit */
74     jit_uni_reorder_create,
75
76     /* fp32: flat <-> blocked with tail */
77     REG_SR_BIDIR(f32, any, f32, nChw8c),
78     REG_SR_BIDIR(f32, any, f32, nChw16c),
79     REG_SR_BIDIR(f32, any, f32, nCdhw16c),
80     REG_SR_BIDIR(f32, nChw8c, f32, nChw16c),
81
82     REG_SR_BIDIR(f32, any, f32, Oihw16o),
83     REG_SR_BIDIR(f32, any, f32, Ohwi16o),
84     REG_SR_BIDIR(f32, any, f32, Oidhw16o),
85     REG_SR_BIDIR(f32, any, f32, Odhwi16o),
86     REG_SR_BIDIR(f32, any, f32, OIhw16o16i),
87     REG_SR_BIDIR(f32, any, f32, OIhw16i16o),
88     REG_SR_BIDIR(f32, any, f32, OIdhw16o16i),
89     REG_SR_BIDIR(f32, any, f32, OIdhw16i16o),
90     REG_SR_BIDIR(f32, any, f32, IOhw16o16i),
91     REG_SR_BIDIR(f32, any, f32, gOihw16o),
92     REG_SR_BIDIR(f32, any, f32, gOhwi16o),
93     REG_SR_BIDIR(f32, any, f32, gOidhw16o),
94     REG_SR_BIDIR(f32, any, f32, gOdhwi16o),
95     REG_SR_BIDIR(f32, any, f32, gOIhw16o16i),
96     REG_SR_BIDIR(f32, any, f32, gOIhw16i16o),
97     REG_SR_BIDIR(f32, any, f32, gOIdhw16o16i),
98     REG_SR_BIDIR(f32, any, f32, gOIdhw16i16o),
99     REG_SR_BIDIR(f32, any, f32, gIOhw16o16i),
100
101     /* int: flat <-> blocked with tail */
102     REG_SR_BIDIR(f32, nhwc, s32, nChw16c),
103     REG_SR_BIDIR(f32, nhwc, s8, nChw16c),
104     REG_SR_BIDIR(f32, nhwc, u8, nChw16c),
105     REG_SR_BIDIR(s32, nhwc, f32, nChw16c),
106     REG_SR_BIDIR(s32, nhwc, s32, nChw16c),
107     REG_SR_BIDIR(s32, nhwc, s8, nChw16c),
108     REG_SR_BIDIR(s32, nhwc, u8, nChw16c),
109     REG_SR_BIDIR(s8, nhwc, f32, nChw16c),
110     REG_SR_BIDIR(s8, nhwc, s32, nChw16c),
111     REG_SR_BIDIR(s8, nhwc, s8, nChw16c),
112     REG_SR_BIDIR(s8, nhwc, u8, nChw16c),
113     REG_SR_BIDIR(u8, nhwc, f32, nChw16c),
114     REG_SR_BIDIR(u8, nhwc, s32, nChw16c),
115     REG_SR_BIDIR(u8, nhwc, s8, nChw16c),
116     REG_SR_BIDIR(u8, nhwc, u8, nChw16c),
117
118     REG_SR_BIDIR(f32, oihw, f32, OIhw4i16o4i),
119     REG_SR_BIDIR(f32, oihw, s8, OIhw4i16o4i),
120     REG_SR_BIDIR(s8, oihw, f32, OIhw4i16o4i),
121     REG_SR_BIDIR(s8, oihw, s8, OIhw4i16o4i),
122     REG_SR_BIDIR(f32, goihw, s8, gOIhw4i16o4i),
123     REG_SR_BIDIR(s8, goihw, f32, gOIhw4i16o4i),
124     REG_SR_BIDIR(f32, goihw, f32, gOIhw4i16o4i),
125     REG_SR_BIDIR(s8, goihw, s8, gOIhw4i16o4i),
126
127     /* s16 <-> s16 */
128     REG_SR_DIRECT_COPY(s16, s16),
129     REG_SR_BIDIR(s16, oihw, s16, OIhw8i16o2i),
130     REG_SR_BIDIR(s16, goihw, s16, gOIhw8i16o2i),
131     REG_SR_BIDIR(s16, OIhw8i16o2i, s16, OIhw8o16i2o),
132     REG_SR_BIDIR(s16, gOIhw8i16o2i, s16, gOIhw8o16i2o),
133
134     /* WA to prevent fallback on reference implementations */
135     REG_SR_DIRECT_COPY(u8, f32),
136     REG_SR_BIDIR(u8, nchw, f32, nChw8c),
137     REG_SR_BIDIR(u8, nchw, f32, nChw16c),
138
139     /* reference: the last line of defence */
140     REG_SR(f32, any, f32, any, fmt_order::any, spec::reference),
141     REG_SR(f32, any, s32, any, fmt_order::any, spec::reference),
142     REG_SR(f32, any, s16, any, fmt_order::any, spec::reference),
143     REG_SR(f32, any, s8, any, fmt_order::any, spec::reference),
144     REG_SR(f32, any, u8, any, fmt_order::any, spec::reference),
145
146     REG_SR(s32, any, f32, any, fmt_order::any, spec::reference),
147     REG_SR(s32, any, s32, any, fmt_order::any, spec::reference),
148     REG_SR(s32, any, s16, any, fmt_order::any, spec::reference),
149     REG_SR(s32, any, s8, any, fmt_order::any, spec::reference),
150     REG_SR(s32, any, u8, any, fmt_order::any, spec::reference),
151
152     REG_SR(s16, any, f32, any, fmt_order::any, spec::reference),
153     REG_SR(s16, any, s32, any, fmt_order::any, spec::reference),
154     REG_SR(s16, any, s16, any, fmt_order::any, spec::reference),
155
156     REG_SR(s8, any, f32, any, fmt_order::any, spec::reference),
157     REG_SR(s8, any, s32, any, fmt_order::any, spec::reference),
158     REG_SR(s8, any, s8, any, fmt_order::any, spec::reference),
159     REG_SR(s8, any, u8, any, fmt_order::any, spec::reference),
160
161     REG_SR(u8, any, f32, any, fmt_order::any, spec::reference),
162     REG_SR(u8, any, s32, any, fmt_order::any, spec::reference),
163     REG_SR(u8, any, u8, any, fmt_order::any, spec::reference),
164     REG_SR(u8, any, s8, any, fmt_order::any, spec::reference),
165
166     /* eol */
167     nullptr,
168 };
169 }
170
171 const rpd_create_f *cpu_engine_t::get_reorder_implementation_list() const {
172     return cpu_reorder_impl_list;
173 }
174
175 }
176 }
177 }