Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / include / mkldnn_types.h
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef MKLDNN_TYPES_H
18 #define MKLDNN_TYPES_H
19
20 #ifdef __cplusplus
21 extern "C" {
22 #endif
23
24 #ifndef DOXYGEN_SHOULD_SKIP_THIS
25 #include <stddef.h>
26 #include <stdint.h>
27 #endif
28
29 /** @addtogroup c_api C API
30  *  @{
31  *
32  *  @addtogroup c_api_types Types
33  *  @{
34  *
35  *  @addtogroup c_api_types_generic Generic
36  *  @{ */
37
38 /** Intel(R) MKL-DNN Version type */
39 typedef struct {
40     int    major;
41     int    minor;
42     int    patch;
43     const char *hash;
44 } mkldnn_version_t;
45
46 /** Status values returned by Intel(R) MKL-DNN functions. */
47 typedef enum {
48     /** The operation was successful */
49     mkldnn_success = 0,
50     /** The operation failed due to an out-of-memory condition */
51     mkldnn_out_of_memory = 1,
52     /** The operation failed and should be retried */
53     mkldnn_try_again = 2,
54     /** The operation failed because of incorrect function arguments  */
55     mkldnn_invalid_arguments = 3,
56     /** The operation failed because a primitive was not ready for execution */
57     mkldnn_not_ready = 4,
58     /** The operation failed because requested functionality is not implemented
59      */
60     mkldnn_unimplemented = 5,
61     /** Primitive iterator passed over last primitive descriptor */
62     mkldnn_iterator_ends = 6,
63     /** Primitive or engine failed on execution */
64     mkldnn_runtime_error = 7,
65     /** Queried element is not required for given primitive */
66     mkldnn_not_required = 8,
67 } mkldnn_status_t;
68
69 /** Data type specification */
70 typedef enum {
71     /** Undefined data type, used for empty memory descriptors. */
72     mkldnn_data_type_undef = 0,
73     /** 32-bit/single-precision floating point. */
74     mkldnn_f32 = 1,
75     /** 32-bit signed integer. */
76     mkldnn_s32 = 2,
77     /** 16-bit signed integer. */
78     mkldnn_s16 = 4,
79     /** 8-bit signed integer. */
80     mkldnn_s8 = 5,
81     /** 8-bit unsigned integer. */
82     mkldnn_u8 = 6,
83     /** 1-bit integer. */
84     mkldnn_bin = 7,
85 } mkldnn_data_type_t;
86
87 /** Rounding mode */
88 typedef enum {
89     /** Round nearest */
90     mkldnn_round_nearest = 1,
91     /** Round down */
92     mkldnn_round_down = 2,
93 } mkldnn_round_mode_t;
94
95 /** Memory format specification.
96  *
97  * Intel MKL-DNN formats describe physical data layout. The physical layout
98  * is described as a sequence of the dimensions as they are laid out in the
99  * memory (from the outer-most to the inner-most). Note that this order
100  * doesn't affect the logical order of the dimensions that is kept in the
101  * `dims` field of the mkldnn_memory_desc_t structure. The logical order of the
102  * dimensions is specified by the type of tensor.
103  *
104  * For example, CNN 5D tensor always has its logical dimensions in the order
105  * `(batch, channels, depth, height, width)`, while the physical layout might be
106  * #mkldnn_ncdhw or #mkldnn_ndhwc:
107  *
108  * ~~~cpp
109  * int batch = 2, channels = 16, depth = 13, height = 13, width = 13;
110  *
111  * int ndims = 5; // 5D tensor
112  * mkldnn_dims_t dims = {batch, channels, depth, height, width};
113  *
114  * mkldnn_memory_desc_t data_in_ncdhw;
115  * mkldnn_memory_desc_init(&data_in_ncdhw, 5, dims, mlkdnn_ncdhw);
116  *
117  * // note that in both cases dims passed are the same
118  * mkldnn_memory_desc_t data_in_ndhwc;
119  * mkldnn_memory_desc_init(&data_in_ndhwc, 5, dims, mlkdnn_ndhwc);
120  * ~~~
121  *
122  * The following notation applies to memory format names:
123  *  - @c 'n' denotes the mini-batch dimension
124  *  - @c 'c' denotes a channels dimension
125  *  - When there are multiple channel dimensions (for example, in convolution
126  *    weights tensor), @c 'i' and @c 'o' denote dimensions of input and output
127  *    channels
128  *  - @c 'd', @c 'h', and @c 'w' denote spatial depth, height, and width
129  *    respectively
130  *  - Upper-case letters indicate that the data is laid out in blocks
131  *    for a particular dimension. In such cases, the format name contains both
132  *    upper- and lower-case letters for that dimension with a lower-case letter
133  *    preceded by the block size. For example: @c 'mkldnn_nChw8c' describes a
134  *    format where the outermost dimension is mini-batch, followed by the
135  *    channel block number, followed by the spatial height and width, and
136  *    finally followed by 8-element channel blocks.
137  *
138  * @note
139  *    Channel designations can be different. For example, both the @c
140  *    'mkldnn_nc' and @c 'mkldnn_io' formats can be used to describe a 2D
141  *    tensor.
142  *
143  * @sa @ref understanding_memory_formats
144  */
145 typedef enum {
146     /** Undefined memory format, used for empty memory descriptors. */
147     mkldnn_format_undef = 0,
148     /** Unspecified format. The primitive selects a format
149      * automatically. */
150     mkldnn_any,
151     /** A tensor in a generic format described by the stride and blocking
152      * values in each dimension. See #mkldnn_blocking_desc_t for more
153      * information. */
154     mkldnn_blocked,
155     /** 1D data tensor. */
156     mkldnn_x,
157     /** 2D data tensor. */
158     mkldnn_nc,
159     /** 3D data tensor with the physical layout @c ncw.
160      * Logical dimensions come in the order: (n, c, w) */
161     mkldnn_ncw,
162     /** 3D data tensor with the physical layout @c nwc.
163      * Logical dimensions come in the order: (n, c, w) */
164     mkldnn_nwc,
165     /** 4D data tensor with the physical layout @c nchw, used in Caffe.
166      * Logical dimensions come in the order: (n, c, h, w) */
167     mkldnn_nchw,
168     /** 4D data tensor with the physical layout @c nhwc, used in TensorFlow.
169      * Logical dimensions come in the order: (n, c, h, w) */
170     mkldnn_nhwc,
171     /** 4D data tensor with the physical layout @c chwn, used in Neon.
172      * Logical dimensions come in the order: (n, c, h, w) */
173     mkldnn_chwn,
174     /** 5D data tensor with the physical layout @c ncdhw.
175      * Logical dimensions come in the order: (n, c, d, h, w) */
176     mkldnn_ncdhw,
177     /** 5D data tensor with the physical layout @c ndhwc, used in TensorFlow.
178      * Logical dimensions come in the order: (n, c, d, h, w) */
179     mkldnn_ndhwc,
180     /** 2D weights tensor with physical layout @c oi.
181      * Logical dimensions come in the order: (o, i) */
182     mkldnn_oi,
183     /** 2D weights tensor with physical layout @c io.
184      * Logical dimensions come in the order: (o, i) */
185     mkldnn_io,
186     /** 3D weights tensor with physical layout @c oiw.
187      * Logical dimensions come in the order: (o, i, w) */
188     mkldnn_oiw,
189     /** 3D weights tensor with physical layout @c wio.
190      * Logical dimensions come in the order: (o, i, w) */
191     mkldnn_wio,
192     /** 4D weights tensor with physical layout @c oihw, used in Caffe.
193      * Logical dimensions come in the order: (o, i, h, w) */
194     mkldnn_oihw,
195     /** 4D weights tensor with physical layout @c hwio, used in TensorFlow.
196      * Logical dimensions come in the order: (o, i, h, w) */
197     mkldnn_hwio,
198     /** 4D weights tensor with physical layout @c ihwo.
199      * Logical dimensions come in the order: (o, i, h, w) */
200     mkldnn_ihwo,
201     /** 4D weights tensor with physical layout @c iohw.
202      * Logical dimensions come in the order: (o, i, h, w) */
203     mkldnn_iohw,
204     /** 5D weights tensor with physical layout @c iodhw, used in Caffe.
205      * Logical dimensions come in the order: (o, i, d, h, w) */
206     mkldnn_oidhw,
207     /** 5D weights tensor with physical layout @c dhwio, used in TensorFlow.
208      * Logical dimensions come in the order: (o, i, d, h, w) */
209     mkldnn_dhwio,
210     /** 4D grouped weights tensor with the physical layout @c goiw.
211      * Logical dimensions come in the order: (g, o, i, w) */
212     mkldnn_goiw,
213     /** 5D grouped weights tensor with the physical layout @c goihw,
214      * used in Caffe.
215      * Logical dimensions come in the order: (g, o, i, h, w) */
216     mkldnn_goihw,
217     /** 5D grouped weights tensor with the physical layout @c hwigo,
218      * used in TensorFlow.
219      * Logical dimensions come in the order: (g, o, i, h, w) */
220     mkldnn_hwigo,
221     /** 5D grouped weights tensor with the physical layout @c giohw.
222      * Logical dimensions come in the order: (g, o, i, h, w) */
223     mkldnn_giohw,
224     /** 6D grouped weights tensor with the physical layout @c goidhw,
225      * used in Caffe.
226      * Logical dimensions come in the order: (g, o, i, d, h, w) */
227     mkldnn_goidhw,
228     /** 3D RNN data tensor in the format (batch, seq_length, input channels). */
229     mkldnn_ntc,
230     /** 3D RNN data tensor in the format (seq_length, batch, input channels). */
231     mkldnn_tnc,
232     /** 5D RNN states tensor in the format (num_layers, num_directions,
233      * num_states, batch, state channels). */
234     mkldnn_ldsnc,
235     /** 5D RNN weights tensor in the format (num_layers, num_directions,
236      *  input_channels, num_gates, output_channels).
237      *
238      *  - For LSTM cells, the gates order is input, forget, candidate
239      *    and output gate.
240      *  - For GRU cells, the gates order is update, reset and output gate. */
241     mkldnn_ldigo,
242     /** 5D RNN weights tensor in the format (num_layers, num_directions,
243      * num_gates, output_channels, input_channels).
244      *
245      *  - For LSTM cells, the gates order is input, forget, candidate
246      *    and output gate.
247      *  - For GRU cells, the gates order is update, reset and output gate. */
248     mkldnn_ldgoi,
249     /** 4D RNN bias tensor in the format (num_layers, num_directions,
250      * num_gates, output_channels).
251      *
252      *  - For LSTM cells, the gates order is input, forget, candidate
253      *    and output gate.
254      *  - For GRU cells, the gates order is update, reset and output gate. */
255     mkldnn_ldgo,
256
257     /* Opaque data types, are not to be used explicitly */
258
259     /* data */
260     mkldnn_nCw4c /** blocked data format */,
261     mkldnn_nCw8c /** blocked data format */,
262     mkldnn_nCw16c /** blocked data format */,
263     mkldnn_nChw4c /** blocked data format */,
264     mkldnn_nChw8c /** blocked data format */,
265     mkldnn_nChw16c /** blocked data format */,
266     mkldnn_nCdhw4c /** blocked data format */,
267     mkldnn_nCdhw8c /** blocked data format */,
268     mkldnn_nCdhw16c /** blocked data format */,
269
270     /* weights, 3D */
271     mkldnn_Owi4o /** blocked weights format */,
272     mkldnn_OIw4i4o /** blocked weights format */,
273     mkldnn_Owi8o /** blocked weights format */,
274     mkldnn_OIw8i8o /** blocked weights format */,
275     mkldnn_OIw8o8i /** blocked weights format */,
276     mkldnn_OIw16i16o /** blocked weights format */,
277     mkldnn_OIw16o16i /** blocked weights format */,
278     mkldnn_Oiw4o /** blocked weights format */,
279     mkldnn_Oiw16o /** blocked weights format */,
280     mkldnn_Owi16o /** blocked weights format */,
281     mkldnn_OIw8i16o2i /** blocked weights format */,
282     mkldnn_OIw8o16i2o /** blocked weights format */,
283     mkldnn_IOw16o16i /** blocked weights format */,
284
285     /* weights, 4D */
286     /** weights format with additional buffer
287      * size equal to the number of output channels
288      * and containing the values:
289      * O[i:0,OC] = -128 * SUM(j:0,IC;h:0,H;w:0,W)(weights(i,j,h,w))*/
290     mkldnn_hwio_s8s8,
291     mkldnn_oIhw8i /** blocked weights format */,
292     mkldnn_oIhw16i /** blocked weights format */,
293     mkldnn_OIhw4i4o /** blocked weights format */,
294     mkldnn_OIhw8i8o /** blocked weights format */,
295     mkldnn_OIhw16i16o /** blocked weights format */,
296     mkldnn_OIhw4i16o4i /** blocked weights format */,
297     /** blocked weights format with additional buffer
298      * with size equal to the number of output channels
299      * and containing the values:
300      * O[i:0,OC] = -128 * SUM(j:0,IC;h:0,H;w:0,W)(weights(i,j,h,w))*/
301     mkldnn_OIhw4i16o4i_s8s8,
302     mkldnn_OIhw8i16o2i /** blocked weights format */,
303     mkldnn_OIhw8o16i2o /** blocked weights format */,
304     mkldnn_OIhw8o8i /** blocked weights format */,
305     mkldnn_OIhw16o16i /** blocked weights format */,
306     mkldnn_IOhw16o16i /** blocked weights format */,
307     mkldnn_Oihw8o /** blocked weights format */,
308     mkldnn_Oihw4o /** blocked weights format */,
309     mkldnn_Oihw16o /** blocked weights format */,
310     mkldnn_Ohwi8o /** blocked weights format */,
311     mkldnn_Ohwi4o /** blocked weights format */,
312     mkldnn_Ohwi16o /** blocked weights format */,
313     mkldnn_OhIw16o4i /** blocked weights format */,
314     mkldnn_OhIw8o4i /** blocked weights format */,
315     /** blocked weights format with additional buffer
316      * with size equal to the number of output channels
317      * and containing the values:
318      * O[i:0,OC] = -128 * SUM(j:0,IC;h:0,H;w:0,W)(weights(i,j,h,w))*/
319     mkldnn_OhIw8o4i_s8s8,
320     mkldnn_OhIw8o32i /** blocked weights format */,
321     mkldnn_OhIw16o32i /** blocked weights format */,
322
323     /* weights, 5D */
324     mkldnn_oIdhw8i /** blocked weights format */,
325     mkldnn_oIdhw16i /** blocked weights format */,
326     mkldnn_OIdhw4i4o /** blocked weights format */,
327     mkldnn_Odhwi4o /** blocked weights format */,
328     mkldnn_OIdhw8i8o /** blocked weights format */,
329     mkldnn_OIdhw8o8i /** blocked weights format */,
330     mkldnn_Odhwi8o /** blocked weights format */,
331     mkldnn_OIdhw16i16o /** blocked weights format */,
332     mkldnn_OIdhw16o16i /** blocked weights format */,
333     mkldnn_Oidhw4o /** blocked weights format */,
334     mkldnn_Oidhw16o /** blocked weights format */,
335     mkldnn_Odhwi16o /** blocked weights format */,
336     mkldnn_OIdhw8i16o2i /** blocked weights format */,
337
338     /* weights w/ groups, 4D */
339     mkldnn_gOwi4o /** blocked weights format */,
340     mkldnn_gOIw4i4o /** blocked weights format */,
341     mkldnn_gOwi8o /** blocked weights format */,
342     mkldnn_gOIw8o8i /** blocked weights format */,
343     mkldnn_gOIw8i8o /** blocked weights format */,
344     mkldnn_gOIw16i16o /** blocked weights format */,
345     mkldnn_gOIw16o16i /** blocked weights format */,
346     mkldnn_gOiw4o /** blocked weights format */,
347     mkldnn_gOiw16o /** blocked weights format */,
348     mkldnn_gOwi16o /** blocked weights format */,
349     mkldnn_gOIw8i16o2i /** blocked weights format */,
350     mkldnn_gOIw8o16i2o /** blocked weights format */,
351     mkldnn_gIOw16o16i /** blocked weights format */,
352
353     /* weights w/ groups, 5D */
354     /** weights format with additional buffer
355      * size equal to the number of output channels
356      * multiplied by number of groups and containing the values:
357      * O[i:0,G*OC] = -128 * SUM(j:0,IC;h:0,H;w:0,W)(weights(i,j,h,w))*/
358     mkldnn_hwigo_s8s8,
359     mkldnn_gOIhw4i4o /** blocked weights format */,
360     mkldnn_gOIhw8i8o /** blocked weights format */,
361     mkldnn_gOIhw16i16o /** blocked weights format */,
362     mkldnn_gOIhw4i16o4i /** blocked weights format */,
363     /** blocked weights format with additional buffer
364      * with size equal to the number of output channels
365      * multiplied by number of groups and containing the values:
366      * O[i:0,G*OC] = -128 * SUM(j:0,IC;h:0,H;w:0,W)(weights(i,j,h,w))*/
367     mkldnn_gOIhw4i16o4i_s8s8,
368     mkldnn_gOIhw2i8o4i /** blocked weights format */,
369     /** blocked weights format with additional buffer
370      * with size equal to the number of output channels
371      * multiplied by number of groups and containing the values:
372      * O[i:0,G*OC] = -128 * SUM(j:0,IC;h:0,H;w:0,W)(weights(i,j,h,w))*/
373     mkldnn_gOIhw2i8o4i_s8s8,
374     mkldnn_gOIhw8i16o2i /** blocked weights format */,
375     mkldnn_gOIhw8o16i2o /** blocked weights format */,
376     mkldnn_gOIhw4o4i /** blocked weights format */,
377     /** blocked weights format with additional buffer
378      * with size equal to the number of output channels
379      * and containing the values:
380      * O[i:0,OC] = -128 * SUM(j:0,IC;h:0,H;w:0,W)(weights(i,j,h,w))*/
381     mkldnn_gOIhw4o4i_s8s8 /** blocked weights format */,
382     mkldnn_gOIhw8o8i /** blocked weights format */,
383     mkldnn_gOIhw16o16i /** blocked weights format */,
384     mkldnn_gIOhw16o16i /** blocked weights format */,
385     mkldnn_gOihw8o /** blocked weights format */,
386     mkldnn_gOihw4o /** blocked weights format */,
387     mkldnn_gOihw16o /** blocked weights format */,
388     mkldnn_gOhwi8o /** blocked weights format */,
389     mkldnn_gOhwi4o /** blocked weights format */,
390     mkldnn_gOhwi16o /** blocked weights format */,
391     mkldnn_Goihw8g /** blocked weights format */,
392     mkldnn_Goihw16g /** blocked weights format */,
393     /** blocked weights format with additional buffer
394      * with size equal to the number of groups and containing the values:
395      * O[i:0,G] = -128 * SUM(h:0,H;w:0,W)(weights(i,i,h,w))*/
396     mkldnn_Goihw16g_s8s8,
397     mkldnn_gOhIw16o4i /** blocked weights format */,
398     mkldnn_gOhIw8o4i /** blocked weights format */,
399     /** blocked weights format with additional buffer
400      * with size equal to the number of output channels
401      * multiplied by number of groups and containing the values:
402      * O[i:0,G*OC] = -128 * SUM(j:0,IC;h:0,H;w:0,W)(weights(i,j,h,w))*/
403     mkldnn_gOhIw8o4i_s8s8,
404
405     /* weights w/ groups, 6D */
406     mkldnn_gOIdhw4i4o /** blocked weights format */,
407     mkldnn_gOdhwi4o /** blocked weights format */,
408     mkldnn_gOIdhw8i8o /** blocked weights format */,
409     mkldnn_gOIdhw8o8i /** blocked weights format */,
410     mkldnn_gOdhwi8o /** blocked weights format */,
411     mkldnn_gOIdhw8i16o2i /** blocked weights format */,
412     mkldnn_gOIdhw16i16o /** blocked weights format */,
413     mkldnn_gOIdhw16o16i /** blocked weights format */,
414     mkldnn_gOidhw4o /** blocked weights format */,
415     mkldnn_gOidhw16o /** blocked weights format */,
416     mkldnn_gOdhwi16o /** blocked weights format */,
417
418     mkldnn_wino_fmt /** Weights format used in 8bit Winograd convolution */,
419
420     mkldnn_rnn_packed /** Packed weights format used in RNN */,
421
422     /** Just a sentinel, not real memory format. Must be changed after new
423      * format is added. */
424     mkldnn_format_last,
425 } mkldnn_memory_format_t;
426
427 /** Kinds of padding. Define how to interpret the data in padding regions. */
428 typedef enum {
429     /** The data in padding regions is zero. */
430     mkldnn_padding_zero,
431 } mkldnn_padding_kind_t;
432
433 /** Kinds of propagation. */
434 typedef enum {
435     /* TODO: suggest renames */
436     /** Undefined propagation type. */
437     mkldnn_prop_kind_undef = 0,
438     /** Forward data propagation (training mode). In this mode primitives
439      * perform computations necessary for subsequent backward propagation. */
440     mkldnn_forward_training = 64,
441     /** Forward data propagation (inference mode). In this mode primitives
442      * perform only computations that are necessary for inference and omit
443      * computations that are necessary only for backward propagation. */
444     mkldnn_forward_inference = 96,
445     /** Forward data propagation (alias for @c mkldnn_forward_inference) */
446     mkldnn_forward_scoring = mkldnn_forward_inference,
447    /** Forward data propagation (alias for @c mkldnn_forward_training) */
448     mkldnn_forward = mkldnn_forward_training,
449     /** Backward propagation (with respect to all parameters */
450     mkldnn_backward = 128,
451     /** Backward data propagation */
452     mkldnn_backward_data = 160,
453     /** Backward weights propagation */
454     mkldnn_backward_weights = 192,
455     /** Backward bias propagation */
456     mkldnn_backward_bias = 193,
457 } mkldnn_prop_kind_t;
458
459 /** Kinds of primitives. Used to implement a way to extend the library with new
460  * primitives without changing the ABI. */
461 typedef enum {
462     /** Undefined primitive (XXX: why do we have it?). */
463     mkldnn_undefined_primitive,
464     /** A memory primitive. */
465     mkldnn_memory,
466     /** A view primitive. */
467     mkldnn_view,
468     /** A reorder primitive.*/
469     mkldnn_reorder,
470     /** A shuffle primitive.*/
471     mkldnn_shuffle,
472     /** A (out-of-place) concat primitive. */
473     mkldnn_concat,
474     /** A (in-place) concat primitive. */
475     mkldnn_concat_inplace,
476     /** A sum primitive. */
477     mkldnn_sum,
478     /** A convolution primitive. */
479     mkldnn_convolution,
480     /** A deconvolution primitive. */
481     mkldnn_deconvolution,
482     /** An element-wise primitive. */
483     mkldnn_eltwise,
484     /** A Softmax primitive. */
485     mkldnn_softmax,
486     /** A pooling primitive. */
487     mkldnn_pooling,
488     /** An LRN primitive. */
489     mkldnn_lrn,
490     /** An batch normalization primitive. */
491     mkldnn_batch_normalization,
492     /** An inner product primitive. */
493     mkldnn_inner_product,
494     /** A rnn primitive. */
495     mkldnn_rnn,
496     /** A ROI pooling primitive. */
497     mkldnn_roi_pooling,
498     /** An channel-wise primitive. */
499     mkldnn_depthwise,
500     /** A binary convolution primitive. */
501     mkldnn_binary_convolution,
502     /** A binarization primitive. */
503     mkldnn_binarization,
504 } mkldnn_primitive_kind_t;
505
506 /** Kinds of algorithms. */
507 typedef enum {
508     mkldnn_alg_kind_undef,
509     /** Direct convolution */
510     mkldnn_convolution_direct = 0x1,
511     /** Winograd convolution */
512     mkldnn_convolution_winograd = 0x2,
513     /** Convolution algorithm(either direct or Winograd) is chosen just in time **/
514     mkldnn_convolution_auto = 0x3,
515     /** Direct deconvolution */
516     mkldnn_deconvolution_direct = 0xa,
517     /** Winograd deconvolution */
518     mkldnn_deconvolution_winograd = 0xb,
519     /** Eltwise: ReLU */
520     mkldnn_eltwise_relu = 0x1f,
521     /** Eltwise: hyperbolic tangent non-linearity (tanh) */
522     mkldnn_eltwise_tanh = 0x2f,
523     /** Eltwise: parametric exponential linear unit (elu) */
524     mkldnn_eltwise_elu = 0x3f,
525     /** Eltwise: square */
526     mkldnn_eltwise_square = 0x4f,
527     /** Eltwise: abs */
528     mkldnn_eltwise_abs = 0x5f,
529     /** Eltwise: square root */
530     mkldnn_eltwise_sqrt = 0x6f,
531     /** Eltwise: linear */
532     mkldnn_eltwise_linear = 0x7f,
533     /** Eltwise: bounded_relu */
534     mkldnn_eltwise_bounded_relu = 0x8f,
535     /** Eltwise: soft_relu */
536     mkldnn_eltwise_soft_relu = 0x9f,
537     /** Eltwise: logistic */
538     mkldnn_eltwise_logistic = 0xaf,
539     /** Eltwise: clamp */
540     mkldnn_eltwise_clamp = 0xbf,
541     /** Eltwise: exp */
542     mkldnn_eltwise_exp = 0xcf,
543     /** Eltwise: not */
544     mkldnn_eltwise_not = 0xdf,
545     /** Max pooling */
546     mkldnn_pooling_max = 0x1ff,
547     /** Average pooling include padding */
548     mkldnn_pooling_avg_include_padding = 0x2ff,
549     /** Average pooling exclude padding */
550     mkldnn_pooling_avg_exclude_padding = 0x3ff,
551     mkldnn_pooling_avg = mkldnn_pooling_avg_exclude_padding,
552     /** Local response normalization (LRN) across multiple channels */
553     mkldnn_lrn_across_channels = 0xaff,
554     /** LRN within a single channel */
555     mkldnn_lrn_within_channel = 0xbff,
556     /** RNN cell */
557     mkldnn_vanilla_rnn = 0x1fff,
558     /** LSTM cell */
559     mkldnn_vanilla_lstm = 0x2fff,
560     /** GRU cell */
561     mkldnn_vanilla_gru = 0x3fff,
562     /** GRU cell with linear before reset
563      *
564      * Modification of original GRU cell. Differs from #mkldnn_vanilla_gru
565      * in how the new memory gate is calculated:
566      * \f[ c_t = tanh(W_c*x_t + b_{c_x} + r_t*(U_c*h_{t-1}+b_{c_h})) \f]
567      * Primitive expects 4 biases on input:
568      * \f$[b_{u}, b_{r}, b_{c_x}, b_{c_h}]\f$
569      * */
570     mkldnn_gru_linear_before_reset = 0x4fff,
571     /** ROI max pooling **/
572     mkldnn_roi_pooling_max = 0xafff,
573     /** ROI pooling with bilinear interpolation**/
574     mkldnn_roi_pooling_bilinear = 0xbfff,
575     /** Depthwise: scale_shift */
576     mkldnn_depthwise_scale_shift = 0x1ffff,
577     /** Depthwise: prelu */
578     mkldnn_depthwise_prelu = 0x2ffff,
579     /** Direct binary convolution */
580     mkldnn_binary_convolution_direct = 0x1fffff,
581     /** Depthwise binarization */
582     mkldnn_binarization_depthwise = 0xafffff
583 } mkldnn_alg_kind_t;
584
585 /** Flags for batch-normalization primititve. */
586 typedef enum {
587     /** Use global statistics
588      *
589      * If specified
590      *  - on forward propagation use mean and variance provided by user (input)
591      *  - on backward propagation reduces the amount of computations, since
592      *    mean and variance are considered as constants
593      *
594      *  If not specified:
595      *   - on forward propagation mean and variance are computed and stored in
596      *     output
597      *   - on backward propagation compute full derivative wrt to data
598      */
599     mkldnn_use_global_stats = 0x1U,
600     /** Use scale and shift parameters
601      *
602      * If specified:
603      *  - on forward propagation use scale and shift (aka scale and bias) for
604      *    the batch normalization results
605      *  - on backward propagation (for prop_kind == #mkldnn_backward) compute
606      *    diff wrt to scale and shift (hence one extra output used)
607      *
608      * If no specified:
609      *  - on backward propagation prop_kind == #mkldnn_backward_data has the
610      *    same behavior as prop_kind == #mkldnn_backward
611      */
612     mkldnn_use_scaleshift = 0x2U,
613     /** Fuse with ReLU
614      *
615      * If specified:
616      *  - on inference this option behaves the same as if the primitive were
617      *    fused with ReLU via post ops API
618      *  - on training primitive requires workspace (required to be able to
619      *    perform backward pass)
620      */
621     mkldnn_fuse_bn_relu = 0x4U,
622 } mkldnn_batch_normalization_flag_t;
623
624 /** @} */
625
626 /** @addtogroup c_api_types_memory Auxiliary types for memory description
627  *  @{ */
628
629 /** Maximum number of dimensions a tensor can have. Only restricts the amount
630  * of space used for the tensor description. Individual computational
631  * primitives may support only tensors of certain dimensions. */
632 #define TENSOR_MAX_DIMS 12
633
634 /** A type to describe tensor dimensions. */
635 typedef ptrdiff_t mkldnn_dims_t[TENSOR_MAX_DIMS];
636 /** A type to describe strides within a tensor. */
637 typedef ptrdiff_t mkldnn_strides_t[TENSOR_MAX_DIMS];
638
639 /** Generic description of blocked data layout for most memory formats.
640  *
641  * @sa @ref understanding_memory_formats */
642 typedef struct {
643     /** Block size for each of the dimensions. */
644     mkldnn_dims_t block_dims;
645     /** strides[0]: stride between the first elements of adjacent blocks.
646      * @n strides[1]: strides between elements in the same block. */
647     mkldnn_strides_t strides[2];
648     /** Size of the data including padding in each dimension. */
649     mkldnn_dims_t padding_dims;
650     /** Per-dimension offset from the padding to actual data, the top-level
651      * tensor with offsets applied must lie within the padding area. */
652     mkldnn_dims_t offset_padding_to_data;
653     /** Offset from memory origin to the current block, non-zero only in
654      * a description of a memory sub-block. */
655     ptrdiff_t offset_padding;
656 } mkldnn_blocking_desc_t;
657
658 typedef enum {
659     /** Undefined memory format, used for empty memory descriptors. */
660     mkldnn_wino_undef = 0,
661     /** Tensors of weights for 2x3 winograd convolutions. */
662     mkldnn_wino_wei_aaOIoi,
663     mkldnn_wino_wei_aaOio,
664     mkldnn_wino_wei_aaOBiOo,
665     /** Tensor of weights for 4x3 convolution. */
666     mkldnn_wino_wei_OBaaIBOIio
667 } mkldnn_wino_memory_format_t;
668
669 /** Description of tensor of weights for winograd 2x3 convolution. */
670 typedef struct {
671     mkldnn_wino_memory_format_t wino_format;
672     int r;
673     int alpha;
674     int ic;
675     int oc;
676     int ic_block;
677     int oc_block;
678     int ic2_block;
679     int oc2_block;
680     float adj_scale;
681     size_t size;
682 } mkldnn_wino_desc_t;
683
684 typedef enum {
685     mkldnn_packed_format_undef = 0,
686     mkldnn_ldigo_p,
687     mkldnn_ldgoi_p
688 } mkldnn_rnn_packed_memory_format_t;
689
690 /* Maximum number of parts of RNN weights tensor that require separate
691  * computation. */
692 #define MKLDNN_RNN_MAX_N_PARTS 4
693
694 /** Description of tensor of packed weights for rnn. */
695 typedef struct {
696     mkldnn_rnn_packed_memory_format_t format;
697     int n_parts;
698     int n;
699     int parts[MKLDNN_RNN_MAX_N_PARTS];
700     size_t part_pack_size[MKLDNN_RNN_MAX_N_PARTS];
701     size_t offset_compensation;
702     size_t size;
703 } mkldnn_rnn_packed_desc_t;
704
705 /** @addtogroup c_api_types_op_descs Operation descriptors
706  *  @{*/
707
708 /** A pointer to any of the operation descriptors. */
709 typedef void *mkldnn_op_desc_t;
710 /** A pointer to any of the operation descriptors (constant variant). */
711 typedef const void *const_mkldnn_op_desc_t;
712
713 /** Memory descriptor. The description is based on a number of dimensions,
714  * dimensions themselves, plus information about elements type and memory
715  * format. Additionally, contains format-specific descriptions of the data
716  * layout. */
717 typedef struct {
718     /** The kind of primitive. Used for self-identifying the primitive
719      * descriptor. Must be #mkldnn_memory. */
720     mkldnn_primitive_kind_t primitive_kind;
721     /** Number of dimensions */
722     int ndims;
723     /** Dimensions in the following order:
724      * - CNN data tensors: mini-batch, channel, spatial
725      *   (<code>{N, C, [[D,] H,] W}</code>)
726      * - CNN weight tensors: group (optional), output channel, input channel,
727      *   spatial (<code>{[G,] O, I, [[D,] H,] W}</code>)
728      * - RNN data tensors: time, mini-batch, channels (<code>{T, N, C}</code>)
729      *   or layers, directions, states, mini-batch, channels (<code>{L, D, S, N, C}</code>)
730      * - RNN weight tensor: layers, directions, input channel, gates, output channels
731      *   (<code>{L, D, I, G, O}</code>).
732      *
733      * @note
734      *    The order of dimensions does not depend on the memory format, so
735      *    whether the data is laid out in #mkldnn_nchw or #mkldnn_nhwc
736      *    the dims for 4D CN data tensor would be <code>{N, C, H, W}</code>.
737      */
738     mkldnn_dims_t dims;
739     /** Data type of the tensor elements. */
740     mkldnn_data_type_t data_type;
741     /** Memory format. */
742     mkldnn_memory_format_t format;
743     union {
744         /** Description of the data layout for memory formats that use
745          * blocking. */
746         mkldnn_blocking_desc_t blocking;
747         /** Tensor of weights for integer 8bit winograd convolution. */
748         mkldnn_wino_desc_t wino_desc;
749         /** Tensor of packed weights for RNN. */
750         mkldnn_rnn_packed_desc_t rnn_packed_desc;
751         /* ... other descriptions possible */
752     } layout_desc;
753 } mkldnn_memory_desc_t;
754
755 /** @} */
756
757 /** A descriptor of a convolution operation. */
758 typedef struct {
759     /** The kind of primitive. Used for self-identifying the primitive
760      * descriptor. Must be #mkldnn_convolution. */
761     mkldnn_primitive_kind_t primitive_kind;
762     /** The kind of propagation. Possible values: #mkldnn_forward_training,
763      * #mkldnn_forward_inference, #mkldnn_backward_data,
764      * #mkldnn_backward_weights, and #mkldnn_backward_bias. */
765     mkldnn_prop_kind_t prop_kind;
766     /** The kind of the convolution algorithm. Possible values:
767      * #mkldnn_convolution_direct. */
768     mkldnn_alg_kind_t alg_kind;
769     /** Source memory descriptor. */
770     mkldnn_memory_desc_t src_desc;
771     /** Source gradient memory descriptor. */
772     mkldnn_memory_desc_t diff_src_desc;
773     /** Weights memory descriptor. */
774     mkldnn_memory_desc_t weights_desc;
775     /** Weights gradient memory descriptor. */
776     mkldnn_memory_desc_t diff_weights_desc;
777     /** Bias memory descriptor. */
778     mkldnn_memory_desc_t bias_desc;
779     /** Bias gradient memory descriptor. */
780     mkldnn_memory_desc_t diff_bias_desc;
781     /** Destination memory descriptor. */
782     mkldnn_memory_desc_t dst_desc;
783     /** Destination gradient memory descriptor. */
784     mkldnn_memory_desc_t diff_dst_desc;
785     /** Convolution strides in each spatial dimension. */
786     mkldnn_dims_t strides;
787     /** Convolution dilates in each spatial dimension. */
788     mkldnn_dims_t dilates;
789     /** Padding in each spatial dimension. padding[0] is a padding in the
790      * beginning (@p padding_l), padding[1] is a padding in the end (@p
791      * padding_r). */
792     mkldnn_dims_t padding[2];
793     /** The kind of padding to use. */
794     mkldnn_padding_kind_t padding_kind;
795     /** The accumulator data type. Initialized automatically. */
796     mkldnn_data_type_t accum_data_type;
797 } mkldnn_convolution_desc_t;
798
799 /** A descriptor of a deconvolution operation. */
800 typedef mkldnn_convolution_desc_t mkldnn_deconvolution_desc_t;
801
802 /** A descriptor of a shuffle operation. */
803 typedef struct {
804     /** The kind of primitive. Used for self-identifying the primitive
805      * descriptor. Must be #mkldnn_convolution. */
806     mkldnn_primitive_kind_t primitive_kind;
807     /** The kind of propagation. Possible values: #mkldnn_forward_training,
808      * #mkldnn_forward_inference, and #mkldnn_backward_data. */
809     mkldnn_prop_kind_t prop_kind;
810     /** Source and destination memory descriptor,
811      *  and source and destination gradient memory descriptor. */
812     mkldnn_memory_desc_t data_desc;
813     /** axis for shuffling. */
814     int axis;
815     /** number of groups in group convolution */
816     int group_size;
817 } mkldnn_shuffle_desc_t;
818
819 /** A descriptor of a element-wise operation. */
820 typedef struct {
821     /** The kind of primitive. Used for self-identifying the primitive
822      * descriptor. Must be #mkldnn_eltwise. */
823     mkldnn_primitive_kind_t primitive_kind;
824     /** The kind of propagation. Possible values: #mkldnn_forward_training,
825      * #mkldnn_forward_inference, #mkldnn_backward, and #mkldnn_backward_data.
826      */
827     mkldnn_prop_kind_t prop_kind;
828     /** The kind of eltwise algorithm. Possible values: #mkldnn_eltwise_relu,
829      * #mkldnn_eltwise_tanh, #mkldnn_eltwise_elu, #mkldnn_eltwise_square,
830      * #mkldnn_eltwise_abs, #mkldnn_eltwise_sqrt, #mkldnn_eltwise_linear,
831      * #mkldnn_eltwise_bounded_relu, #mkldnn_eltwise_soft_relu, and
832      * #mkldnn_eltwise_logistic. */
833     mkldnn_alg_kind_t alg_kind;
834     /** Source and destination memory descriptor. */
835     mkldnn_memory_desc_t data_desc;
836     /** Source and destination gradient memory descriptor. */
837     mkldnn_memory_desc_t diff_data_desc;
838     /** Algorithm specific parameter.
839      * Accordance table:
840      *  - #mkldnn_eltwise_relu: @p alpha -- negative slope, @p beta ignored
841      *  - #mkldnn_eltwise_tanh: @p alpha and @p beta ignored
842      *  - #mkldnn_eltwise_elu: @p alpha -- negative slope, @p beta ignored
843      *  - #mkldnn_eltwise_square: @p alpha and @p beta ignored
844      *  - #mkldnn_eltwise_abs: @p alpha and @p beta ignored
845      *  - #mkldnn_eltwise_sqrt: @p alpha and @p beta ignored
846      *  - #mkldnn_eltwise_linear: @p alpha -- scale, @p beta -- shift
847      *  - #mkldnn_eltwise_bounded_relu: @p alpha -- upper bound, @p beta ignored
848      *  - #mkldnn_eltwise_soft_relu: @p alpha and @p beta ignored
849      *  - #mkldnn_eltwise_logistic: @p alpha and @p beta ignored
850      */
851     float alpha, beta;
852 } mkldnn_eltwise_desc_t;
853
854 /** A descriptor of a channel-wise operation. */
855 typedef struct {
856     /** The kind of primitive. Used for self identifying the primitive
857      * descriptor. Must be #mkldnn_depthwise. */
858     mkldnn_primitive_kind_t primitive_kind;
859     /** The kind of propagation. Possible values: #mkldnn_forward_training,
860      * #mkldnn_forward_inference, #mkldnn_backward, and #mkldnn_backward_data.
861      */
862     mkldnn_prop_kind_t prop_kind;
863     /** The kind of depthwise algorithm. Possible values: #mkldnn_depthwise_scale_shift
864      * #mkldnn_depthwise_prelu */
865     mkldnn_alg_kind_t alg_kind;
866         /** Source memory descriptor. */
867         mkldnn_memory_desc_t src_desc;
868         /** Destination memory descriptor. */
869         mkldnn_memory_desc_t dst_desc;
870     /** Weights memory descriptor. */
871     mkldnn_memory_desc_t weights_desc;
872     /** Bias memory descriptor. */
873     mkldnn_memory_desc_t bias_desc;
874 } mkldnn_depthwise_desc_t;
875
876 /** A descriptor of a Softmax operation. */
877 typedef struct {
878     /** The kind of primitive. Used for self-identifying the primitive
879     * descriptor. Must be #mkldnn_softmax. */
880     mkldnn_primitive_kind_t primitive_kind;
881     /** The kind of propagation. Possible values: #mkldnn_forward_training and
882      * #mkldnn_forward_inference. */
883     mkldnn_prop_kind_t prop_kind;
884     /** Source and destination memory descriptor. */
885     mkldnn_memory_desc_t data_desc;
886     /** Source and Destination of gradient memory descriptor. */
887     mkldnn_memory_desc_t diff_desc;
888     /** The axis along which to perform the softmax. */
889     int softmax_axis;
890 } mkldnn_softmax_desc_t;
891
892 /** A descriptor of a pooling operation. */
893 typedef struct {
894     /** The kind of primitive. Used for self-identifying the primitive
895      * descriptor. Must be #mkldnn_pooling. */
896     mkldnn_primitive_kind_t primitive_kind;
897     /** The kind of propagation. Possible values: #mkldnn_forward_training,
898      * #mkldnn_forward_inference, #mkldnn_backward, and #mkldnn_backward_data.
899      */
900     mkldnn_prop_kind_t prop_kind;
901     /** The kind of pooling algorithm. Possible values: #mkldnn_pooling_max and
902      * #mkldnn_pooling_avg. */
903     mkldnn_alg_kind_t alg_kind;
904     /** Source memory descriptor. */
905     mkldnn_memory_desc_t src_desc;
906     /** Source gradient memory descriptor. */
907     mkldnn_memory_desc_t diff_src_desc;
908     /** Destination memory descriptor. */
909     mkldnn_memory_desc_t dst_desc;
910     /** Destination gradient memory descriptor. */
911     mkldnn_memory_desc_t diff_dst_desc;
912     /** Pooling kernel strides for spatial dimensions. */
913     mkldnn_dims_t strides;
914     /** Pooling kernel spatial dimensions. */
915     mkldnn_dims_t kernel;
916     /** Padding in each spatial dimension. padding[0] is a padding in the
917      * beginning (@p padding_l), padding[1] is a padding in the end (@p
918      * padding_r). */
919     mkldnn_dims_t padding[2];
920     /** The kind of padding to use. */
921     mkldnn_padding_kind_t padding_kind;
922     /** The accumulator data type. Initialized automatically. */
923     mkldnn_data_type_t accum_data_type;
924 } mkldnn_pooling_desc_t;
925
926 /** A descriptor of a Local Response Normalization (LRN) operation. */
927 typedef struct {
928     /** The kind of primitive. Used for self-identifying the primitive
929      * descriptor. Must be #mkldnn_lrn. */
930     mkldnn_primitive_kind_t primitive_kind;
931     /** The kind of propagation. Possible values: #mkldnn_forward_training,
932      * #mkldnn_forward_inference, #mkldnn_backward, and #mkldnn_backward_data.
933      */
934     mkldnn_prop_kind_t prop_kind;
935     /** LRN algorithm. Possible values: #mkldnn_lrn_within_channel and
936      * #mkldnn_lrn_across_channels. */
937     mkldnn_alg_kind_t alg_kind;
938     /** Source and destination memory descriptor. */
939     mkldnn_memory_desc_t data_desc;
940     /** Source and destination gradient memory descriptor. */
941     mkldnn_memory_desc_t diff_data_desc;
942     /** The number of channels to sum over (for cross-channel LRN) or the side
943      * length of the square region to sum over (for within-channel LRN). */
944     int local_size;
945     /** LRN alpha parameter. */
946     float lrn_alpha;
947     /** LRN beta parameter. */
948     float lrn_beta;
949     /** LRN k parameter. */
950     float lrn_k;
951 } mkldnn_lrn_desc_t;
952
953 /** A descriptor of a Batch Normalization operation. */
954 typedef struct {
955     /** The kind of primitive. Used for self-identifying the primitive
956      * descriptor. Must be #mkldnn_batch_normalization. */
957     mkldnn_primitive_kind_t primitive_kind;
958     /** The kind of propagation. Possible values: #mkldnn_forward_training,
959      * #mkldnn_forward_inference, #mkldnn_backward, and #mkldnn_backward_data.
960      */
961     mkldnn_prop_kind_t prop_kind;
962     /** Source and destination memory descriptor. */
963     mkldnn_memory_desc_t data_desc;
964     /** Source and destination gradient memory descriptor. */
965     mkldnn_memory_desc_t diff_data_desc;
966     /** Scale and shift data and gradient memory descriptors.
967      *
968      * Scaleshift memory descriptor uses 2D #mkldnn_nc format[2,Channels]. 1-st
969      * dimension contains gamma parameter, 2-nd dimension contains beta
970      * parameter. */
971     mkldnn_memory_desc_t data_scaleshift_desc;
972     mkldnn_memory_desc_t diff_data_scaleshift_desc;
973     /** Mean and variance data memory descriptors.
974      *
975      * Mean and variance memory descriptors use 1D #mkldnn_x format[Channels].
976      */
977     mkldnn_memory_desc_t mean_desc;
978     mkldnn_memory_desc_t variance_desc;
979     /** Batch normalization epsilon parameter. */
980     float batch_norm_epsilon;
981     unsigned flags;
982 } mkldnn_batch_normalization_desc_t;
983
984 /** A descriptor of an inner product operation. */
985 typedef struct {
986     /** The kind of primitive. Used for self-identifying the primitive
987      * descriptor. Must be #mkldnn_inner_product. */
988     mkldnn_primitive_kind_t primitive_kind;
989     /** The kind of propagation. Possible values: #mkldnn_forward_training,
990      * #mkldnn_forward_inference, #mkldnn_backward_data,
991      * #mkldnn_backward_weights, and #mkldnn_backward_bias. */
992     mkldnn_prop_kind_t prop_kind;
993     /** Source memory descriptor. */
994     mkldnn_memory_desc_t src_desc;
995     /** Source gradient memory descriptor. */
996     mkldnn_memory_desc_t diff_src_desc;
997     /** Weights memory descriptor. */
998     mkldnn_memory_desc_t weights_desc;
999     /** Weights gradient memory descriptor. */
1000     mkldnn_memory_desc_t diff_weights_desc;
1001     /** Bias memory descriptor. */
1002     mkldnn_memory_desc_t bias_desc;
1003     /** Bias gradient memory descriptor. */
1004     mkldnn_memory_desc_t diff_bias_desc;
1005     /** Destination memory descriptor. */
1006     mkldnn_memory_desc_t dst_desc;
1007     /** Destination gradient memory descriptor. */
1008     mkldnn_memory_desc_t diff_dst_desc;
1009     /** The accumulator data type. Initialized automatically. */
1010     mkldnn_data_type_t accum_data_type;
1011 } mkldnn_inner_product_desc_t;
1012
1013 /** Flags for RNN cell. */
1014 typedef enum {
1015     mkldnn_rnn_cell_with_relu = 0x1U,
1016     mkldnn_rnn_cell_with_clipping = 0x2U,
1017 } mkldnn_rnn_cell_flags_t;
1018
1019 typedef struct {
1020     /** RNN cell kind. Must be one of #mkldnn_vanilla_rnn,
1021      * #mkldnn_vanilla_lstm, #mkldnn_vanilla_gru,
1022      * or #mkldnn_gru_linear_before_reset. */
1023     mkldnn_alg_kind_t cell_kind;
1024     /** Activation function used. Must be either #mkldnn_eltwise_relu or
1025      * #mkldnn_eltwise_tanh. */
1026     mkldnn_alg_kind_t activation_kind;
1027     /** RNN cell flags */
1028     unsigned int flags;
1029     /** @c alpha is a negative slope parameter (used only if
1030      * `(flags & #mkldnn_rnn_cell_with_relu) != 0`) */
1031     float alpha;
1032     /** clipping parameter (used only if
1033      * `(flags & #mkldnn_rnn_cell_with_clipping) != 0`) */
1034     float clipping;
1035 } mkldnn_rnn_cell_desc_t;
1036
1037 /** A direction of RNN primitive execution. */
1038 typedef enum {
1039     /* Unidirectional execution of RNN primitive from left to right. */
1040     mkldnn_unidirectional_left2right,
1041     /* Unidirectional execution of RNN primitive from right to left. */
1042     mkldnn_unidirectional_right2left,
1043     /* Bidirectional execution of RNN primitive with concatenation of the
1044      * results. */
1045     mkldnn_bidirectional_concat,
1046     /* Bidirectional execution of RNN primitive with summation of the
1047      * results. */
1048     mkldnn_bidirectional_sum,
1049     mkldnn_unidirectional = mkldnn_unidirectional_left2right,
1050 } mkldnn_rnn_direction_t;
1051
1052 /** A descriptor for an RNN operation. */
1053 typedef struct {
1054     /** The kind of primitive. Used for self-identifying the primitive
1055      * descriptor. Must be #mkldnn_rnn. */
1056     mkldnn_primitive_kind_t primitive_kind;
1057     /** The kind of propagation. Possible values: #mkldnn_forward_training,
1058      * #mkldnn_forward_inference, and #mkldnn_backward. */
1059     mkldnn_prop_kind_t prop_kind;
1060     /** The RNN cell desc. */
1061     mkldnn_rnn_cell_desc_t cell_desc;
1062     /** The direction of RNN primitive execution. */
1063     mkldnn_rnn_direction_t direction;
1064     /** Source layer memory descriptor. */
1065     mkldnn_memory_desc_t src_layer_desc;
1066     /** Source iteration memory descriptor. */
1067     mkldnn_memory_desc_t src_iter_desc;
1068     /** Weights layer memory descriptor. */
1069     mkldnn_memory_desc_t weights_layer_desc;
1070     /** Weights iteration memory descriptor. */
1071     mkldnn_memory_desc_t weights_iter_desc;
1072     /** Bias memory descriptor. */
1073     mkldnn_memory_desc_t bias_desc;
1074     /** Destination layer memory descriptor. */
1075     mkldnn_memory_desc_t dst_layer_desc;
1076     /** Destination iter memory descriptor. */
1077     mkldnn_memory_desc_t dst_iter_desc;
1078     /** Source gradient layer memory descriptor. */
1079     mkldnn_memory_desc_t diff_src_layer_desc;
1080     /** Source gradient iter memory descriptor. */
1081     mkldnn_memory_desc_t diff_src_iter_desc;
1082     /** Weights gradient layer memory descriptor. */
1083     mkldnn_memory_desc_t diff_weights_layer_desc;
1084     /** Weights gradient iter memory descriptor. */
1085     mkldnn_memory_desc_t diff_weights_iter_desc;
1086     /** Bias gradient memory descriptor. */
1087     mkldnn_memory_desc_t diff_bias_desc;
1088     /** Destination gradient layer memory descriptor. */
1089     mkldnn_memory_desc_t diff_dst_layer_desc;
1090     /** Destination gradient iteration memory descriptor. */
1091     mkldnn_memory_desc_t diff_dst_iter_desc;
1092 } mkldnn_rnn_desc_t;
1093
1094 /** A descriptor of a ROI Pooling operation. */
1095 typedef struct {
1096     /** The kind of primitive. Used for self identifying the primitive
1097      * descriptor. Must be #mkldnn_roi_pooling. */
1098     mkldnn_primitive_kind_t primitive_kind;
1099     /** The kind of propagation. Possible values: #mkldnn_forward. */
1100     mkldnn_prop_kind_t prop_kind;
1101     /** Source memory descriptor. */
1102     mkldnn_memory_desc_t* src_desc;
1103     /** Destination memory descriptor. */
1104     mkldnn_memory_desc_t dst_desc;
1105
1106     /** Primitive parameters. */
1107     int pooled_h;
1108     int pooled_w;
1109     double spatial_scale;
1110     int num_src;
1111     mkldnn_alg_kind_t alg_kind;
1112 } mkldnn_roi_pooling_desc_t;
1113
1114 /** A descriptor of a binary convolution operation. */
1115 typedef struct {
1116     /** The kind of primitive. Used for self identifying the primitive
1117      * descriptor. Must be #mkldnn_binary_convolution. */
1118     mkldnn_primitive_kind_t primitive_kind;
1119     /** The kind of propagation. Possible values: #mkldnn_forward_training,
1120      * #mkldnn_forward_inference */
1121     mkldnn_prop_kind_t prop_kind;
1122     /** The kind of the binary convolution algorithm. Possible values:
1123      * #mkldnn_binary_convolution_direct. */
1124     mkldnn_alg_kind_t alg_kind;
1125     /** Source memory descriptor. */
1126     mkldnn_memory_desc_t src_desc;
1127     /** Weights memory descriptor. */
1128     mkldnn_memory_desc_t weights_desc;
1129     /** Destination memory descriptor. */
1130     mkldnn_memory_desc_t dst_desc;
1131     /** Convolution strides in each spatial dimension. */
1132     mkldnn_dims_t strides;
1133     /** Convolution dilates in each spatial dimension. */
1134     mkldnn_dims_t dilates;
1135     /** Padding in each spatial dimension. padding[0] is a padding in the
1136      * beginning (@p padding_l), padding[1] is a padding in the end (@p
1137      * padding_r). */
1138     mkldnn_dims_t padding[2];
1139     /** The accumulator data type. Initialized automatically. */
1140     mkldnn_data_type_t accum_data_type;
1141     /** Logic value of elements in padding area */
1142     float pad_value;
1143 } mkldnn_binary_convolution_desc_t;
1144
1145 /** A descriptor of a binarization operation. */
1146 typedef struct {
1147     /** The kind of primitive. Used for self identifying the primitive
1148      * descriptor. Must be #mkldnn_binarization. */
1149     mkldnn_primitive_kind_t primitive_kind;
1150     /** The kind of propagation. Possible values: #mkldnn_forward_training,
1151      * #mkldnn_forward_inference, #mkldnn_backward, and #mkldnn_backward_data.
1152      */
1153     mkldnn_prop_kind_t prop_kind;
1154     /** The kind of binarization algorithm. Possible values: #mkldnn_binarization_depthwise */
1155     mkldnn_alg_kind_t alg_kind;
1156     /** Source memory descriptor. */
1157     mkldnn_memory_desc_t src_desc;
1158     /** Destination memory descriptor. */
1159     mkldnn_memory_desc_t dst_desc;
1160     /** Weights memory descriptor. */
1161     mkldnn_memory_desc_t weights_desc;
1162 } mkldnn_binarization_desc_t;
1163
1164 /** @} */
1165
1166 /** @addtogroup c_api_engine_types Engine
1167  * @{ */
1168
1169 /** @brief Kinds of engines. */
1170 typedef enum {
1171     /** An unspecified engine. */
1172     mkldnn_any_engine,
1173     /** CPU engine. */
1174     mkldnn_cpu,
1175 } mkldnn_engine_kind_t;
1176
1177 /** @struct mkldnn_engine
1178  * @brief An opaque structure to describe an engine. */
1179 struct mkldnn_engine;
1180 /** @brief An engine handle. */
1181 typedef struct mkldnn_engine *mkldnn_engine_t;
1182 #if 0
1183 /* FIXME: looks like this never happens */
1184 /** @brief A constant engine handle. */
1185 typedef const struct mkldnn_engine *const_mkldnn_engine_t;
1186 #endif
1187
1188 /** @} */
1189
1190 /** @addtogroup c_api_primitive_desc_iterators Primitive descriptor iterators
1191  * @{ */
1192
1193 /** @struct mkldnn_primitive_desc_iterator
1194  * @brief An opaque structure to describe a primitive descriptor iterator. */
1195 struct mkldnn_primitive_desc_iterator;
1196
1197 /** @brief A primitive descriptor iterator handle. */
1198 typedef struct mkldnn_primitive_desc_iterator
1199     *mkldnn_primitive_desc_iterator_t;
1200
1201 /** @brief A constant primitive descriptor iterator handle. */
1202 typedef const struct mkldnn_primitive_desc_iterator
1203     *const_mkldnn_primitive_desc_iterator_t;
1204
1205 /** @} */
1206
1207 /** @addtogroup c_api_primitive_descs Primitive descriptors
1208  * @{ */
1209
1210 /** @struct mkldnn_primitive_desc
1211  * @brief An opaque structure to describe a primitive descriptor. */
1212 struct mkldnn_primitive_desc;
1213
1214 /** @brief A primitive descriptor handle. */
1215 typedef struct mkldnn_primitive_desc *mkldnn_primitive_desc_t;
1216
1217 /** @brief A constant primitive descriptor handle. */
1218 typedef const struct mkldnn_primitive_desc *const_mkldnn_primitive_desc_t;
1219
1220 /** @} */
1221
1222 /** @addtogroup c_api_primitive_attr Primitive descriptor attributes
1223  * @{ */
1224
1225 /** @struct mkldnn_primitive_attr
1226  * @brief An opaque structure for primitive descriptor attributes.
1227  *
1228  * Attributes may contain:
1229  *  - rounding mode for integer based primitives (like convolution, reorders)
1230  *  - output scales (to scale the result prior to storing it to the memory)
1231  */
1232 struct mkldnn_primitive_attr;
1233
1234 /** @brief A primitive descriptor attributes handle that controls primitive
1235  * behavior. */
1236 typedef struct mkldnn_primitive_attr *mkldnn_primitive_attr_t;
1237
1238 /** @brief A constant primitive descriptor attributes handle. */
1239 typedef const struct mkldnn_primitive_attr *const_mkldnn_primitive_attr_t;
1240
1241 /** @struct mkldnn_post_ops
1242  * @brief An opaque structure for a chain of post operations.
1243  *
1244  * mkldnn_post_ops can be used to perform some (trivial) operations like
1245  * accumulation or eltwise after certain primitives like convolution.
1246  *
1247  * Post operations might be combined together, making a chain of post
1248  * operations. For instance one can configure convolution followed by
1249  * accumulation followed by eltwise. This might be especially beneficial
1250  * for residual learning blocks.
1251  *
1252  * @warning
1253  *      Of course not all combinations are supported, so the user should handle
1254  *      errors accordingly.
1255  *
1256  * Supported post operations:
1257  *  - accumulation (base primitive: convolution)
1258  *  - eltwise (base primitive: convolution)
1259  */
1260 struct mkldnn_post_ops;
1261
1262 /** @brief A post operation chain handle. */
1263 typedef struct mkldnn_post_ops *mkldnn_post_ops_t;
1264
1265 /** @brief A constant post operation chain handle. */
1266 typedef const struct mkldnn_post_ops *const_mkldnn_post_ops_t;
1267
1268 /** @} */
1269
1270 /** @addtogroup c_api_types_primitive Primitive
1271  * @{ */
1272
1273 /** @struct mkldnn_primitive
1274  * An opaque structure to describe a primitive. */
1275 struct mkldnn_primitive;
1276 /** A primitive handle. */
1277 typedef struct mkldnn_primitive *mkldnn_primitive_t;
1278 /** A constant primitive handle. */
1279 typedef const struct mkldnn_primitive *const_mkldnn_primitive_t;
1280
1281 /** A wrapper structure to specify a particular output of a primitive. */
1282 typedef struct {
1283     /** Primitive to specify the output for. */
1284     const_mkldnn_primitive_t primitive;
1285     /** Desired output index. */
1286     size_t output_index;
1287 } mkldnn_primitive_at_t;
1288
1289 /** @} */
1290
1291 /** @addtogroup c_api_types_query Queries
1292  * @{ */
1293
1294 /** Primitive descriptor query specification
1295  *
1296  * For generic function mkldnn_primitive_desc_query(), the type of result must
1297  * agree with the queried argument. The correspondence table:
1298  *      Query                        | type of result
1299  *      --------------------------------------------------------------
1300  *      #mkldnn_query_engine         | mkldnn_engine_t *
1301  *      #mkldnn_query_primitive_kind | mkldnn_primitive_kind_t *
1302  *      *_s32                        | int *
1303  *      *_s64                        | ptrdiff_t *
1304  *      *_f64                        | double *
1305  *      *_str                        | const char **
1306  *      #mkldnn_query_op_d           | const_mkldnn_op_desc_t *
1307  *      *_md                         | const mkldnn_memory_desc_t **
1308  *      *_${op}_d                    | const mkldnn_${op}_desc_t **
1309  *      *_pd                         | const_mkldnn_primitive_desc_t *
1310  *
1311  * @note
1312  *     Rule of thumb: all opaque types and structures are returned by
1313  *     reference. All numbers are returned by value.
1314  *
1315  * @warning
1316  *     All returned references point to constant objects and are valid only
1317  *     during the lifetime of the queried primitive descriptor. Returned objects
1318  *     must not be destroyed by the user. If you need to keep the object longer
1319  *     than the lifetime of the queried primitive descriptor, use
1320  *     mkldnn_primitive_desc_clone() to make a copy. */
1321 typedef enum {
1322     mkldnn_query_undef = 0,  /**< no query */
1323
1324     mkldnn_query_engine, /**< execution engine */
1325     mkldnn_query_primitive_kind, /**< primitive kind */
1326
1327     mkldnn_query_num_of_inputs_s32, /**< number of inputs expected */
1328     mkldnn_query_num_of_outputs_s32, /**< number of outputs expected */
1329
1330     mkldnn_query_time_estimate_f64, /**< runtime estimation (seconds) */
1331     mkldnn_query_memory_consumption_s64, /**< memory consumption -- extra
1332                                            (scratch) memory, additional to all
1333                                            inputs and outputs memory (bytes) */
1334
1335     mkldnn_query_impl_info_str, /**< implementation name */
1336
1337     /* memory and op descriptor section */
1338     mkldnn_query_some_d = 64, /**< stub */
1339     mkldnn_query_op_d, /**< op descriptor */
1340     mkldnn_query_memory_d, /**< memory descriptor for memory and view */
1341     mkldnn_query_convolution_d, /**< convolution descriptor */
1342     mkldnn_query_deconvolution_d, /**< deconvolution descriptor */
1343     mkldnn_query_shuffle_d, /**< shuffle descriptor */
1344     mkldnn_query_eltwise_d, /**< eltwise descriptor */
1345     mkldnn_query_softmax_d, /**< softmax descriptor */
1346     mkldnn_query_pooling_d, /**< pooling descriptor */
1347     mkldnn_query_lrn_d, /**< lrn descriptor */
1348     mkldnn_query_batch_normalization_d, /**< batch normalization descriptor */
1349     mkldnn_query_inner_product_d, /**< inner product descriptor */
1350     mkldnn_query_rnn_d, /**< rnn descriptor */
1351     mkldnn_query_roi_pooling_d, /**< roi descriptor */
1352     mkldnn_query_depthwise_d, /**< eltwise descriptor */
1353     mkldnn_query_binary_convolution_d, /**< binary convolution descriptor */
1354     mkldnn_query_binarization_d, /**< binarization descriptor */
1355
1356     /* (memory) primitive descriptor section */
1357     mkldnn_query_some_pd = 128, /**< stub */
1358     mkldnn_query_input_pd, /**< input memory primitive desc */
1359     mkldnn_query_output_pd, /**< output memory primitive desc */
1360     mkldnn_query_src_pd, /**< source memory primitive desc */
1361     mkldnn_query_diff_src_pd, /**< source gradient memory primitive desc */
1362     mkldnn_query_weights_pd, /**< weights memory primitive descriptor desc */
1363     mkldnn_query_diff_weights_pd, /**< weights grad. memory primitive desc */
1364     mkldnn_query_dst_pd, /**< destination memory primitive desc */
1365     mkldnn_query_diff_dst_pd, /**< destination grad. memory primitive desc */
1366     mkldnn_query_workspace_pd, /**< workspace memory primitive desc */
1367 } mkldnn_query_t;
1368
1369 /** @} */
1370
1371 /** @addtogroup c_api_types_stream Execution stream
1372  * @{ */
1373
1374 /** @brief Kinds of streams. */
1375 typedef enum {
1376     /** An unspecified engine. */
1377     mkldnn_any_stream,
1378     /** Eager stream. */
1379     mkldnn_eager,
1380     /** Lazy stream. */
1381     mkldnn_lazy,
1382 } mkldnn_stream_kind_t;
1383
1384 /** @struct mkldnn_stream
1385  * An opaque structure to describe an execution stream. */
1386 struct mkldnn_stream;
1387 /** An execution stream handle. */
1388 typedef struct mkldnn_stream *mkldnn_stream_t;
1389 /** A constant execution stream handle. */
1390 typedef const struct mkldnn_stream *const_mkldnn_stream_t;
1391
1392 /** @} */
1393 /** @} */
1394 /** @} */
1395
1396 #ifdef __cplusplus
1397 }
1398 #endif
1399
1400
1401 #endif