1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef TYPE_MAPPING_HPP
18 #define TYPE_MAPPING_HPP
20 #include "mkldnn_types.h"
25 // TODO: autogenerate this
27 using dims_t = mkldnn_dims_t;
28 using strides_t = mkldnn_strides_t;
30 /* FIXME: to inference from correspoding types */
32 using stride_t = ptrdiff_t;
34 using status_t = mkldnn_status_t;
36 const status_t success = mkldnn_success;
37 const status_t out_of_memory = mkldnn_out_of_memory;
38 const status_t try_again = mkldnn_try_again;
39 const status_t invalid_arguments = mkldnn_invalid_arguments;
40 const status_t not_ready = mkldnn_not_ready;
41 const status_t unimplemented = mkldnn_unimplemented;
42 const status_t iterator_ends = mkldnn_iterator_ends;
43 const status_t runtime_error = mkldnn_runtime_error;
44 const status_t not_required = mkldnn_not_required;
47 using prop_kind_t = mkldnn_prop_kind_t;
49 const prop_kind_t undef = mkldnn_prop_kind_undef;
50 const prop_kind_t forward_training = mkldnn_forward_training;
51 const prop_kind_t forward_inference = mkldnn_forward_inference;
52 const prop_kind_t forward_scoring = mkldnn_forward_scoring;
53 const prop_kind_t forward = mkldnn_forward;
54 const prop_kind_t backward = mkldnn_backward;
55 const prop_kind_t backward_data = mkldnn_backward_data;
56 const prop_kind_t backward_weights = mkldnn_backward_weights;
57 const prop_kind_t backward_bias = mkldnn_backward_bias;
60 using alg_kind_t = mkldnn_alg_kind_t;
62 const alg_kind_t undef = mkldnn_alg_kind_undef;
63 const alg_kind_t convolution_direct = mkldnn_convolution_direct;
64 const alg_kind_t convolution_winograd = mkldnn_convolution_winograd;
65 const alg_kind_t deconvolution_direct = mkldnn_deconvolution_direct;
66 const alg_kind_t deconvolution_winograd = mkldnn_deconvolution_winograd;
67 const alg_kind_t eltwise_relu = mkldnn_eltwise_relu;
68 const alg_kind_t eltwise_tanh = mkldnn_eltwise_tanh;
69 const alg_kind_t eltwise_elu = mkldnn_eltwise_elu;
70 const alg_kind_t eltwise_square = mkldnn_eltwise_square;
71 const alg_kind_t eltwise_abs = mkldnn_eltwise_abs;
72 const alg_kind_t eltwise_sqrt = mkldnn_eltwise_sqrt;
73 const alg_kind_t eltwise_linear = mkldnn_eltwise_linear;
74 const alg_kind_t eltwise_bounded_relu = mkldnn_eltwise_bounded_relu;
75 const alg_kind_t eltwise_soft_relu = mkldnn_eltwise_soft_relu;
76 const alg_kind_t eltwise_logistic = mkldnn_eltwise_logistic;
77 const alg_kind_t eltwise_clamp = mkldnn_eltwise_clamp;
78 const alg_kind_t depthwise_scale_shift = mkldnn_depthwise_scale_shift;
79 const alg_kind_t depthwise_prelu = mkldnn_depthwise_prelu;
80 const alg_kind_t pooling_max = mkldnn_pooling_max;
81 const alg_kind_t pooling_avg = mkldnn_pooling_avg;
82 const alg_kind_t pooling_avg_include_padding = mkldnn_pooling_avg_include_padding;
83 const alg_kind_t pooling_avg_exclude_padding = mkldnn_pooling_avg_exclude_padding;
84 const alg_kind_t lrn_across_channels = mkldnn_lrn_across_channels;
85 const alg_kind_t lrn_within_channel = mkldnn_lrn_within_channel;
86 const alg_kind_t vanilla_rnn = mkldnn_vanilla_rnn;
87 const alg_kind_t vanilla_lstm = mkldnn_vanilla_lstm;
88 const alg_kind_t vanilla_gru = mkldnn_vanilla_gru;
89 const alg_kind_t gru_linear_before_reset = mkldnn_gru_linear_before_reset;
90 const alg_kind_t roi_pooling_max = mkldnn_roi_pooling_max;
91 const alg_kind_t roi_pooling_bilinear = mkldnn_roi_pooling_bilinear;
94 using data_type_t = mkldnn_data_type_t;
96 const data_type_t undef = mkldnn_data_type_undef;
97 const data_type_t f32 = mkldnn_f32;
98 const data_type_t s32 = mkldnn_s32;
99 const data_type_t s16 = mkldnn_s16;
100 const data_type_t s8 = mkldnn_s8;
101 const data_type_t u8 = mkldnn_u8;
104 using round_mode_t = mkldnn_round_mode_t;
105 namespace round_mode {
106 const round_mode_t nearest = mkldnn_round_nearest;
107 const round_mode_t down = mkldnn_round_down;
110 using memory_format_t = mkldnn_memory_format_t;
111 namespace memory_format {
112 const memory_format_t undef = mkldnn_format_undef;
113 const memory_format_t any = mkldnn_any;
114 const memory_format_t blocked = mkldnn_blocked;
115 const memory_format_t x = mkldnn_x;
116 const memory_format_t nc = mkldnn_nc;
117 const memory_format_t nchw = mkldnn_nchw;
118 const memory_format_t nhwc = mkldnn_nhwc;
119 const memory_format_t chwn = mkldnn_chwn;
120 const memory_format_t nChw8c = mkldnn_nChw8c;
121 const memory_format_t nChw16c = mkldnn_nChw16c;
122 const memory_format_t ncdhw = mkldnn_ncdhw;
123 const memory_format_t ndhwc = mkldnn_ndhwc;
124 const memory_format_t nCdhw16c = mkldnn_nCdhw16c;
125 const memory_format_t oi = mkldnn_oi;
126 const memory_format_t io = mkldnn_io;
127 const memory_format_t oihw = mkldnn_oihw;
128 const memory_format_t ihwo = mkldnn_ihwo;
129 const memory_format_t hwio = mkldnn_hwio;
130 const memory_format_t dhwio = mkldnn_dhwio;
131 const memory_format_t oidhw = mkldnn_oidhw;
132 const memory_format_t OIdhw16i16o = mkldnn_OIdhw16i16o;
133 const memory_format_t OIdhw16o16i = mkldnn_OIdhw16o16i;
134 const memory_format_t Oidhw16o = mkldnn_Oidhw16o;
135 const memory_format_t Odhwi16o = mkldnn_Odhwi16o;
136 const memory_format_t oIhw8i = mkldnn_oIhw8i;
137 const memory_format_t oIhw16i = mkldnn_oIhw16i;
138 const memory_format_t OIhw8i8o = mkldnn_OIhw8i8o;
139 const memory_format_t OIhw16i16o = mkldnn_OIhw16i16o;
140 const memory_format_t OIhw4i16o4i = mkldnn_OIhw4i16o4i;
141 const memory_format_t OIhw8i16o2i = mkldnn_OIhw8i16o2i;
142 const memory_format_t OIdhw8i16o2i = mkldnn_OIdhw8i16o2i;
143 const memory_format_t OIhw8o16i2o = mkldnn_OIhw8o16i2o;
144 const memory_format_t OIhw8o8i = mkldnn_OIhw8o8i;
145 const memory_format_t OIhw16o16i = mkldnn_OIhw16o16i;
146 const memory_format_t IOhw16o16i = mkldnn_IOhw16o16i;
147 const memory_format_t Oihw16o = mkldnn_Oihw16o;
148 const memory_format_t Ohwi8o = mkldnn_Ohwi8o;
149 const memory_format_t Ohwi16o = mkldnn_Ohwi16o;
150 const memory_format_t goihw = mkldnn_goihw;
151 const memory_format_t hwigo = mkldnn_hwigo;
152 const memory_format_t gOIhw8i8o = mkldnn_gOIhw8i8o;
153 const memory_format_t gOIhw16i16o = mkldnn_gOIhw16i16o;
154 const memory_format_t gOIhw4i16o4i = mkldnn_gOIhw4i16o4i;
155 const memory_format_t gOIhw8i16o2i = mkldnn_gOIhw8i16o2i;
156 const memory_format_t gOIdhw8i16o2i = mkldnn_gOIdhw8i16o2i;
157 const memory_format_t gOIhw8o16i2o = mkldnn_gOIhw8o16i2o;
158 const memory_format_t gOIhw8o8i = mkldnn_gOIhw8o8i;
159 const memory_format_t gOIhw16o16i = mkldnn_gOIhw16o16i;
160 const memory_format_t gIOhw16o16i = mkldnn_gIOhw16o16i;
161 const memory_format_t gOihw16o = mkldnn_gOihw16o;
162 const memory_format_t gOhwi8o = mkldnn_gOhwi8o;
163 const memory_format_t gOhwi16o = mkldnn_gOhwi16o;
164 const memory_format_t Goihw8g = mkldnn_Goihw8g;
165 const memory_format_t Goihw16g = mkldnn_Goihw16g;
166 const memory_format_t goidhw = mkldnn_goidhw;
167 const memory_format_t gOIdhw16i16o = mkldnn_gOIdhw16i16o;
168 const memory_format_t gOIdhw16o16i = mkldnn_gOIdhw16o16i;
169 const memory_format_t gOidhw16o = mkldnn_gOidhw16o;
170 const memory_format_t gOdhwi16o = mkldnn_gOdhwi16o;
171 const memory_format_t ntc = mkldnn_ntc;
172 const memory_format_t tnc = mkldnn_tnc;
173 const memory_format_t ldsnc = mkldnn_ldsnc;
174 const memory_format_t ldigo = mkldnn_ldigo;
175 const memory_format_t ldigo_p = mkldnn_ldigo_p;
176 const memory_format_t ldgoi = mkldnn_ldgoi;
177 const memory_format_t ldgoi_p = mkldnn_ldgoi_p;
178 const memory_format_t ldgo = mkldnn_ldgo;
179 const memory_format_t wino_fmt = mkldnn_wino_fmt;
182 using padding_kind_t = mkldnn_padding_kind_t;
183 namespace padding_kind {
184 const padding_kind_t padding_zero = mkldnn_padding_zero;
187 using engine_kind_t = mkldnn_engine_kind_t;
188 namespace engine_kind {
189 const engine_kind_t any_engine = mkldnn_any_engine;
190 const engine_kind_t cpu = mkldnn_cpu;
193 using primitive_kind_t = mkldnn_primitive_kind_t;
194 namespace primitive_kind {
195 const primitive_kind_t undefined = mkldnn_undefined_primitive;
196 const primitive_kind_t memory = mkldnn_memory;
197 const primitive_kind_t view = mkldnn_view;
198 const primitive_kind_t reorder = mkldnn_reorder;
199 const primitive_kind_t concat = mkldnn_concat;
200 const primitive_kind_t concat_inplace = mkldnn_concat_inplace;
201 const primitive_kind_t sum = mkldnn_sum;
202 const primitive_kind_t convolution = mkldnn_convolution;
203 const primitive_kind_t deconvolution = mkldnn_deconvolution;
204 const primitive_kind_t eltwise = mkldnn_eltwise;
205 const primitive_kind_t depthwise = mkldnn_depthwise;
206 const primitive_kind_t softmax = mkldnn_softmax;
207 const primitive_kind_t pooling = mkldnn_pooling;
208 const primitive_kind_t lrn = mkldnn_lrn;
209 const primitive_kind_t batch_normalization = mkldnn_batch_normalization;
210 const primitive_kind_t inner_product = mkldnn_inner_product;
211 const primitive_kind_t convolution_relu = mkldnn_convolution_relu;
212 const primitive_kind_t rnn = mkldnn_rnn;
213 const primitive_kind_t roi_pooling = mkldnn_roi_pooling;
216 using query_t = mkldnn_query_t;
218 const query_t undef = mkldnn_query_undef;
220 const query_t engine = mkldnn_query_engine;
221 const query_t primitive_kind = mkldnn_query_primitive_kind;
223 const query_t num_of_inputs_s32 = mkldnn_query_num_of_inputs_s32;
224 const query_t num_of_outputs_s32 = mkldnn_query_num_of_outputs_s32;
226 const query_t time_estimate_f64 = mkldnn_query_time_estimate_f64;
227 const query_t memory_consumption_s64 = mkldnn_query_memory_consumption_s64;
229 const query_t impl_info_str = mkldnn_query_impl_info_str;
231 const query_t some_d = mkldnn_query_some_d;
232 const query_t memory_d = mkldnn_query_memory_d;
233 const query_t convolution_d = mkldnn_query_convolution_d;
234 const query_t deconvolution_d = mkldnn_query_deconvolution_d;
235 const query_t eltwise_d = mkldnn_query_eltwise_d;
236 const query_t depthwise_d = mkldnn_query_depthwise_d;
237 const query_t softmax_d = mkldnn_query_softmax_d;
238 const query_t pooling_d = mkldnn_query_pooling_d;
239 const query_t lrn_d = mkldnn_query_lrn_d;
240 const query_t batch_normalization_d = mkldnn_query_batch_normalization_d;
241 const query_t inner_product_d = mkldnn_query_inner_product_d;
242 const query_t convolution_relu_d = mkldnn_query_convolution_relu_d;
243 const query_t rnn_d = mkldnn_query_rnn_d;
244 const query_t roi_pooling_d = mkldnn_query_roi_pooling_d;
246 const query_t some_pd = mkldnn_query_some_pd;
247 const query_t input_pd = mkldnn_query_input_pd;
248 const query_t output_pd = mkldnn_query_output_pd;
249 const query_t src_pd = mkldnn_query_src_pd;
250 const query_t diff_src_pd = mkldnn_query_diff_src_pd;
251 const query_t weights_pd = mkldnn_query_weights_pd;
252 const query_t diff_weights_pd = mkldnn_query_diff_weights_pd;
253 const query_t dst_pd = mkldnn_query_dst_pd;
254 const query_t diff_dst_pd = mkldnn_query_diff_dst_pd;
256 const query_t workspace_pd = mkldnn_query_workspace_pd;
259 using blocking_desc_t = mkldnn_blocking_desc_t;
260 using wino_data_t = mkldnn_wino_desc_t;
261 using memory_desc_t = mkldnn_memory_desc_t;
262 using convolution_desc_t = mkldnn_convolution_desc_t;
263 using deconvolution_desc_t = mkldnn_deconvolution_desc_t;
264 using pooling_desc_t = mkldnn_pooling_desc_t;
265 using eltwise_desc_t = mkldnn_eltwise_desc_t;
266 using softmax_desc_t = mkldnn_softmax_desc_t;
267 using lrn_desc_t = mkldnn_lrn_desc_t;
268 using batch_normalization_desc_t = mkldnn_batch_normalization_desc_t;
269 using inner_product_desc_t = mkldnn_inner_product_desc_t;
270 using convolution_relu_desc_t = mkldnn_convolution_relu_desc_t;
271 using roi_pooling_desc_t = mkldnn_roi_pooling_desc_t;
272 using depthwise_desc_t = mkldnn_depthwise_desc_t;
274 using rnn_direction_t = mkldnn_rnn_direction_t;
275 using rnn_cell_desc_t = mkldnn_rnn_cell_desc_t;
276 using rnn_desc_t = mkldnn_rnn_desc_t;
278 /* C op_desc_t, which eventually are just (void*) */
279 using c_op_desc_t = mkldnn_op_desc_t;
280 using const_c_op_desc_t = const_mkldnn_op_desc_t;
284 primitive_kind_t kind;
285 memory_desc_t memory;
286 convolution_desc_t convolution;
287 deconvolution_desc_t deconvolution;
288 pooling_desc_t pooling;
289 eltwise_desc_t eltwise;
290 softmax_desc_t softmax;
292 batch_normalization_desc_t batch_normalization;
293 inner_product_desc_t inner_product;
294 convolution_relu_desc_t convolution_relu;
295 roi_pooling_desc_t roi_pooling;
296 depthwise_desc_t depthwise;
299 op_desc_t(const primitive_kind_t &_): kind(_) {}
301 # define DECL_CTOR_AND_CONVERTERS(c_type, name) \
302 op_desc_t(const c_type &_): name(_) {} \
303 static op_desc_t *convert_from_c(c_type *_) \
304 { return reinterpret_cast<op_desc_t*>(_); } \
305 static const op_desc_t *convert_from_c(const c_type *_) \
306 { return reinterpret_cast<const op_desc_t*>(_); }
308 DECL_CTOR_AND_CONVERTERS(memory_desc_t, memory);
309 DECL_CTOR_AND_CONVERTERS(convolution_desc_t, convolution);
310 DECL_CTOR_AND_CONVERTERS(pooling_desc_t, pooling);
311 DECL_CTOR_AND_CONVERTERS(eltwise_desc_t, eltwise);
312 DECL_CTOR_AND_CONVERTERS(depthwise_desc_t, depthwise);
313 DECL_CTOR_AND_CONVERTERS(softmax_desc_t, softmax);
314 DECL_CTOR_AND_CONVERTERS(lrn_desc_t, lrn);
315 DECL_CTOR_AND_CONVERTERS(batch_normalization_desc_t, batch_normalization);
316 DECL_CTOR_AND_CONVERTERS(inner_product_desc_t, inner_product);
317 DECL_CTOR_AND_CONVERTERS(convolution_relu_desc_t, convolution_relu);
318 DECL_CTOR_AND_CONVERTERS(roi_pooling_desc_t, roi_pooling);
320 # undef DECL_CTOR_AND_CONVERTERS
323 using engine_t = mkldnn_engine;
324 using primitive_desc_iterator_t = mkldnn_primitive_desc_iterator;
325 using primitive_desc_t = mkldnn_primitive_desc;
326 using primitive_attr_t = mkldnn_primitive_attr;
327 using post_ops_t = mkldnn_post_ops;
328 using primitive_t = mkldnn_primitive;
329 using primitive_at_t = mkldnn_primitive_at_t;
331 using stream_kind_t = mkldnn_stream_kind_t;
332 namespace stream_kind {
333 const stream_kind_t any_stream = mkldnn_any_stream;
334 const stream_kind_t eager = mkldnn_eager;
335 const stream_kind_t lazy = mkldnn_lazy;
337 using stream_t = mkldnn_stream;
339 /* forward declaration of internal primitive_desc types */
351 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s