1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef MKLDNN_TYPES_H
18 #define MKLDNN_TYPES_H
24 #ifndef DOXYGEN_SHOULD_SKIP_THIS
29 /** @addtogroup c_api C API
32 * @addtogroup c_api_types Types
35 * @addtogroup c_api_types_generic Generic
38 /** Intel(R) MKL-DNN Version type */
46 /** Status values returned by Intel(R) MKL-DNN functions. */
48 /** The operation was successful */
50 /** The operation failed due to an out-of-memory condition */
51 mkldnn_out_of_memory = 1,
52 /** The operation failed and should be retried */
54 /** The operation failed because of incorrect function arguments */
55 mkldnn_invalid_arguments = 3,
56 /** The operation failed because a primitive was not ready for execution */
58 /** The operation failed because requested functionality is not implemented
60 mkldnn_unimplemented = 5,
61 /** Primitive iterator passed over last primitive descriptor */
62 mkldnn_iterator_ends = 6,
63 /** Primitive or engine failed on execution */
64 mkldnn_runtime_error = 7,
65 /** Queried element is not required for given primitive */
66 mkldnn_not_required = 8,
69 /** Data type specification */
71 /** Undefined data type, used for empty memory descriptors. */
72 mkldnn_data_type_undef = 0,
73 /** 32-bit/single-precision floating point. */
75 /** 32-bit signed integer. */
77 /** 16-bit signed integer. */
79 /** 8-bit signed integer. */
81 /** 8-bit unsigned integer. */
90 mkldnn_round_nearest = 1,
92 mkldnn_round_down = 2,
93 } mkldnn_round_mode_t;
95 /** Memory format specification.
97 * Intel MKL-DNN formats describe physical data layout. The physical layout
98 * is described as a sequence of the dimensions as they are laid out in the
99 * memory (from the outer-most to the inner-most). Note that this order
100 * doesn't affect the logical order of the dimensions that is kept in the
101 * `dims` field of the mkldnn_memory_desc_t structure. The logical order of the
102 * dimensions is specified by the type of tensor.
104 * For example, CNN 5D tensor always has its logical dimensions in the order
105 * `(batch, channels, depth, height, width)`, while the physical layout might be
106 * #mkldnn_ncdhw or #mkldnn_ndhwc:
109 * int batch = 2, channels = 16, depth = 13, height = 13, width = 13;
111 * int ndims = 5; // 5D tensor
112 * mkldnn_dims_t dims = {batch, channels, depth, height, width};
114 * mkldnn_memory_desc_t data_in_ncdhw;
115 * mkldnn_memory_desc_init(&data_in_ncdhw, 5, dims, mlkdnn_ncdhw);
117 * // note that in both cases dims passed are the same
118 * mkldnn_memory_desc_t data_in_ndhwc;
119 * mkldnn_memory_desc_init(&data_in_ndhwc, 5, dims, mlkdnn_ndhwc);
122 * The following notation applies to memory format names:
123 * - @c 'n' denotes the mini-batch dimension
124 * - @c 'c' denotes a channels dimension
125 * - When there are multiple channel dimensions (for example, in convolution
126 * weights tensor), @c 'i' and @c 'o' denote dimensions of input and output
128 * - @c 'd', @c 'h', and @c 'w' denote spatial depth, height, and width
130 * - Upper-case letters indicate that the data is laid out in blocks
131 * for a particular dimension. In such cases, the format name contains both
132 * upper- and lower-case letters for that dimension with a lower-case letter
133 * preceded by the block size. For example: @c 'mkldnn_nChw8c' describes a
134 * format where the outermost dimension is mini-batch, followed by the
135 * channel block number, followed by the spatial height and width, and
136 * finally followed by 8-element channel blocks.
139 * Channel designations can be different. For example, both the @c
140 * 'mkldnn_nc' and @c 'mkldnn_io' formats can be used to describe a 2D
143 * @sa @ref understanding_memory_formats
146 /** Undefined memory format, used for empty memory descriptors. */
147 mkldnn_format_undef = 0,
148 /** Unspecified format. The primitive selects a format
151 /** A tensor in a generic format described by the stride and blocking
152 * values in each dimension. See #mkldnn_blocking_desc_t for more
155 /** 1D data tensor. */
157 /** 2D data tensor. */
159 /** 3D data tensor with the physical layout @c ncw.
160 * Logical dimensions come in the order: (n, c, w) */
162 /** 3D data tensor with the physical layout @c nwc.
163 * Logical dimensions come in the order: (n, c, w) */
165 /** 4D data tensor with the physical layout @c nchw, used in Caffe.
166 * Logical dimensions come in the order: (n, c, h, w) */
168 /** 4D data tensor with the physical layout @c nhwc, used in TensorFlow.
169 * Logical dimensions come in the order: (n, c, h, w) */
171 /** 4D data tensor with the physical layout @c chwn, used in Neon.
172 * Logical dimensions come in the order: (n, c, h, w) */
174 /** 5D data tensor with the physical layout @c ncdhw.
175 * Logical dimensions come in the order: (n, c, d, h, w) */
177 /** 5D data tensor with the physical layout @c ndhwc, used in TensorFlow.
178 * Logical dimensions come in the order: (n, c, d, h, w) */
180 /** 2D weights tensor with physical layout @c oi.
181 * Logical dimensions come in the order: (o, i) */
183 /** 2D weights tensor with physical layout @c io.
184 * Logical dimensions come in the order: (o, i) */
186 /** 3D weights tensor with physical layout @c oiw.
187 * Logical dimensions come in the order: (o, i, w) */
189 /** 3D weights tensor with physical layout @c wio.
190 * Logical dimensions come in the order: (o, i, w) */
192 /** 4D weights tensor with physical layout @c oihw, used in Caffe.
193 * Logical dimensions come in the order: (o, i, h, w) */
195 /** 4D weights tensor with physical layout @c hwio, used in TensorFlow.
196 * Logical dimensions come in the order: (o, i, h, w) */
198 /** 4D weights tensor with physical layout @c ihwo.
199 * Logical dimensions come in the order: (o, i, h, w) */
201 /** 4D weights tensor with physical layout @c iohw.
202 * Logical dimensions come in the order: (o, i, h, w) */
204 /** 5D weights tensor with physical layout @c iodhw, used in Caffe.
205 * Logical dimensions come in the order: (o, i, d, h, w) */
207 /** 5D weights tensor with physical layout @c dhwio, used in TensorFlow.
208 * Logical dimensions come in the order: (o, i, d, h, w) */
210 /** 4D grouped weights tensor with the physical layout @c goiw.
211 * Logical dimensions come in the order: (g, o, i, w) */
213 /** 5D grouped weights tensor with the physical layout @c goihw,
215 * Logical dimensions come in the order: (g, o, i, h, w) */
217 /** 5D grouped weights tensor with the physical layout @c hwigo,
218 * used in TensorFlow.
219 * Logical dimensions come in the order: (g, o, i, h, w) */
221 /** 5D grouped weights tensor with the physical layout @c giohw.
222 * Logical dimensions come in the order: (g, o, i, h, w) */
224 /** 6D grouped weights tensor with the physical layout @c goidhw,
226 * Logical dimensions come in the order: (g, o, i, d, h, w) */
228 /** 3D RNN data tensor in the format (batch, seq_length, input channels). */
230 /** 3D RNN data tensor in the format (seq_length, batch, input channels). */
232 /** 5D RNN states tensor in the format (num_layers, num_directions,
233 * num_states, batch, state channels). */
235 /** 5D RNN weights tensor in the format (num_layers, num_directions,
236 * input_channels, num_gates, output_channels).
238 * - For LSTM cells, the gates order is input, forget, candidate
240 * - For GRU cells, the gates order is update, reset and output gate. */
242 /** 5D RNN weights tensor in the format (num_layers, num_directions,
243 * num_gates, output_channels, input_channels).
245 * - For LSTM cells, the gates order is input, forget, candidate
247 * - For GRU cells, the gates order is update, reset and output gate. */
249 /** 4D RNN bias tensor in the format (num_layers, num_directions,
250 * num_gates, output_channels).
252 * - For LSTM cells, the gates order is input, forget, candidate
254 * - For GRU cells, the gates order is update, reset and output gate. */
257 /* Opaque data types, are not to be used explicitly */
260 mkldnn_nCw4c /** blocked data format */,
261 mkldnn_nCw8c /** blocked data format */,
262 mkldnn_nCw16c /** blocked data format */,
263 mkldnn_nChw4c /** blocked data format */,
264 mkldnn_nChw8c /** blocked data format */,
265 mkldnn_nChw16c /** blocked data format */,
266 mkldnn_nCdhw4c /** blocked data format */,
267 mkldnn_nCdhw8c /** blocked data format */,
268 mkldnn_nCdhw16c /** blocked data format */,
271 mkldnn_Owi4o /** blocked weights format */,
272 mkldnn_OIw4i4o /** blocked weights format */,
273 mkldnn_Owi8o /** blocked weights format */,
274 mkldnn_OIw8i8o /** blocked weights format */,
275 mkldnn_OIw8o8i /** blocked weights format */,
276 mkldnn_OIw16i16o /** blocked weights format */,
277 mkldnn_OIw16o16i /** blocked weights format */,
278 mkldnn_Oiw4o /** blocked weights format */,
279 mkldnn_Oiw16o /** blocked weights format */,
280 mkldnn_Owi16o /** blocked weights format */,
281 mkldnn_OIw8i16o2i /** blocked weights format */,
282 mkldnn_OIw8o16i2o /** blocked weights format */,
283 mkldnn_IOw16o16i /** blocked weights format */,
286 /** weights format with additional buffer
287 * size equal to the number of output channels
288 * and containing the values:
289 * O[i:0,OC] = -128 * SUM(j:0,IC;h:0,H;w:0,W)(weights(i,j,h,w))*/
291 mkldnn_oIhw8i /** blocked weights format */,
292 mkldnn_oIhw16i /** blocked weights format */,
293 mkldnn_OIhw4i4o /** blocked weights format */,
294 mkldnn_OIhw8i8o /** blocked weights format */,
295 mkldnn_OIhw16i16o /** blocked weights format */,
296 mkldnn_OIhw4i16o4i /** blocked weights format */,
297 /** blocked weights format with additional buffer
298 * with size equal to the number of output channels
299 * and containing the values:
300 * O[i:0,OC] = -128 * SUM(j:0,IC;h:0,H;w:0,W)(weights(i,j,h,w))*/
301 mkldnn_OIhw4i16o4i_s8s8,
302 mkldnn_OIhw8i16o2i /** blocked weights format */,
303 mkldnn_OIhw8o16i2o /** blocked weights format */,
304 mkldnn_OIhw8o8i /** blocked weights format */,
305 mkldnn_OIhw16o16i /** blocked weights format */,
306 mkldnn_IOhw16o16i /** blocked weights format */,
307 mkldnn_Oihw8o /** blocked weights format */,
308 mkldnn_Oihw4o /** blocked weights format */,
309 mkldnn_Oihw16o /** blocked weights format */,
310 mkldnn_Ohwi8o /** blocked weights format */,
311 mkldnn_Ohwi4o /** blocked weights format */,
312 mkldnn_Ohwi16o /** blocked weights format */,
313 mkldnn_OhIw16o4i /** blocked weights format */,
314 mkldnn_OhIw8o4i /** blocked weights format */,
315 /** blocked weights format with additional buffer
316 * with size equal to the number of output channels
317 * and containing the values:
318 * O[i:0,OC] = -128 * SUM(j:0,IC;h:0,H;w:0,W)(weights(i,j,h,w))*/
319 mkldnn_OhIw8o4i_s8s8,
320 mkldnn_OhIw8o32i /** blocked weights format */,
321 mkldnn_OhIw16o32i /** blocked weights format */,
324 mkldnn_oIdhw8i /** blocked weights format */,
325 mkldnn_oIdhw16i /** blocked weights format */,
326 mkldnn_OIdhw4i4o /** blocked weights format */,
327 mkldnn_Odhwi4o /** blocked weights format */,
328 mkldnn_OIdhw8i8o /** blocked weights format */,
329 mkldnn_OIdhw8o8i /** blocked weights format */,
330 mkldnn_Odhwi8o /** blocked weights format */,
331 mkldnn_OIdhw16i16o /** blocked weights format */,
332 mkldnn_OIdhw16o16i /** blocked weights format */,
333 mkldnn_Oidhw4o /** blocked weights format */,
334 mkldnn_Oidhw16o /** blocked weights format */,
335 mkldnn_Odhwi16o /** blocked weights format */,
336 mkldnn_OIdhw8i16o2i /** blocked weights format */,
338 /* weights w/ groups, 4D */
339 mkldnn_gOwi4o /** blocked weights format */,
340 mkldnn_gOIw4i4o /** blocked weights format */,
341 mkldnn_gOwi8o /** blocked weights format */,
342 mkldnn_gOIw8o8i /** blocked weights format */,
343 mkldnn_gOIw8i8o /** blocked weights format */,
344 mkldnn_gOIw16i16o /** blocked weights format */,
345 mkldnn_gOIw16o16i /** blocked weights format */,
346 mkldnn_gOiw4o /** blocked weights format */,
347 mkldnn_gOiw16o /** blocked weights format */,
348 mkldnn_gOwi16o /** blocked weights format */,
349 mkldnn_gOIw8i16o2i /** blocked weights format */,
350 mkldnn_gOIw8o16i2o /** blocked weights format */,
351 mkldnn_gIOw16o16i /** blocked weights format */,
353 /* weights w/ groups, 5D */
354 /** weights format with additional buffer
355 * size equal to the number of output channels
356 * multiplied by number of groups and containing the values:
357 * O[i:0,G*OC] = -128 * SUM(j:0,IC;h:0,H;w:0,W)(weights(i,j,h,w))*/
359 mkldnn_gOIhw4i4o /** blocked weights format */,
360 mkldnn_gOIhw8i8o /** blocked weights format */,
361 mkldnn_gOIhw16i16o /** blocked weights format */,
362 mkldnn_gOIhw4i16o4i /** blocked weights format */,
363 /** blocked weights format with additional buffer
364 * with size equal to the number of output channels
365 * multiplied by number of groups and containing the values:
366 * O[i:0,G*OC] = -128 * SUM(j:0,IC;h:0,H;w:0,W)(weights(i,j,h,w))*/
367 mkldnn_gOIhw4i16o4i_s8s8,
368 mkldnn_gOIhw2i8o4i /** blocked weights format */,
369 /** blocked weights format with additional buffer
370 * with size equal to the number of output channels
371 * multiplied by number of groups and containing the values:
372 * O[i:0,G*OC] = -128 * SUM(j:0,IC;h:0,H;w:0,W)(weights(i,j,h,w))*/
373 mkldnn_gOIhw2i8o4i_s8s8,
374 mkldnn_gOIhw8i16o2i /** blocked weights format */,
375 mkldnn_gOIhw8o16i2o /** blocked weights format */,
376 mkldnn_gOIhw4o4i /** blocked weights format */,
377 /** blocked weights format with additional buffer
378 * with size equal to the number of output channels
379 * and containing the values:
380 * O[i:0,OC] = -128 * SUM(j:0,IC;h:0,H;w:0,W)(weights(i,j,h,w))*/
381 mkldnn_gOIhw4o4i_s8s8 /** blocked weights format */,
382 mkldnn_gOIhw8o8i /** blocked weights format */,
383 mkldnn_gOIhw16o16i /** blocked weights format */,
384 mkldnn_gIOhw16o16i /** blocked weights format */,
385 mkldnn_gOihw8o /** blocked weights format */,
386 mkldnn_gOihw4o /** blocked weights format */,
387 mkldnn_gOihw16o /** blocked weights format */,
388 mkldnn_gOhwi8o /** blocked weights format */,
389 mkldnn_gOhwi4o /** blocked weights format */,
390 mkldnn_gOhwi16o /** blocked weights format */,
391 mkldnn_Goihw8g /** blocked weights format */,
392 mkldnn_Goihw16g /** blocked weights format */,
393 /** blocked weights format with additional buffer
394 * with size equal to the number of groups and containing the values:
395 * O[i:0,G] = -128 * SUM(h:0,H;w:0,W)(weights(i,i,h,w))*/
396 mkldnn_Goihw16g_s8s8,
397 mkldnn_gOhIw16o4i /** blocked weights format */,
398 mkldnn_gOhIw8o4i /** blocked weights format */,
399 /** blocked weights format with additional buffer
400 * with size equal to the number of output channels
401 * multiplied by number of groups and containing the values:
402 * O[i:0,G*OC] = -128 * SUM(j:0,IC;h:0,H;w:0,W)(weights(i,j,h,w))*/
403 mkldnn_gOhIw8o4i_s8s8,
405 /* weights w/ groups, 6D */
406 mkldnn_gOIdhw4i4o /** blocked weights format */,
407 mkldnn_gOdhwi4o /** blocked weights format */,
408 mkldnn_gOIdhw8i8o /** blocked weights format */,
409 mkldnn_gOIdhw8o8i /** blocked weights format */,
410 mkldnn_gOdhwi8o /** blocked weights format */,
411 mkldnn_gOIdhw8i16o2i /** blocked weights format */,
412 mkldnn_gOIdhw16i16o /** blocked weights format */,
413 mkldnn_gOIdhw16o16i /** blocked weights format */,
414 mkldnn_gOidhw4o /** blocked weights format */,
415 mkldnn_gOidhw16o /** blocked weights format */,
416 mkldnn_gOdhwi16o /** blocked weights format */,
418 mkldnn_wino_fmt /** Weights format used in 8bit Winograd convolution */,
420 mkldnn_rnn_packed /** Packed weights format used in RNN */,
422 /** Just a sentinel, not real memory format. Must be changed after new
423 * format is added. */
425 } mkldnn_memory_format_t;
427 /** Kinds of padding. Define how to interpret the data in padding regions. */
429 /** The data in padding regions is zero. */
431 } mkldnn_padding_kind_t;
433 /** Kinds of propagation. */
435 /* TODO: suggest renames */
436 /** Undefined propagation type. */
437 mkldnn_prop_kind_undef = 0,
438 /** Forward data propagation (training mode). In this mode primitives
439 * perform computations necessary for subsequent backward propagation. */
440 mkldnn_forward_training = 64,
441 /** Forward data propagation (inference mode). In this mode primitives
442 * perform only computations that are necessary for inference and omit
443 * computations that are necessary only for backward propagation. */
444 mkldnn_forward_inference = 96,
445 /** Forward data propagation (alias for @c mkldnn_forward_inference) */
446 mkldnn_forward_scoring = mkldnn_forward_inference,
447 /** Forward data propagation (alias for @c mkldnn_forward_training) */
448 mkldnn_forward = mkldnn_forward_training,
449 /** Backward propagation (with respect to all parameters */
450 mkldnn_backward = 128,
451 /** Backward data propagation */
452 mkldnn_backward_data = 160,
453 /** Backward weights propagation */
454 mkldnn_backward_weights = 192,
455 /** Backward bias propagation */
456 mkldnn_backward_bias = 193,
457 } mkldnn_prop_kind_t;
459 /** Kinds of primitives. Used to implement a way to extend the library with new
460 * primitives without changing the ABI. */
462 /** Undefined primitive (XXX: why do we have it?). */
463 mkldnn_undefined_primitive,
464 /** A memory primitive. */
466 /** A view primitive. */
468 /** A reorder primitive.*/
470 /** A shuffle primitive.*/
472 /** A (out-of-place) concat primitive. */
474 /** A (in-place) concat primitive. */
475 mkldnn_concat_inplace,
476 /** A sum primitive. */
478 /** A convolution primitive. */
480 /** A deconvolution primitive. */
481 mkldnn_deconvolution,
482 /** An element-wise primitive. */
484 /** A Softmax primitive. */
486 /** A pooling primitive. */
488 /** An LRN primitive. */
490 /** An batch normalization primitive. */
491 mkldnn_batch_normalization,
492 /** An inner product primitive. */
493 mkldnn_inner_product,
494 /** A rnn primitive. */
496 /** A ROI pooling primitive. */
498 /** An channel-wise primitive. */
500 /** A binary convolution primitive. */
501 mkldnn_binary_convolution,
502 /** A binarization primitive. */
504 } mkldnn_primitive_kind_t;
506 /** Kinds of algorithms. */
508 mkldnn_alg_kind_undef,
509 /** Direct convolution */
510 mkldnn_convolution_direct = 0x1,
511 /** Winograd convolution */
512 mkldnn_convolution_winograd = 0x2,
513 /** Convolution algorithm(either direct or Winograd) is chosen just in time **/
514 mkldnn_convolution_auto = 0x3,
515 /** Direct deconvolution */
516 mkldnn_deconvolution_direct = 0xa,
517 /** Winograd deconvolution */
518 mkldnn_deconvolution_winograd = 0xb,
520 mkldnn_eltwise_relu = 0x1f,
521 /** Eltwise: hyperbolic tangent non-linearity (tanh) */
522 mkldnn_eltwise_tanh = 0x2f,
523 /** Eltwise: parametric exponential linear unit (elu) */
524 mkldnn_eltwise_elu = 0x3f,
525 /** Eltwise: square */
526 mkldnn_eltwise_square = 0x4f,
528 mkldnn_eltwise_abs = 0x5f,
529 /** Eltwise: square root */
530 mkldnn_eltwise_sqrt = 0x6f,
531 /** Eltwise: linear */
532 mkldnn_eltwise_linear = 0x7f,
533 /** Eltwise: bounded_relu */
534 mkldnn_eltwise_bounded_relu = 0x8f,
535 /** Eltwise: soft_relu */
536 mkldnn_eltwise_soft_relu = 0x9f,
537 /** Eltwise: logistic */
538 mkldnn_eltwise_logistic = 0xaf,
539 /** Eltwise: clamp */
540 mkldnn_eltwise_clamp = 0xbf,
542 mkldnn_eltwise_exp = 0xcf,
544 mkldnn_eltwise_not = 0xdf,
546 mkldnn_pooling_max = 0x1ff,
547 /** Average pooling include padding */
548 mkldnn_pooling_avg_include_padding = 0x2ff,
549 /** Average pooling exclude padding */
550 mkldnn_pooling_avg_exclude_padding = 0x3ff,
551 mkldnn_pooling_avg = mkldnn_pooling_avg_exclude_padding,
552 /** Local response normalization (LRN) across multiple channels */
553 mkldnn_lrn_across_channels = 0xaff,
554 /** LRN within a single channel */
555 mkldnn_lrn_within_channel = 0xbff,
557 mkldnn_vanilla_rnn = 0x1fff,
559 mkldnn_vanilla_lstm = 0x2fff,
561 mkldnn_vanilla_gru = 0x3fff,
562 /** GRU cell with linear before reset
564 * Modification of original GRU cell. Differs from #mkldnn_vanilla_gru
565 * in how the new memory gate is calculated:
566 * \f[ c_t = tanh(W_c*x_t + b_{c_x} + r_t*(U_c*h_{t-1}+b_{c_h})) \f]
567 * Primitive expects 4 biases on input:
568 * \f$[b_{u}, b_{r}, b_{c_x}, b_{c_h}]\f$
570 mkldnn_gru_linear_before_reset = 0x4fff,
571 /** ROI max pooling **/
572 mkldnn_roi_pooling_max = 0xafff,
573 /** ROI pooling with bilinear interpolation**/
574 mkldnn_roi_pooling_bilinear = 0xbfff,
575 /** Depthwise: scale_shift */
576 mkldnn_depthwise_scale_shift = 0x1ffff,
577 /** Depthwise: prelu */
578 mkldnn_depthwise_prelu = 0x2ffff,
579 /** Direct binary convolution */
580 mkldnn_binary_convolution_direct = 0x1fffff,
581 /** Depthwise binarization */
582 mkldnn_binarization_depthwise = 0xafffff
585 /** Flags for batch-normalization primititve. */
587 /** Use global statistics
590 * - on forward propagation use mean and variance provided by user (input)
591 * - on backward propagation reduces the amount of computations, since
592 * mean and variance are considered as constants
595 * - on forward propagation mean and variance are computed and stored in
597 * - on backward propagation compute full derivative wrt to data
599 mkldnn_use_global_stats = 0x1U,
600 /** Use scale and shift parameters
603 * - on forward propagation use scale and shift (aka scale and bias) for
604 * the batch normalization results
605 * - on backward propagation (for prop_kind == #mkldnn_backward) compute
606 * diff wrt to scale and shift (hence one extra output used)
609 * - on backward propagation prop_kind == #mkldnn_backward_data has the
610 * same behavior as prop_kind == #mkldnn_backward
612 mkldnn_use_scaleshift = 0x2U,
616 * - on inference this option behaves the same as if the primitive were
617 * fused with ReLU via post ops API
618 * - on training primitive requires workspace (required to be able to
619 * perform backward pass)
621 mkldnn_fuse_bn_relu = 0x4U,
622 } mkldnn_batch_normalization_flag_t;
626 /** @addtogroup c_api_types_memory Auxiliary types for memory description
629 /** Maximum number of dimensions a tensor can have. Only restricts the amount
630 * of space used for the tensor description. Individual computational
631 * primitives may support only tensors of certain dimensions. */
632 #define TENSOR_MAX_DIMS 12
634 /** A type to describe tensor dimensions. */
635 typedef ptrdiff_t mkldnn_dims_t[TENSOR_MAX_DIMS];
636 /** A type to describe strides within a tensor. */
637 typedef ptrdiff_t mkldnn_strides_t[TENSOR_MAX_DIMS];
639 /** Generic description of blocked data layout for most memory formats.
641 * @sa @ref understanding_memory_formats */
643 /** Block size for each of the dimensions. */
644 mkldnn_dims_t block_dims;
645 /** strides[0]: stride between the first elements of adjacent blocks.
646 * @n strides[1]: strides between elements in the same block. */
647 mkldnn_strides_t strides[2];
648 /** Size of the data including padding in each dimension. */
649 mkldnn_dims_t padding_dims;
650 /** Per-dimension offset from the padding to actual data, the top-level
651 * tensor with offsets applied must lie within the padding area. */
652 mkldnn_dims_t offset_padding_to_data;
653 /** Offset from memory origin to the current block, non-zero only in
654 * a description of a memory sub-block. */
655 ptrdiff_t offset_padding;
656 } mkldnn_blocking_desc_t;
659 /** Undefined memory format, used for empty memory descriptors. */
660 mkldnn_wino_undef = 0,
661 /** Tensors of weights for 2x3 winograd convolutions. */
662 mkldnn_wino_wei_aaOIoi,
663 mkldnn_wino_wei_aaOio,
664 mkldnn_wino_wei_aaOBiOo,
665 /** Tensor of weights for 4x3 convolution. */
666 mkldnn_wino_wei_OBaaIBOIio
667 } mkldnn_wino_memory_format_t;
669 /** Description of tensor of weights for winograd 2x3 convolution. */
671 mkldnn_wino_memory_format_t wino_format;
682 } mkldnn_wino_desc_t;
685 mkldnn_packed_format_undef = 0,
688 } mkldnn_rnn_packed_memory_format_t;
690 /* Maximum number of parts of RNN weights tensor that require separate
692 #define MKLDNN_RNN_MAX_N_PARTS 4
694 /** Description of tensor of packed weights for rnn. */
696 mkldnn_rnn_packed_memory_format_t format;
699 int parts[MKLDNN_RNN_MAX_N_PARTS];
700 size_t part_pack_size[MKLDNN_RNN_MAX_N_PARTS];
701 size_t offset_compensation;
703 } mkldnn_rnn_packed_desc_t;
705 /** @addtogroup c_api_types_op_descs Operation descriptors
708 /** A pointer to any of the operation descriptors. */
709 typedef void *mkldnn_op_desc_t;
710 /** A pointer to any of the operation descriptors (constant variant). */
711 typedef const void *const_mkldnn_op_desc_t;
713 /** Memory descriptor. The description is based on a number of dimensions,
714 * dimensions themselves, plus information about elements type and memory
715 * format. Additionally, contains format-specific descriptions of the data
718 /** The kind of primitive. Used for self-identifying the primitive
719 * descriptor. Must be #mkldnn_memory. */
720 mkldnn_primitive_kind_t primitive_kind;
721 /** Number of dimensions */
723 /** Dimensions in the following order:
724 * - CNN data tensors: mini-batch, channel, spatial
725 * (<code>{N, C, [[D,] H,] W}</code>)
726 * - CNN weight tensors: group (optional), output channel, input channel,
727 * spatial (<code>{[G,] O, I, [[D,] H,] W}</code>)
728 * - RNN data tensors: time, mini-batch, channels (<code>{T, N, C}</code>)
729 * or layers, directions, states, mini-batch, channels (<code>{L, D, S, N, C}</code>)
730 * - RNN weight tensor: layers, directions, input channel, gates, output channels
731 * (<code>{L, D, I, G, O}</code>).
734 * The order of dimensions does not depend on the memory format, so
735 * whether the data is laid out in #mkldnn_nchw or #mkldnn_nhwc
736 * the dims for 4D CN data tensor would be <code>{N, C, H, W}</code>.
739 /** Data type of the tensor elements. */
740 mkldnn_data_type_t data_type;
741 /** Memory format. */
742 mkldnn_memory_format_t format;
744 /** Description of the data layout for memory formats that use
746 mkldnn_blocking_desc_t blocking;
747 /** Tensor of weights for integer 8bit winograd convolution. */
748 mkldnn_wino_desc_t wino_desc;
749 /** Tensor of packed weights for RNN. */
750 mkldnn_rnn_packed_desc_t rnn_packed_desc;
751 /* ... other descriptions possible */
753 } mkldnn_memory_desc_t;
757 /** A descriptor of a convolution operation. */
759 /** The kind of primitive. Used for self-identifying the primitive
760 * descriptor. Must be #mkldnn_convolution. */
761 mkldnn_primitive_kind_t primitive_kind;
762 /** The kind of propagation. Possible values: #mkldnn_forward_training,
763 * #mkldnn_forward_inference, #mkldnn_backward_data,
764 * #mkldnn_backward_weights, and #mkldnn_backward_bias. */
765 mkldnn_prop_kind_t prop_kind;
766 /** The kind of the convolution algorithm. Possible values:
767 * #mkldnn_convolution_direct. */
768 mkldnn_alg_kind_t alg_kind;
769 /** Source memory descriptor. */
770 mkldnn_memory_desc_t src_desc;
771 /** Source gradient memory descriptor. */
772 mkldnn_memory_desc_t diff_src_desc;
773 /** Weights memory descriptor. */
774 mkldnn_memory_desc_t weights_desc;
775 /** Weights gradient memory descriptor. */
776 mkldnn_memory_desc_t diff_weights_desc;
777 /** Bias memory descriptor. */
778 mkldnn_memory_desc_t bias_desc;
779 /** Bias gradient memory descriptor. */
780 mkldnn_memory_desc_t diff_bias_desc;
781 /** Destination memory descriptor. */
782 mkldnn_memory_desc_t dst_desc;
783 /** Destination gradient memory descriptor. */
784 mkldnn_memory_desc_t diff_dst_desc;
785 /** Convolution strides in each spatial dimension. */
786 mkldnn_dims_t strides;
787 /** Convolution dilates in each spatial dimension. */
788 mkldnn_dims_t dilates;
789 /** Padding in each spatial dimension. padding[0] is a padding in the
790 * beginning (@p padding_l), padding[1] is a padding in the end (@p
792 mkldnn_dims_t padding[2];
793 /** The kind of padding to use. */
794 mkldnn_padding_kind_t padding_kind;
795 /** The accumulator data type. Initialized automatically. */
796 mkldnn_data_type_t accum_data_type;
797 } mkldnn_convolution_desc_t;
799 /** A descriptor of a deconvolution operation. */
800 typedef mkldnn_convolution_desc_t mkldnn_deconvolution_desc_t;
802 /** A descriptor of a shuffle operation. */
804 /** The kind of primitive. Used for self-identifying the primitive
805 * descriptor. Must be #mkldnn_convolution. */
806 mkldnn_primitive_kind_t primitive_kind;
807 /** The kind of propagation. Possible values: #mkldnn_forward_training,
808 * #mkldnn_forward_inference, and #mkldnn_backward_data. */
809 mkldnn_prop_kind_t prop_kind;
810 /** Source and destination memory descriptor,
811 * and source and destination gradient memory descriptor. */
812 mkldnn_memory_desc_t data_desc;
813 /** axis for shuffling. */
815 /** number of groups in group convolution */
817 } mkldnn_shuffle_desc_t;
819 /** A descriptor of a element-wise operation. */
821 /** The kind of primitive. Used for self-identifying the primitive
822 * descriptor. Must be #mkldnn_eltwise. */
823 mkldnn_primitive_kind_t primitive_kind;
824 /** The kind of propagation. Possible values: #mkldnn_forward_training,
825 * #mkldnn_forward_inference, #mkldnn_backward, and #mkldnn_backward_data.
827 mkldnn_prop_kind_t prop_kind;
828 /** The kind of eltwise algorithm. Possible values: #mkldnn_eltwise_relu,
829 * #mkldnn_eltwise_tanh, #mkldnn_eltwise_elu, #mkldnn_eltwise_square,
830 * #mkldnn_eltwise_abs, #mkldnn_eltwise_sqrt, #mkldnn_eltwise_linear,
831 * #mkldnn_eltwise_bounded_relu, #mkldnn_eltwise_soft_relu, and
832 * #mkldnn_eltwise_logistic. */
833 mkldnn_alg_kind_t alg_kind;
834 /** Source and destination memory descriptor. */
835 mkldnn_memory_desc_t data_desc;
836 /** Source and destination gradient memory descriptor. */
837 mkldnn_memory_desc_t diff_data_desc;
838 /** Algorithm specific parameter.
840 * - #mkldnn_eltwise_relu: @p alpha -- negative slope, @p beta ignored
841 * - #mkldnn_eltwise_tanh: @p alpha and @p beta ignored
842 * - #mkldnn_eltwise_elu: @p alpha -- negative slope, @p beta ignored
843 * - #mkldnn_eltwise_square: @p alpha and @p beta ignored
844 * - #mkldnn_eltwise_abs: @p alpha and @p beta ignored
845 * - #mkldnn_eltwise_sqrt: @p alpha and @p beta ignored
846 * - #mkldnn_eltwise_linear: @p alpha -- scale, @p beta -- shift
847 * - #mkldnn_eltwise_bounded_relu: @p alpha -- upper bound, @p beta ignored
848 * - #mkldnn_eltwise_soft_relu: @p alpha and @p beta ignored
849 * - #mkldnn_eltwise_logistic: @p alpha and @p beta ignored
852 } mkldnn_eltwise_desc_t;
854 /** A descriptor of a channel-wise operation. */
856 /** The kind of primitive. Used for self identifying the primitive
857 * descriptor. Must be #mkldnn_depthwise. */
858 mkldnn_primitive_kind_t primitive_kind;
859 /** The kind of propagation. Possible values: #mkldnn_forward_training,
860 * #mkldnn_forward_inference, #mkldnn_backward, and #mkldnn_backward_data.
862 mkldnn_prop_kind_t prop_kind;
863 /** The kind of depthwise algorithm. Possible values: #mkldnn_depthwise_scale_shift
864 * #mkldnn_depthwise_prelu */
865 mkldnn_alg_kind_t alg_kind;
866 /** Source memory descriptor. */
867 mkldnn_memory_desc_t src_desc;
868 /** Destination memory descriptor. */
869 mkldnn_memory_desc_t dst_desc;
870 /** Weights memory descriptor. */
871 mkldnn_memory_desc_t weights_desc;
872 /** Bias memory descriptor. */
873 mkldnn_memory_desc_t bias_desc;
874 } mkldnn_depthwise_desc_t;
876 /** A descriptor of a Softmax operation. */
878 /** The kind of primitive. Used for self-identifying the primitive
879 * descriptor. Must be #mkldnn_softmax. */
880 mkldnn_primitive_kind_t primitive_kind;
881 /** The kind of propagation. Possible values: #mkldnn_forward_training and
882 * #mkldnn_forward_inference. */
883 mkldnn_prop_kind_t prop_kind;
884 /** Source and destination memory descriptor. */
885 mkldnn_memory_desc_t data_desc;
886 /** Source and Destination of gradient memory descriptor. */
887 mkldnn_memory_desc_t diff_desc;
888 /** The axis along which to perform the softmax. */
890 } mkldnn_softmax_desc_t;
892 /** A descriptor of a pooling operation. */
894 /** The kind of primitive. Used for self-identifying the primitive
895 * descriptor. Must be #mkldnn_pooling. */
896 mkldnn_primitive_kind_t primitive_kind;
897 /** The kind of propagation. Possible values: #mkldnn_forward_training,
898 * #mkldnn_forward_inference, #mkldnn_backward, and #mkldnn_backward_data.
900 mkldnn_prop_kind_t prop_kind;
901 /** The kind of pooling algorithm. Possible values: #mkldnn_pooling_max and
902 * #mkldnn_pooling_avg. */
903 mkldnn_alg_kind_t alg_kind;
904 /** Source memory descriptor. */
905 mkldnn_memory_desc_t src_desc;
906 /** Source gradient memory descriptor. */
907 mkldnn_memory_desc_t diff_src_desc;
908 /** Destination memory descriptor. */
909 mkldnn_memory_desc_t dst_desc;
910 /** Destination gradient memory descriptor. */
911 mkldnn_memory_desc_t diff_dst_desc;
912 /** Pooling kernel strides for spatial dimensions. */
913 mkldnn_dims_t strides;
914 /** Pooling kernel spatial dimensions. */
915 mkldnn_dims_t kernel;
916 /** Padding in each spatial dimension. padding[0] is a padding in the
917 * beginning (@p padding_l), padding[1] is a padding in the end (@p
919 mkldnn_dims_t padding[2];
920 /** The kind of padding to use. */
921 mkldnn_padding_kind_t padding_kind;
922 /** The accumulator data type. Initialized automatically. */
923 mkldnn_data_type_t accum_data_type;
924 } mkldnn_pooling_desc_t;
926 /** A descriptor of a Local Response Normalization (LRN) operation. */
928 /** The kind of primitive. Used for self-identifying the primitive
929 * descriptor. Must be #mkldnn_lrn. */
930 mkldnn_primitive_kind_t primitive_kind;
931 /** The kind of propagation. Possible values: #mkldnn_forward_training,
932 * #mkldnn_forward_inference, #mkldnn_backward, and #mkldnn_backward_data.
934 mkldnn_prop_kind_t prop_kind;
935 /** LRN algorithm. Possible values: #mkldnn_lrn_within_channel and
936 * #mkldnn_lrn_across_channels. */
937 mkldnn_alg_kind_t alg_kind;
938 /** Source and destination memory descriptor. */
939 mkldnn_memory_desc_t data_desc;
940 /** Source and destination gradient memory descriptor. */
941 mkldnn_memory_desc_t diff_data_desc;
942 /** The number of channels to sum over (for cross-channel LRN) or the side
943 * length of the square region to sum over (for within-channel LRN). */
945 /** LRN alpha parameter. */
947 /** LRN beta parameter. */
949 /** LRN k parameter. */
953 /** A descriptor of a Batch Normalization operation. */
955 /** The kind of primitive. Used for self-identifying the primitive
956 * descriptor. Must be #mkldnn_batch_normalization. */
957 mkldnn_primitive_kind_t primitive_kind;
958 /** The kind of propagation. Possible values: #mkldnn_forward_training,
959 * #mkldnn_forward_inference, #mkldnn_backward, and #mkldnn_backward_data.
961 mkldnn_prop_kind_t prop_kind;
962 /** Source and destination memory descriptor. */
963 mkldnn_memory_desc_t data_desc;
964 /** Source and destination gradient memory descriptor. */
965 mkldnn_memory_desc_t diff_data_desc;
966 /** Scale and shift data and gradient memory descriptors.
968 * Scaleshift memory descriptor uses 2D #mkldnn_nc format[2,Channels]. 1-st
969 * dimension contains gamma parameter, 2-nd dimension contains beta
971 mkldnn_memory_desc_t data_scaleshift_desc;
972 mkldnn_memory_desc_t diff_data_scaleshift_desc;
973 /** Mean and variance data memory descriptors.
975 * Mean and variance memory descriptors use 1D #mkldnn_x format[Channels].
977 mkldnn_memory_desc_t mean_desc;
978 mkldnn_memory_desc_t variance_desc;
979 /** Batch normalization epsilon parameter. */
980 float batch_norm_epsilon;
982 } mkldnn_batch_normalization_desc_t;
984 /** A descriptor of an inner product operation. */
986 /** The kind of primitive. Used for self-identifying the primitive
987 * descriptor. Must be #mkldnn_inner_product. */
988 mkldnn_primitive_kind_t primitive_kind;
989 /** The kind of propagation. Possible values: #mkldnn_forward_training,
990 * #mkldnn_forward_inference, #mkldnn_backward_data,
991 * #mkldnn_backward_weights, and #mkldnn_backward_bias. */
992 mkldnn_prop_kind_t prop_kind;
993 /** Source memory descriptor. */
994 mkldnn_memory_desc_t src_desc;
995 /** Source gradient memory descriptor. */
996 mkldnn_memory_desc_t diff_src_desc;
997 /** Weights memory descriptor. */
998 mkldnn_memory_desc_t weights_desc;
999 /** Weights gradient memory descriptor. */
1000 mkldnn_memory_desc_t diff_weights_desc;
1001 /** Bias memory descriptor. */
1002 mkldnn_memory_desc_t bias_desc;
1003 /** Bias gradient memory descriptor. */
1004 mkldnn_memory_desc_t diff_bias_desc;
1005 /** Destination memory descriptor. */
1006 mkldnn_memory_desc_t dst_desc;
1007 /** Destination gradient memory descriptor. */
1008 mkldnn_memory_desc_t diff_dst_desc;
1009 /** The accumulator data type. Initialized automatically. */
1010 mkldnn_data_type_t accum_data_type;
1011 } mkldnn_inner_product_desc_t;
1013 /** Flags for RNN cell. */
1015 mkldnn_rnn_cell_with_relu = 0x1U,
1016 mkldnn_rnn_cell_with_clipping = 0x2U,
1017 } mkldnn_rnn_cell_flags_t;
1020 /** RNN cell kind. Must be one of #mkldnn_vanilla_rnn,
1021 * #mkldnn_vanilla_lstm, #mkldnn_vanilla_gru,
1022 * or #mkldnn_gru_linear_before_reset. */
1023 mkldnn_alg_kind_t cell_kind;
1024 /** Activation function used. Must be either #mkldnn_eltwise_relu or
1025 * #mkldnn_eltwise_tanh. */
1026 mkldnn_alg_kind_t activation_kind;
1027 /** RNN cell flags */
1029 /** @c alpha is a negative slope parameter (used only if
1030 * `(flags & #mkldnn_rnn_cell_with_relu) != 0`) */
1032 /** clipping parameter (used only if
1033 * `(flags & #mkldnn_rnn_cell_with_clipping) != 0`) */
1035 } mkldnn_rnn_cell_desc_t;
1037 /** A direction of RNN primitive execution. */
1039 /* Unidirectional execution of RNN primitive from left to right. */
1040 mkldnn_unidirectional_left2right,
1041 /* Unidirectional execution of RNN primitive from right to left. */
1042 mkldnn_unidirectional_right2left,
1043 /* Bidirectional execution of RNN primitive with concatenation of the
1045 mkldnn_bidirectional_concat,
1046 /* Bidirectional execution of RNN primitive with summation of the
1048 mkldnn_bidirectional_sum,
1049 mkldnn_unidirectional = mkldnn_unidirectional_left2right,
1050 } mkldnn_rnn_direction_t;
1052 /** A descriptor for an RNN operation. */
1054 /** The kind of primitive. Used for self-identifying the primitive
1055 * descriptor. Must be #mkldnn_rnn. */
1056 mkldnn_primitive_kind_t primitive_kind;
1057 /** The kind of propagation. Possible values: #mkldnn_forward_training,
1058 * #mkldnn_forward_inference, and #mkldnn_backward. */
1059 mkldnn_prop_kind_t prop_kind;
1060 /** The RNN cell desc. */
1061 mkldnn_rnn_cell_desc_t cell_desc;
1062 /** The direction of RNN primitive execution. */
1063 mkldnn_rnn_direction_t direction;
1064 /** Source layer memory descriptor. */
1065 mkldnn_memory_desc_t src_layer_desc;
1066 /** Source iteration memory descriptor. */
1067 mkldnn_memory_desc_t src_iter_desc;
1068 /** Weights layer memory descriptor. */
1069 mkldnn_memory_desc_t weights_layer_desc;
1070 /** Weights iteration memory descriptor. */
1071 mkldnn_memory_desc_t weights_iter_desc;
1072 /** Bias memory descriptor. */
1073 mkldnn_memory_desc_t bias_desc;
1074 /** Destination layer memory descriptor. */
1075 mkldnn_memory_desc_t dst_layer_desc;
1076 /** Destination iter memory descriptor. */
1077 mkldnn_memory_desc_t dst_iter_desc;
1078 /** Source gradient layer memory descriptor. */
1079 mkldnn_memory_desc_t diff_src_layer_desc;
1080 /** Source gradient iter memory descriptor. */
1081 mkldnn_memory_desc_t diff_src_iter_desc;
1082 /** Weights gradient layer memory descriptor. */
1083 mkldnn_memory_desc_t diff_weights_layer_desc;
1084 /** Weights gradient iter memory descriptor. */
1085 mkldnn_memory_desc_t diff_weights_iter_desc;
1086 /** Bias gradient memory descriptor. */
1087 mkldnn_memory_desc_t diff_bias_desc;
1088 /** Destination gradient layer memory descriptor. */
1089 mkldnn_memory_desc_t diff_dst_layer_desc;
1090 /** Destination gradient iteration memory descriptor. */
1091 mkldnn_memory_desc_t diff_dst_iter_desc;
1092 } mkldnn_rnn_desc_t;
1094 /** A descriptor of a ROI Pooling operation. */
1096 /** The kind of primitive. Used for self identifying the primitive
1097 * descriptor. Must be #mkldnn_roi_pooling. */
1098 mkldnn_primitive_kind_t primitive_kind;
1099 /** The kind of propagation. Possible values: #mkldnn_forward. */
1100 mkldnn_prop_kind_t prop_kind;
1101 /** Source memory descriptor. */
1102 mkldnn_memory_desc_t* src_desc;
1103 /** Destination memory descriptor. */
1104 mkldnn_memory_desc_t dst_desc;
1106 /** Primitive parameters. */
1109 double spatial_scale;
1111 mkldnn_alg_kind_t alg_kind;
1112 } mkldnn_roi_pooling_desc_t;
1114 /** A descriptor of a binary convolution operation. */
1116 /** The kind of primitive. Used for self identifying the primitive
1117 * descriptor. Must be #mkldnn_binary_convolution. */
1118 mkldnn_primitive_kind_t primitive_kind;
1119 /** The kind of propagation. Possible values: #mkldnn_forward_training,
1120 * #mkldnn_forward_inference */
1121 mkldnn_prop_kind_t prop_kind;
1122 /** The kind of the binary convolution algorithm. Possible values:
1123 * #mkldnn_binary_convolution_direct. */
1124 mkldnn_alg_kind_t alg_kind;
1125 /** Source memory descriptor. */
1126 mkldnn_memory_desc_t src_desc;
1127 /** Weights memory descriptor. */
1128 mkldnn_memory_desc_t weights_desc;
1129 /** Destination memory descriptor. */
1130 mkldnn_memory_desc_t dst_desc;
1131 /** Convolution strides in each spatial dimension. */
1132 mkldnn_dims_t strides;
1133 /** Convolution dilates in each spatial dimension. */
1134 mkldnn_dims_t dilates;
1135 /** Padding in each spatial dimension. padding[0] is a padding in the
1136 * beginning (@p padding_l), padding[1] is a padding in the end (@p
1138 mkldnn_dims_t padding[2];
1139 /** The accumulator data type. Initialized automatically. */
1140 mkldnn_data_type_t accum_data_type;
1141 /** Logic value of elements in padding area */
1143 } mkldnn_binary_convolution_desc_t;
1145 /** A descriptor of a binarization operation. */
1147 /** The kind of primitive. Used for self identifying the primitive
1148 * descriptor. Must be #mkldnn_binarization. */
1149 mkldnn_primitive_kind_t primitive_kind;
1150 /** The kind of propagation. Possible values: #mkldnn_forward_training,
1151 * #mkldnn_forward_inference, #mkldnn_backward, and #mkldnn_backward_data.
1153 mkldnn_prop_kind_t prop_kind;
1154 /** The kind of binarization algorithm. Possible values: #mkldnn_binarization_depthwise */
1155 mkldnn_alg_kind_t alg_kind;
1156 /** Source memory descriptor. */
1157 mkldnn_memory_desc_t src_desc;
1158 /** Destination memory descriptor. */
1159 mkldnn_memory_desc_t dst_desc;
1160 /** Weights memory descriptor. */
1161 mkldnn_memory_desc_t weights_desc;
1162 } mkldnn_binarization_desc_t;
1166 /** @addtogroup c_api_engine_types Engine
1169 /** @brief Kinds of engines. */
1171 /** An unspecified engine. */
1175 } mkldnn_engine_kind_t;
1177 /** @struct mkldnn_engine
1178 * @brief An opaque structure to describe an engine. */
1179 struct mkldnn_engine;
1180 /** @brief An engine handle. */
1181 typedef struct mkldnn_engine *mkldnn_engine_t;
1183 /* FIXME: looks like this never happens */
1184 /** @brief A constant engine handle. */
1185 typedef const struct mkldnn_engine *const_mkldnn_engine_t;
1190 /** @addtogroup c_api_primitive_desc_iterators Primitive descriptor iterators
1193 /** @struct mkldnn_primitive_desc_iterator
1194 * @brief An opaque structure to describe a primitive descriptor iterator. */
1195 struct mkldnn_primitive_desc_iterator;
1197 /** @brief A primitive descriptor iterator handle. */
1198 typedef struct mkldnn_primitive_desc_iterator
1199 *mkldnn_primitive_desc_iterator_t;
1201 /** @brief A constant primitive descriptor iterator handle. */
1202 typedef const struct mkldnn_primitive_desc_iterator
1203 *const_mkldnn_primitive_desc_iterator_t;
1207 /** @addtogroup c_api_primitive_descs Primitive descriptors
1210 /** @struct mkldnn_primitive_desc
1211 * @brief An opaque structure to describe a primitive descriptor. */
1212 struct mkldnn_primitive_desc;
1214 /** @brief A primitive descriptor handle. */
1215 typedef struct mkldnn_primitive_desc *mkldnn_primitive_desc_t;
1217 /** @brief A constant primitive descriptor handle. */
1218 typedef const struct mkldnn_primitive_desc *const_mkldnn_primitive_desc_t;
1222 /** @addtogroup c_api_primitive_attr Primitive descriptor attributes
1225 /** @struct mkldnn_primitive_attr
1226 * @brief An opaque structure for primitive descriptor attributes.
1228 * Attributes may contain:
1229 * - rounding mode for integer based primitives (like convolution, reorders)
1230 * - output scales (to scale the result prior to storing it to the memory)
1232 struct mkldnn_primitive_attr;
1234 /** @brief A primitive descriptor attributes handle that controls primitive
1236 typedef struct mkldnn_primitive_attr *mkldnn_primitive_attr_t;
1238 /** @brief A constant primitive descriptor attributes handle. */
1239 typedef const struct mkldnn_primitive_attr *const_mkldnn_primitive_attr_t;
1241 /** @struct mkldnn_post_ops
1242 * @brief An opaque structure for a chain of post operations.
1244 * mkldnn_post_ops can be used to perform some (trivial) operations like
1245 * accumulation or eltwise after certain primitives like convolution.
1247 * Post operations might be combined together, making a chain of post
1248 * operations. For instance one can configure convolution followed by
1249 * accumulation followed by eltwise. This might be especially beneficial
1250 * for residual learning blocks.
1253 * Of course not all combinations are supported, so the user should handle
1254 * errors accordingly.
1256 * Supported post operations:
1257 * - accumulation (base primitive: convolution)
1258 * - eltwise (base primitive: convolution)
1260 struct mkldnn_post_ops;
1262 /** @brief A post operation chain handle. */
1263 typedef struct mkldnn_post_ops *mkldnn_post_ops_t;
1265 /** @brief A constant post operation chain handle. */
1266 typedef const struct mkldnn_post_ops *const_mkldnn_post_ops_t;
1270 /** @addtogroup c_api_types_primitive Primitive
1273 /** @struct mkldnn_primitive
1274 * An opaque structure to describe a primitive. */
1275 struct mkldnn_primitive;
1276 /** A primitive handle. */
1277 typedef struct mkldnn_primitive *mkldnn_primitive_t;
1278 /** A constant primitive handle. */
1279 typedef const struct mkldnn_primitive *const_mkldnn_primitive_t;
1281 /** A wrapper structure to specify a particular output of a primitive. */
1283 /** Primitive to specify the output for. */
1284 const_mkldnn_primitive_t primitive;
1285 /** Desired output index. */
1286 size_t output_index;
1287 } mkldnn_primitive_at_t;
1291 /** @addtogroup c_api_types_query Queries
1294 /** Primitive descriptor query specification
1296 * For generic function mkldnn_primitive_desc_query(), the type of result must
1297 * agree with the queried argument. The correspondence table:
1298 * Query | type of result
1299 * --------------------------------------------------------------
1300 * #mkldnn_query_engine | mkldnn_engine_t *
1301 * #mkldnn_query_primitive_kind | mkldnn_primitive_kind_t *
1303 * *_s64 | ptrdiff_t *
1305 * *_str | const char **
1306 * #mkldnn_query_op_d | const_mkldnn_op_desc_t *
1307 * *_md | const mkldnn_memory_desc_t **
1308 * *_${op}_d | const mkldnn_${op}_desc_t **
1309 * *_pd | const_mkldnn_primitive_desc_t *
1312 * Rule of thumb: all opaque types and structures are returned by
1313 * reference. All numbers are returned by value.
1316 * All returned references point to constant objects and are valid only
1317 * during the lifetime of the queried primitive descriptor. Returned objects
1318 * must not be destroyed by the user. If you need to keep the object longer
1319 * than the lifetime of the queried primitive descriptor, use
1320 * mkldnn_primitive_desc_clone() to make a copy. */
1322 mkldnn_query_undef = 0, /**< no query */
1324 mkldnn_query_engine, /**< execution engine */
1325 mkldnn_query_primitive_kind, /**< primitive kind */
1327 mkldnn_query_num_of_inputs_s32, /**< number of inputs expected */
1328 mkldnn_query_num_of_outputs_s32, /**< number of outputs expected */
1330 mkldnn_query_time_estimate_f64, /**< runtime estimation (seconds) */
1331 mkldnn_query_memory_consumption_s64, /**< memory consumption -- extra
1332 (scratch) memory, additional to all
1333 inputs and outputs memory (bytes) */
1335 mkldnn_query_impl_info_str, /**< implementation name */
1337 /* memory and op descriptor section */
1338 mkldnn_query_some_d = 64, /**< stub */
1339 mkldnn_query_op_d, /**< op descriptor */
1340 mkldnn_query_memory_d, /**< memory descriptor for memory and view */
1341 mkldnn_query_convolution_d, /**< convolution descriptor */
1342 mkldnn_query_deconvolution_d, /**< deconvolution descriptor */
1343 mkldnn_query_shuffle_d, /**< shuffle descriptor */
1344 mkldnn_query_eltwise_d, /**< eltwise descriptor */
1345 mkldnn_query_softmax_d, /**< softmax descriptor */
1346 mkldnn_query_pooling_d, /**< pooling descriptor */
1347 mkldnn_query_lrn_d, /**< lrn descriptor */
1348 mkldnn_query_batch_normalization_d, /**< batch normalization descriptor */
1349 mkldnn_query_inner_product_d, /**< inner product descriptor */
1350 mkldnn_query_rnn_d, /**< rnn descriptor */
1351 mkldnn_query_roi_pooling_d, /**< roi descriptor */
1352 mkldnn_query_depthwise_d, /**< eltwise descriptor */
1353 mkldnn_query_binary_convolution_d, /**< binary convolution descriptor */
1354 mkldnn_query_binarization_d, /**< binarization descriptor */
1356 /* (memory) primitive descriptor section */
1357 mkldnn_query_some_pd = 128, /**< stub */
1358 mkldnn_query_input_pd, /**< input memory primitive desc */
1359 mkldnn_query_output_pd, /**< output memory primitive desc */
1360 mkldnn_query_src_pd, /**< source memory primitive desc */
1361 mkldnn_query_diff_src_pd, /**< source gradient memory primitive desc */
1362 mkldnn_query_weights_pd, /**< weights memory primitive descriptor desc */
1363 mkldnn_query_diff_weights_pd, /**< weights grad. memory primitive desc */
1364 mkldnn_query_dst_pd, /**< destination memory primitive desc */
1365 mkldnn_query_diff_dst_pd, /**< destination grad. memory primitive desc */
1366 mkldnn_query_workspace_pd, /**< workspace memory primitive desc */
1371 /** @addtogroup c_api_types_stream Execution stream
1374 /** @brief Kinds of streams. */
1376 /** An unspecified engine. */
1378 /** Eager stream. */
1382 } mkldnn_stream_kind_t;
1384 /** @struct mkldnn_stream
1385 * An opaque structure to describe an execution stream. */
1386 struct mkldnn_stream;
1387 /** An execution stream handle. */
1388 typedef struct mkldnn_stream *mkldnn_stream_t;
1389 /** A constant execution stream handle. */
1390 typedef const struct mkldnn_stream *const_mkldnn_stream_t;