1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
20 #ifndef DOXYGEN_SHOULD_SKIP_THIS
22 /* All symbols shall be internal unless marked as MKLDNN_API */
23 #if defined _WIN32 || defined __CYGWIN__
24 # define MKLDNN_HELPER_DLL_IMPORT __declspec(dllimport)
25 # define MKLDNN_HELPER_DLL_EXPORT __declspec(dllexport)
28 # define MKLDNN_HELPER_DLL_IMPORT __attribute__ ((visibility ("default")))
29 # define MKLDNN_HELPER_DLL_EXPORT __attribute__ ((visibility ("default")))
31 # define MKLDNN_HELPER_DLL_IMPORT
32 # define MKLDNN_HELPER_DLL_EXPORT
37 # ifdef MKLDNN_DLL_EXPORTS
38 # define MKLDNN_API MKLDNN_HELPER_DLL_EXPORT
40 # define MKLDNN_API MKLDNN_HELPER_DLL_IMPORT
46 #if defined (__GNUC__)
47 # define MKLDNN_DEPRECATED __attribute__((deprecated))
48 #elif defined(_MSC_VER)
49 # define MKLDNN_DEPRECATED __declspec(deprecated)
51 # define MKLDNN_DEPRECATED
54 #include "mkldnn_types.h"
55 #include "mkldnn_version.h"
56 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
62 /** @addtogroup c_api C API
65 /** @addtogroup c_api_primitive Primitive operations
68 /** @addtogroup c_api_primitive_common Common primitive operations
71 /** Creates a primitive descriptor @p iterator for given @p op_desc, @p engine,
72 * and optionally a hint primitive descriptor from forward propagation
73 * (required for backward propagation). Pass @c NULL for forward propagation.
75 mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_iterator_create(
76 mkldnn_primitive_desc_iterator_t *iterator,
77 const_mkldnn_op_desc_t op_desc, mkldnn_engine_t engine,
78 const_mkldnn_primitive_desc_t hint_forward_primitive_desc);
80 /** Creates a primitive descriptor @p iterator for given @p op_desc, @p attr,
81 * @p engine, and optionally a hint primitive descriptor from forward
82 * propagation (required for backward propagation). Pass @c NULL for forward
85 mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_iterator_create_v2(
86 mkldnn_primitive_desc_iterator_t *iterator,
87 const_mkldnn_op_desc_t op_desc, const_mkldnn_primitive_attr_t attr,
88 mkldnn_engine_t engine,
89 const_mkldnn_primitive_desc_t hint_forward_primitive_desc);
91 /** Iterates over primitive descriptors. Returns #mkldnn_iterator_ends if no
92 * more primitive descriptors are available. */
93 mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_iterator_next(
94 mkldnn_primitive_desc_iterator_t iterator);
96 /** Fetches the current primitive descriptor.
99 * The user should delete the fetched primitive descriptor using
100 * mkldnn_primitive_desc_destroy() once it is no longer needed. */
101 mkldnn_primitive_desc_t MKLDNN_API mkldnn_primitive_desc_iterator_fetch(
102 const_mkldnn_primitive_desc_iterator_t iterator);
104 /** Deletes a primitive descriptor @p iterator */
105 mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_iterator_destroy(
106 mkldnn_primitive_desc_iterator_t iterator);
108 /** Creates a @p primitive_desc using @p op_desc, @p engine, and optionally a
109 * hint primitive descriptor from forward propagation. The call is equivalent
110 * to creating a primitive descriptor iterator, immediately fetching a
111 * primitive descriptor, and then destroying the iterator. */
112 mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_create(
113 mkldnn_primitive_desc_t *primitive_desc,
114 const_mkldnn_op_desc_t op_desc, mkldnn_engine_t engine,
115 const_mkldnn_primitive_desc_t hint_forward_primitive_desc);
117 /** Creates a @p primitive_desc using @p op_desc, @p attr, @p engine, and
118 * optionally a hint primitive descriptor from forward propagation. The call is
119 * equivalent to creating a primitive descriptor iterator, immediately fetching
120 * a primitive descriptor, and then destroying the iterator. */
121 mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_create_v2(
122 mkldnn_primitive_desc_t *primitive_desc,
123 const_mkldnn_op_desc_t op_desc, const_mkldnn_primitive_attr_t attr,
124 mkldnn_engine_t engine,
125 const_mkldnn_primitive_desc_t hint_forward_primitive_desc);
127 /** Makes a copy of a @p primitive_desc. */
128 mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_clone(
129 mkldnn_primitive_desc_t *primitive_desc,
130 const_mkldnn_primitive_desc_t existing_primitive_desc);
132 /** Returns a constant reference to the attribute of a @p primitive_desc.
135 * The user should not destroy the obtained @p attr.
138 * The lifetime of an @p attr is the same as that of a @p primitive_desc,
139 * so it is illegal to use the @p attr once @p primitive_desc has been
141 mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_get_attr(
142 const_mkldnn_primitive_desc_t primitive_desc,
143 const_mkldnn_primitive_attr_t *attr);
145 /** Deletes a @p primitive_desc. */
146 mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_destroy(
147 mkldnn_primitive_desc_t primitive_desc);
149 /** Queries primitive descriptor
151 * One of the most typical use cases is to query a convolution primitive
152 * descriptor created with source, weights, and destination formats equal
153 * to #mkldnn_any about the corresponding memory primitive descriptors
154 * (@p what equals #mkldnn_query_src_pd, #mkldnn_query_weights_pd, and
155 * #mkldnn_query_dst_pd respectively) to be able to prepare memory and
156 * create reorders if required.
158 * Another quite typical use case is to query an operation primitive
159 * descriptor for a workspace (@p what equals #mkldnn_query_workspace_pd).
160 * The returned status #mkldnn_not_required indicates that a workspace is
163 * A few other possibilities:
164 * - query a memory primitive descriptor for the underlying memory
165 * descriptor (#mkldnn_query_memory_d)
166 * - query an operation primitive descriptor for the underlying operation
167 * descriptor (#mkldnn_query_convolution_d, #mkldnn_query_eltwise_d,
168 * #mkldnn_query_rnn_d, etc.)
169 * - query an operation primitive descriptor for the implementation
170 * information string (#mkldnn_query_impl_info_str)
171 * - query an operation primitive descriptor for the number of inputs and
172 * outputs (#mkldnn_query_num_of_inputs_s32 and
173 * #mkldnn_query_num_of_outputs_s32 respectively)
175 * @sa mkldnn_query_t for more options
177 mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_query(
178 const_mkldnn_primitive_desc_t primitive_desc, mkldnn_query_t what,
179 int index, void *result);
181 /** Queries primitive descriptor for memory descriptor
183 * @returns NULL in case of any error (in particular if the queried entity is
184 * not of type mkldnn_memory_desc_t).
186 * This is just a specialized version of mkldnn_primitive_desc_query
187 * used for convenience.
189 const mkldnn_memory_desc_t MKLDNN_API *mkldnn_primitive_desc_query_memory_d(
190 const_mkldnn_primitive_desc_t primitive_desc);
192 /** Queries primitive descriptor for primitive descriptor
194 * @returns NULL in case of any error (in particular if the queried entity is
195 * not of type const_mkldnn_primitive_desc_t).
197 * This is just a specialized version of mkldnn_primitive_desc_query
198 * used for convenience.
200 * Example: Query an operation primitive descriptor for a workspace
201 * (@p what equals #mkldnn_query_workspace_pd). Returned
202 * NULL indicates that the primitive does not require a workspace.
203 * Otherwise, a user should prepare the workspace and pass it
204 * to the corresponding primitive.
206 const_mkldnn_primitive_desc_t MKLDNN_API mkldnn_primitive_desc_query_pd(
207 const_mkldnn_primitive_desc_t primitive_desc, mkldnn_query_t what,
210 /** Queries primitive descriptor for signed 32bit int
212 * @returns 0 in case of any error (in particular if the queried entity is
213 * not of type int32_t). Note that 0 might also be the actual returned
216 * This is just a specialized version of mkldnn_primitive_desc_query
217 * used for convenience.
219 int MKLDNN_API mkldnn_primitive_desc_query_s32(
220 const_mkldnn_primitive_desc_t primitive_desc, mkldnn_query_t what,
223 /** Creates a @p primitive using a @p primitive_desc descriptor and arrays of
224 * @p inputs and @p outputs. */
225 mkldnn_status_t MKLDNN_API mkldnn_primitive_create(
226 mkldnn_primitive_t *primitive,
227 const_mkldnn_primitive_desc_t primitive_desc,
228 const mkldnn_primitive_at_t *inputs,
229 const_mkldnn_primitive_t *outputs);
231 /** Retrieves a reference to the @p primitive_desc descriptor of given @p
235 * The returned object must not be destroyed by the user. The @c const
236 * qualifier of the returned object prevents such attempts. */
237 mkldnn_status_t MKLDNN_API mkldnn_primitive_get_primitive_desc(
238 const_mkldnn_primitive_t primitive,
239 const_mkldnn_primitive_desc_t *primitive_desc);
241 /** For a @p primitive, returns @p input at the @p index position. */
242 mkldnn_status_t MKLDNN_API mkldnn_primitive_get_input_at(
243 const_mkldnn_primitive_t primitive, size_t index,
244 mkldnn_primitive_at_t *input);
246 /** For a @p primitive, returns @p output at the @p index position. */
247 mkldnn_status_t MKLDNN_API mkldnn_primitive_get_output(
248 const_mkldnn_primitive_t primitive, size_t index,
249 const_mkldnn_primitive_t *output);
251 /** Deletes a @p primitive. */
252 mkldnn_status_t MKLDNN_API mkldnn_primitive_destroy(
253 mkldnn_primitive_t primitive);
255 /** Creates an #mkldnn_primitive_at_t structure from a @p primitive and @p
256 * output_index. This function only fills in the data structure
257 * and does not check whether arguments are correct. The actual error checking
258 * is done when the resulting #mkldnn_primitive_at structure is passed to a
259 * primitive creation function. */
260 mkldnn_primitive_at_t MKLDNN_API mkldnn_primitive_at(
261 const_mkldnn_primitive_t primitive, size_t output_index);
265 /** @addtogroup c_api_attributes Attributes
266 * An extension for controlling primitive behavior.
269 /** Creates an empty (default) @p attr attribute. All the parameters are set to
272 * An empty attribute is used in primitive descriptor creation whenever it
273 * is not passed explicitly, e.g. in mkldnn_primitive_desc_create.
275 mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_create(
276 mkldnn_primitive_attr_t *attr);
278 /** Makes a copy of an @p existing_attr. */
279 mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_clone(
280 mkldnn_primitive_attr_t *attr,
281 const_mkldnn_primitive_attr_t existing_attr);
283 /** Deletes an @p attr. */
284 mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_destroy(
285 mkldnn_primitive_attr_t attr);
287 /** Returns integer output rounding mode @p round_mode for a given @p attr,
288 * previously set by mkldnn_primitive_attr_set_int_output_round_mode. */
289 mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_get_int_output_round_mode(
290 const_mkldnn_primitive_attr_t attr, mkldnn_round_mode_t *round_mode);
292 /** Sets output rounding mode @p round_mode for integer operations for a given
295 * The default value is #mkldnn_round_nearest.
297 mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_int_output_round_mode(
298 mkldnn_primitive_attr_t attr, mkldnn_round_mode_t round_mode);
300 /** Returns @p count, correspondence scale @p mask, and a pointer to a constant
301 * floating point array of output @p scales for given @p attr, previously set
302 * by mkldnn_primitive_attr_set_output_scales.
305 * The @p scales array points to the internal @p attr field, so the user
306 * should not modify or destroy @p scales.
309 * The lifetime of @p scales is the same as that of the @p attr to which it
310 * belongs, so it is illegal to use @p scales after @p attr is destroyed.
312 mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_get_output_scales(
313 const_mkldnn_primitive_attr_t attr, int *count, int *mask,
314 const float **scales);
316 /** Sets output @p scales for primitive operations. The number of elements @p
317 * count and correspondence scale @p mask are stored for future use.
319 * The @p mask argument defines the correspondence between the output tensor
320 * dimensions and the @p scales array. Set the i-th bit of @p mask to 1 to use a
321 * dedicated scaling factor for each slice of the output tensor over the i-th
322 * dimension. Set @p mask to 0 to use a common scaling factor for the whole
326 * The dimension order is always native and does not depend on the actual
327 * layout used. Examples:
328 * - 2D dimensional data the order of dimensions is always: (n, c)
329 * - 4D dimensional data the order is always: (n, c, h, w)
330 * - 5D dimensional weights the order is always: (g, oc, ic, kh, kw)
334 * int mb = 32, oc = 32, oh = 14, ow = 14; // convolution output params
335 * float scales[oc] = { ... }; // unique output scales per output channel
336 * int oc_dim = 1; // mb_dim = 0, channel_dim = 1, height_dim = 2, ...
338 * mkldnn_convolution_desc_t cd; // create & configure convolution op_desc
340 * mkldnn_primitive_attr_t attr;
341 * mkldnn_primitive_attr_create(&attr); // create default attributes
342 * mkldnn_primitive_attr_set_output_scales(attr, oc, 1 << oc_dim, scales);
344 * mkldnn_primitive_desc_t cpd;
345 * mkldnn_primitive_desc_create_v2(&cpd, &cd, attr, NULL);
349 * There is no way to check that @p count corresponds to @p mask until an
350 * actual primitive descriptor is created, so it is the user's
351 * responsibility to set proper values. The following formula must hold:
353 * \f[count = \prod\limits_{d \in mask} output.dims[d]\f]
355 mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_output_scales(
356 mkldnn_primitive_attr_t attr, int count, int mask,
357 const float *scales);
359 /** Returns @p post_ops for given @p attr.
362 * @p post_ops points to the internal @p attr field, so the user should not
363 * modify or destroy @p post_ops. Also, the lifetime of @p post_ops is the
364 * same as that of the @p attr it belongs to, so it is illegal to use @p
365 * post_ops after @p attr has been destroyed.
367 mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_get_post_ops(
368 const_mkldnn_primitive_attr_t attr, const_mkldnn_post_ops_t *post_ops);
370 /** Sets configured @p post_ops to an attribute @p attr for future use (when
371 * primitive descriptor is being created).
374 * At this point in time, there is no way to check whether the primitive
375 * descriptor does or does not support a given sequence of post operations.
376 * Therefore the user should handle an error that might occur at the
377 * mkldnn_primitive_desc_create call.
379 mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_post_ops(
380 mkldnn_primitive_attr_t attr, const_mkldnn_post_ops_t post_ops);
382 /** @addtogroup c_api_attributes_post_ops Sequence of post operations
383 * An extension for performing extra operations after a base operation.
386 /** Creates an empty sequence of post operations @p post_ops. */
387 mkldnn_status_t MKLDNN_API mkldnn_post_ops_create(mkldnn_post_ops_t *post_ops);
389 /** Deletes a @p post_ops sequence. */
390 mkldnn_status_t MKLDNN_API mkldnn_post_ops_destroy(mkldnn_post_ops_t post_ops);
392 /** Returns the @p length of post operations for given @p post_ops. */
393 int MKLDNN_API mkldnn_post_ops_len(const_mkldnn_post_ops_t post_ops);
395 /** Returns the type of post operation with index @p index in given
396 * @p post_ops. In case of error, returns #mkldnn_undefined_primitive. */
397 mkldnn_primitive_kind_t MKLDNN_API mkldnn_post_ops_get_kind(
398 const_mkldnn_post_ops_t post_ops, int index);
400 /** Appends accumulation (sum) post operation to the @p post_ops. Prior to
401 * accumulating the result, the previous value would be multiplied by @p scale.
403 * The kind of this post operation is #mkldnn_sum.
405 * This feature might improve performance for cases like residual learning
406 * blocks, where the result of convolution is accumulated to the previously
407 * computed activations. The parameter @p scale might be extreme for the
408 * integer-based computations when the result and previous activations have
409 * different logical scaling factors.
411 * In the simplest case when the accumulation is the only post operation, the
412 * computations would be:
413 * dst[] <- scale * dst[] + op(...) // instead of dst[] <- op(...)
416 * This post operation (as well as all the others) disregards the original
417 * layout of the destination; that is, the layout of the original
418 * destination is expected to be the same as the layout of the stored
421 mkldnn_status_t MKLDNN_API mkldnn_post_ops_append_sum(
422 mkldnn_post_ops_t post_ops, float scale);
424 /** Gets the parameters of the accumulation (sum) post operation with index
425 * @p index in the sequence of @p post_ops.
428 * If index @p index would not correspond to the accumulation post
429 * operation, the function returns #mkldnn_invalid_arguments.
431 mkldnn_status_t MKLDNN_API mkldnn_post_ops_get_params_sum(
432 const_mkldnn_post_ops_t post_ops, int index, float *scale);
434 /** Appends eltwise post operation to the @p post_ops with given parameters
435 * @p kind, @p alpha, and @p beta (@sa mkldnn_eltwise_forward_desc_init and
436 * mkldnn_eltwise_desc_t).
438 * The kind of this post operation is #mkldnn_eltwise.
440 * In the simplest case when the eltwise is the only post operation, the
441 * computations would be:
442 * dst[] <- scale * eltwise_op ( op(...) ) // instead of dst[] <- op(...)
443 * where eltwise_op is configured with the given parameters.
445 mkldnn_status_t MKLDNN_API mkldnn_post_ops_append_eltwise(
446 mkldnn_post_ops_t post_ops, float scale, mkldnn_alg_kind_t alg,
447 float alpha, float beta);
449 /** Gets the eltwise parameters of the post operation with index @p index in
450 * the sequence of @p post_ops.
452 mkldnn_status_t MKLDNN_API mkldnn_post_ops_get_params_eltwise(
453 const_mkldnn_post_ops_t post_ops, int index, float *scale,
454 mkldnn_alg_kind_t *alg, float *alpha, float *beta);
456 /** Appends depthwise post operation to the @p post_ops with given parameters
457 * @p kind, @p weights and @p bias (@sa mkldnn_depthwise_forward_desc_init and
458 * mkldnn_depthwise_desc_t).
460 * The kind of this post operation is #mkldnn_depthwise.
462 * In the simplest case when the depthwise is the only post operation, the
463 * computations would be:
464 * dst[] <- scale * depthwise_op ( op(...) ) // instead of dst[] <- op(...)
465 * where depthwise_op is configured with given parameters.
467 mkldnn_status_t MKLDNN_API mkldnn_post_ops_append_depthwise(
468 mkldnn_post_ops_t post_ops, mkldnn_alg_kind_t alg,
469 const float* weights_data, const float* biases_data);
471 /** Gets the depthwise parameters of the post operation with index @p index in
472 * the sequence of @p post_ops.
474 mkldnn_status_t MKLDNN_API mkldnn_post_ops_get_params_depthwise(
475 const_mkldnn_post_ops_t post_ops, int index,
476 mkldnn_alg_kind_t *alg, const float** weights_data,
477 const float** biases_data);
479 /** Appends DW convolution post operation to the @p post_ops with given parameters
480 * @p weights and @p bias.
482 * The kind of this post operation is #mkldnn_convolution.
484 mkldnn_status_t MKLDNN_API mkldnn_post_ops_append_dw_conv(
485 mkldnn_post_ops_t post_ops, int in_h, int in_w, int ker_h, int ker_w, int str_h, int str_w,
486 const float* weights_data, const float* biases_data);
488 /** Gets the DW convolution parameters of the post operation with index @p index in
489 * the sequence of @p post_ops.
491 mkldnn_status_t MKLDNN_API mkldnn_post_ops_get_params_dw_conv(
492 const_mkldnn_post_ops_t post_ops, int index, int* in_h, int* in_w,
493 int* ker_h, int* ker_w, int* str_h, int* str_w, const float** weights_data,
494 const float** biases_data);
496 /** Appends binarization post operation to the @p post_ops with given parameters
497 * @p kind and @p weights (@sa mkldnn_binarization_forward_desc_init and
498 * mkldnn_binarization_desc_t).
500 * The kind of this post operation is #mkldnn_binarization.
502 * In the simplest case when the binarization is the only post operation, the
503 * computations would be:
504 * dst[] <- binarization_op ( op(...) ) // instead of dst[] <- op(...)
505 * where binarization_op is configured with given parameters.
507 mkldnn_status_t MKLDNN_API mkldnn_post_ops_append_binarization(
508 mkldnn_post_ops_t post_ops, mkldnn_alg_kind_t alg, const float* weights_data);
510 /** Gets the binarization parameters of the post operation with index @p index in
511 * the sequence of @p post_ops.
513 mkldnn_status_t MKLDNN_API mkldnn_post_ops_get_params_binarization(
514 const_mkldnn_post_ops_t post_ops, int index,
515 mkldnn_alg_kind_t *alg, const float** weights_data);
521 /** @addtogroup c_api_memory Memory
522 * A primitive to describe and store data.
524 * The library supports various data types and formats. Memory hierarchy
525 * consists of three levels of abstraction:
526 * 1. **Memory descriptor** -- engine agnostic logical description of data
527 * (number of dimensions, dimensions themselves, and data type), and
528 * optionally the format/layout that describes the physical representation
529 * of data in memory. If the format is not known yet, one can pass
530 * #mkldnn_any. This approach is used to allow compute-intensive
531 * primitives to specify the most appropriate format on their own with
532 * users required to reorder the data if the incoming format doesn't match
533 * the primitive's selection. Memory descriptor can be created with the
534 * mkldnn_memory_desc_init() function or by directly filling the
535 * mkldnn_memory_desc_t structure. The latter requires deep knowledge of
536 * how the physical data representation is mapped to the structure. The
537 * @ref understanding_memory_formats topic should shed some light on that.
538 * 2. **Memory primitive descriptor** -- logical description of data that is
539 * fully defined; that is, it cannot contain #mkldnn_any as a format. It
540 * also has the engine specified. A memory primitive descriptor is created
541 * by calling mkldnn_memory_primitive_desc_create() with two arguments: an
542 * mkldnn_memory_desc_t and an mkldnn_engine_t. It has the same type as
543 * other primitive descriptors and can be:
544 * - queried to return the underlying memory descriptor using
545 * mkldnn_primitive_desc_query() and
546 * mkldnn_primitive_desc_query_memory_d().
547 * - compared with another memory primitive descriptor using
548 * mkldnn_memory_primitive_desc_equal(). This is especially useful when
549 * checking whether a primitive requires reorder from the user's data
550 * format to the primitive's format.
551 * - queried to return the size of the data using
552 * mkldnn_memory_primitive_desc_get_size(). As described in
553 * @ref understanding_memory_formats, the size of data sometimes cannot
554 * be computed as the product of dimensions times the size of the data
555 * type. So users are encouraged to use this function for better code
557 * 3. **Memory primitive** or simply **memory** -- a pseudo-primitive that is
558 * defined by a memory primitive descriptor and a handle to the data
559 * itself. (In the case of CPU engine, the handle is simply a pointer to
560 * @c void.) The data handle can be queried using
561 * mkldnn_memory_get_data_handle() and set using
562 * mkldnn_memory_set_data_handle(). The latter function always sets the
563 * memory in the padding region to zero, which is the invariant maintained
564 * by all the primitives in Intel MKL-DNN. See
565 * @ref understanding_memory_formats for more details.
566 * A memory primitive can be created using mkldnn_primitive_create() with
567 * empty inputs and outputs. In this case, the memory primitive's data
568 * handle must be set manually using mkldnn_memory_set_data_handle().
570 * Along with ordinary memory with all dimensions being positive, Intel
571 * MKL-DNN supports *zero-volume* memory with one or more dimensions set to
572 * zero. This is to support the NumPy\* convention.
573 * If a *zero-volume* memory is passed to a primitive, the primitive does
574 * not perform any computations on this memory. For example:
575 * - Convolution with `(0 batch, 3 input channels, 13 height, 13 width)`
576 * source and `(16 output channels, 3 inputs, channel, 3 height, 3 width)`
577 * weights would produce `(0 batch, 16 ouput channels, 11 height, 11 width)`
578 * destination (assuming strides are `1` and paddings are zero) and perform
579 * zero multiply-add operations.
580 * - Concatenation of three memories of shapes `(3, 4, 13, 13)`,
581 * `(3, 0, 13, 13)`, and `(3, 1, 13, 13)` along the second axis would produce
582 * the output of the shape `(3, 5, 13, 13)`, effectively ignoring the second
583 * input (however, if the user created a concatenation primitive descriptor
584 * with three inputs they should also provide all three memories to the
585 * concatenation primitive, including the one with zero second dimension).
586 * - However, Intel MKL-DNN would return an error when attempting to create a
587 * convolution with *zero-volume* memory passed for weights because such a
588 * convolution is not well-defined:
590 * dst(1, 16, 11, 11) <-- src(1, 0, 13, 13) (*) wei(16, 0, 3, 3)
592 * Should the values in the destination be zeroes or just not accessed at
593 * all? Moreover, backward pass w.r.t. weights in such cases is also not
596 * Data handle of *zero-volume* memory is never accessed and hence can be
597 * unset (NULL in case of CPU engine).
599 * @sa @ref understanding_memory_formats
602 /** Initializes a @p memory_desc memory descriptor using @p ndims, @p dims, @p
603 * data_type, and data @p format. @p format can be #mkldnn_any, which means
604 * that specific data layouts are not permitted. */
605 mkldnn_status_t MKLDNN_API mkldnn_memory_desc_init(
606 mkldnn_memory_desc_t *memory_desc, int ndims, const mkldnn_dims_t dims,
607 mkldnn_data_type_t data_type, mkldnn_memory_format_t format);
609 /** Creates a @p memory_primitive_desc memory primitive descriptor using @p
610 * memory_desc and @p engine. @p memory_desc cannot be uncertain; that is, it
611 * cannot be initialized with #mkldnn_any. */
612 mkldnn_status_t MKLDNN_API mkldnn_memory_primitive_desc_create(
613 mkldnn_primitive_desc_t *memory_primitive_desc,
614 const mkldnn_memory_desc_t *memory_desc, mkldnn_engine_t engine);
616 /** Creates a @p view_primitive_desc for a given @p memory_primitive_desc, with
617 * @p dims sizes and @p offsets offsets. May fail if the format used does not
618 * allow obtaining the desired view. In this case, consider using the extract
620 mkldnn_status_t MKLDNN_API mkldnn_view_primitive_desc_create(
621 mkldnn_primitive_desc_t *view_primitive_desc,
622 const_mkldnn_primitive_desc_t memory_primitive_desc,
623 const mkldnn_dims_t dims, const mkldnn_dims_t offsets);
625 /** Compares two descriptors of memory primitives.
626 * @return 1 if the descriptors are the same.
627 * @return 0 if the descriptors are different.
629 * Use this function to identify whether a reorder is required for the memory
630 * primitives. @p lhs and @p rhs must be either memory or view primitive
632 int MKLDNN_API mkldnn_memory_primitive_desc_equal(
633 const_mkldnn_primitive_desc_t lhs,
634 const_mkldnn_primitive_desc_t rhs);
636 /** Returns the size (in bytes) that is required for given @p
637 * memory_primitive_desc */
639 size_t MKLDNN_API mkldnn_memory_primitive_desc_get_size(
640 const_mkldnn_primitive_desc_t memory_primitive_desc);
642 /** For a @p memory primitive, returns the data @p handle. For the CPU engine,
643 * the data handle is a pointer to the actual data. */
645 mkldnn_status_t MKLDNN_API mkldnn_memory_get_data_handle(
646 const_mkldnn_primitive_t memory, void **handle);
648 /** For a @p memory primitive, sets the data @p handle. */
649 mkldnn_status_t MKLDNN_API mkldnn_memory_set_data_handle(
650 mkldnn_primitive_t memory, void *handle);
654 /** @addtogroup c_api_reorder Reorder
655 * A primitive to copy data between memory formats.
658 /** Initializes a @p reorder_primitive_desc using descriptors of @p input and
659 * @p output memory primitives.
662 * - input (#mkldnn_query_input_pd, 0)
665 * - output (#mkldnn_query_output_pd, 0)
667 mkldnn_status_t MKLDNN_API mkldnn_reorder_primitive_desc_create(
668 mkldnn_primitive_desc_t *reorder_primitive_desc,
669 const_mkldnn_primitive_desc_t input,
670 const_mkldnn_primitive_desc_t output);
672 /** Initializes a @p reorder_primitive_desc using an @p attr attribute and
673 * descriptors of @p input and @p output memory primitives.
676 * - input (#mkldnn_query_input_pd, 0)
679 * - output (#mkldnn_query_output_pd, 0)
681 mkldnn_status_t MKLDNN_API mkldnn_reorder_primitive_desc_create_v2(
682 mkldnn_primitive_desc_t *reorder_primitive_desc,
683 const_mkldnn_primitive_desc_t input,
684 const_mkldnn_primitive_desc_t output,
685 const_mkldnn_primitive_attr_t attr);
689 /** @addtogroup c_api_concat Concat
690 * A primitive to concatenate data by arbitrary dimension.
693 /** Creates out-of-place @p concat_primitive_desc for concatenation of @p n
694 * inputs by @p concat_dimension with resulting @p output_desc memory
695 * descriptor. @p output_desc can be NULL or specified with the #mkldnn_any
696 * format -- in this case, the appropriate memory format would be chosen
700 * - input 0 (#mkldnn_query_input_pd, 0)
701 * - input 1 (#mkldnn_query_input_pd, 1)
703 * - input @p n - 1 (#mkldnn_query_input_pd, @p n - 1)
706 * - output (#mkldnn_query_output_pd, 0)
708 mkldnn_status_t MKLDNN_API mkldnn_concat_primitive_desc_create(
709 mkldnn_primitive_desc_t *concat_primitive_desc,
710 const mkldnn_memory_desc_t *output_desc, int n, int concat_dimension,
711 const_mkldnn_primitive_desc_t *input_pds);
714 /** Creates in-place @p concat_primitive_desc for given @p n and @p inputs
715 * memory primitive descriptors along @p concat_dimension. All inputs must have
716 * the same memory format. Output memory format would be the same. Likewise, the
717 * view_primitive_desc_create call may fail if the memory format of the inputs
718 * does not allow in-place concatenation for the given sizes.
720 * @note This primitive is more like a synchronization stub for concatenation,
721 * because concat_inplace performs no operation during execution.
723 * @note Because no operation occurs, the user must ensure the input. */
724 mkldnn_status_t MKLDNN_API mkldnn_concat_inplace_by_input_primitive_desc_create(
725 mkldnn_primitive_desc_t *concat_primitive_desc,
726 int n, int concat_dimension, const_mkldnn_primitive_desc_t *inputs);
728 /** Creates in-place @p concat_primitive_desc for given @p output memory
729 * descriptor and @n inputs with @p sizes sizes along @p concat_dimension.
730 * Unlike out-of-place concatenation, @p output must be fully defined here.
731 * Likewise, the view_primitive_desc_create call may fail if the given memory
732 * format does not allow inplace concatenation for the given sizes.
734 * @note This primitive is more like a synchronization stub for concatenation,
735 * because concat_inplace performs no operation during execution. */
736 mkldnn_status_t MKLDNN_API mkldnn_concat_inplace_by_output_primitive_desc_create(
737 mkldnn_primitive_desc_t *concat_primitive_desc,
738 const mkldnn_primitive_desc_t output, int n, int concat_dimension,
744 /** @addtogroup c_api_sum Sum
745 * A primitive to sum data.
748 /** Creates out-of-place @p sum_primitive_desc for sum of @p n
749 * inputs multiplied by scale with resulting @p output_desc memory
750 * descriptor. @p output_desc can be NULL or specified with the #mkldnn_any
751 * format -- in this case, the appropriate memory format would be chosen
755 * - input 0 (#mkldnn_query_input_pd, 0)
756 * - input 1 (#mkldnn_query_input_pd, 1)
758 * - input @p n - 1 (#mkldnn_query_input_pd, @p n - 1)
761 * - output (#mkldnn_query_output_pd, 0)
763 mkldnn_status_t MKLDNN_API mkldnn_sum_primitive_desc_create(
764 mkldnn_primitive_desc_t *sum_primitive_desc,
765 const mkldnn_memory_desc_t *output_desc, int n, const float *scales,
766 const_mkldnn_primitive_desc_t *input_pds);
770 /** @addtogroup c_api_convolution Convolution
771 * A primitive to compute convolution using different algorithms.
773 * \f[dst[n][oc][oh][ow] =
774 * \sum_{kw=0}^{KW}\sum_{kh=0}^{KH}\sum_{ic=0}^{IC}
775 * src[n][ic][oh \cdot s_h - p_l[0] + kh][ow \cdot s_w - p_r[1] + kw]
776 * \cdot weights[g][oc][ic][kh][kw]
779 * where size of output spatial domain is given by
780 * \f$ OH = \left\lfloor{\frac{IH - KH + p_l[0] + p_r[0]}{s_h}}
781 * \right\rfloor + 1\f$,
782 * \f$ OW = \left\lfloor{\frac{IW - KW + p_l[1] + p_r[1]}{s_w}}
783 * \right\rfloor + 1\f$,
785 * and summation is carried over input channels \f$ic\f$ in
786 * group \f$g\f$, and \f$s_h, s_w\f$ are @p strides and
787 * \f$p_l, p_r\f$ are @p padding_l and @p padding_r.
790 /** Initializes a convolution descriptor @p conv_desc for forward propagation
791 * using @p prop_kind (possible values are #mkldnn_forward_training and
792 * #mkldnn_forward_inference), @p alg_kind, memory descriptors, @p strides, @p
793 * padding_l, @p padding_r, and @p padding_kind. In order to create a
794 * convolution without bias, @p bias_desc should either be @c NULL or point to
795 * a descriptor with memory format equal to #mkldnn_format_undef.
797 * @note If @p padding_r is @c NULL, the padding is supposed to be symmetric.
799 * @note Memory descriptors are allowed to be initialized with #mkldnn_any
800 * value of @p format_kind.
803 * - src (#mkldnn_query_src_pd, 0)
804 * - weights (#mkldnn_query_weights_pd, 0)
805 * - bias (#mkldnn_query_weights_pd, 1), if created with bias
808 * - dst (#mkldnn_query_dst_pd, 0)
810 mkldnn_status_t MKLDNN_API mkldnn_convolution_forward_desc_init(
811 mkldnn_convolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind,
812 mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc,
813 const mkldnn_memory_desc_t *weights_desc,
814 const mkldnn_memory_desc_t *bias_desc,
815 const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides,
816 const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r,
817 mkldnn_padding_kind_t padding_kind);
819 /** Initializes a dilated convolution descriptor @p conv_desc for forward
820 * propagation using @p prop_kind (possible values are #mkldnn_forward_training
821 * and #mkldnn_forward_inference), @p alg_kind, memory descriptors, @p strides,
822 * @p dilates, @p padding_l, @p padding_r, and @p padding_kind.
823 * In order to create a dilated convolution without bias, @p bias_desc
824 * should either be @c NULL or point to a descriptor with memory format equal
825 * to #mkldnn_format_undef.
827 * @note If @p padding_r is @c NULL, the padding is supposed to be symmetric.
829 * @note Memory descriptors are allowed to be initialized with #mkldnn_any
830 * value of @p format_kind.
833 * - src (#mkldnn_query_src_pd, 0)
834 * - weights (#mkldnn_query_weights_pd, 0)
835 * - bias (#mkldnn_query_weights_pd, 1), if created with bias
838 * - dst (#mkldnn_query_dst_pd, 0)
840 mkldnn_status_t MKLDNN_API mkldnn_dilated_convolution_forward_desc_init(
841 mkldnn_convolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind,
842 mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc,
843 const mkldnn_memory_desc_t *weights_desc,
844 const mkldnn_memory_desc_t *bias_desc,
845 const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides,
846 const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l,
847 const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind);
849 /** Initializes a convolution descriptor @p conv_desc for backward propagation
850 * with respect to data using @p alg_kind, memory descriptors, @p strides, @p
851 * padding_l, @p padding_r, and @p padding_kind.
853 * @note Memory descriptors are allowed to be initialized with #mkldnn_any
854 * value of @p format_kind.
857 * - diff_dst (#mkldnn_query_diff_dst_pd, 0)
858 * - weights (#mkldnn_query_weights_pd, 0)
861 * - diff_src (#mkldnn_query_diff_src_pd, 0)
863 mkldnn_status_t MKLDNN_API mkldnn_convolution_backward_data_desc_init(
864 mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind,
865 const mkldnn_memory_desc_t *diff_src_desc,
866 const mkldnn_memory_desc_t *weights_desc,
867 const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides,
868 const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r,
869 mkldnn_padding_kind_t padding_kind);
871 /** Initializes a dilated convolution descriptor @p conv_desc for backward
872 * propagation with respect to data using @p alg_kind, memory descriptors, @p
873 * strides, @p dilates @p padding_l, @p padding_r, and @p padding_kind.
875 * @note Memory descriptors are allowed to be initialized with #mkldnn_any
876 * value of @p format_kind.
879 * - diff_dst (#mkldnn_query_diff_dst_pd, 0)
880 * - weights (#mkldnn_query_weights_pd, 0)
883 * - diff_src (#mkldnn_query_diff_src_pd, 0)
885 mkldnn_status_t MKLDNN_API mkldnn_dilated_convolution_backward_data_desc_init(
886 mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind,
887 const mkldnn_memory_desc_t *diff_src_desc,
888 const mkldnn_memory_desc_t *weights_desc,
889 const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides,
890 const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l,
891 const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind);
893 /** Initializes a convolution descriptor @p conv_desc for backward propagation
894 * with respect to weights using @p alg_kind, memory descriptors, @p strides,
895 * @p padding_l, @p padding_r, and @p padding_kind.
897 * @note Memory descriptors are allowed to be initialized with #mkldnn_any
898 * value of @p format_kind.
901 * - src (#mkldnn_query_src_pd, 0)
902 * - diff_dst (#mkldnn_query_diff_dst_pd, 0)
905 * - diff_weights (#mkldnn_query_diff_weights_pd, 0)
906 * - diff_bias (#mkldnn_query_diff_weights_pd, 1), if created with bias
908 mkldnn_status_t MKLDNN_API mkldnn_convolution_backward_weights_desc_init(
909 mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind,
910 const mkldnn_memory_desc_t *src_desc,
911 const mkldnn_memory_desc_t *diff_weights_desc,
912 const mkldnn_memory_desc_t *diff_bias_desc,
913 const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides,
914 const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r,
915 mkldnn_padding_kind_t padding_kind);
917 /** Initializes a convolution descriptor @p conv_desc for backward propagation
918 * with respect to weights using @p alg_kind, memory descriptors, @p strides,
919 * @p dilates @p padding_l, @p padding_r, and @p padding_kind.
921 * @note Memory descriptors are allowed to be initialized with #mkldnn_any
922 * value of @p format_kind.
925 * - src (#mkldnn_query_src_pd, 0)
926 * - diff_dst (#mkldnn_query_diff_dst_pd, 0)
929 * - diff_weights (#mkldnn_query_diff_weights_pd, 0)
930 * - diff_bias (#mkldnn_query_diff_weights_pd, 1), if created with bias
932 mkldnn_status_t MKLDNN_API
933 mkldnn_dilated_convolution_backward_weights_desc_init(
934 mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind,
935 const mkldnn_memory_desc_t *src_desc,
936 const mkldnn_memory_desc_t *diff_weights_desc,
937 const mkldnn_memory_desc_t *diff_bias_desc,
938 const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides,
939 const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l,
940 const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind);
944 /** @addtogroup c_api_deconvolution Deconvolution
945 * A primitive to compute deconvolution using different algorithms.
950 /** Initializes a deconvolution descriptor @p deconv_desc for forward
951 * propagation using @p prop_kind (possible values are #mkldnn_forward_training
952 * and #mkldnn_forward_inference), @p alg_kind, memory descriptors, @p strides,
953 * @p padding_l, @p padding_r, and @p padding_kind. In order to create a
954 * deconvolution without bias, @p bias_desc should either be @c NULL or point to
955 * a descriptor with memory format equal to #mkldnn_format_undef.
957 * @note If @p padding_r is @c NULL, the padding is supposed to be symmetric.
959 * @note Memory descriptors are allowed to be initialized with #mkldnn_any
960 * value of @p format_kind.
963 * - src (#mkldnn_query_src_pd, 0)
964 * - weights (#mkldnn_query_weights_pd, 0)
965 * - bias (#mkldnn_query_weights_pd, 1), if created with bias
968 * - dst (#mkldnn_query_dst_pd, 0)
970 mkldnn_status_t MKLDNN_API mkldnn_deconvolution_forward_desc_init(
971 mkldnn_deconvolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind,
972 mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc,
973 const mkldnn_memory_desc_t *weights_desc,
974 const mkldnn_memory_desc_t *bias_desc,
975 const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides,
976 const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r,
977 mkldnn_padding_kind_t padding_kind);
979 /** Initializes a dilated deconvolution descriptor @p deconv_desc for forward
980 * propagation using @p prop_kind (possible values are #mkldnn_forward_training
981 * and #mkldnn_forward_inference), @p alg_kind, memory descriptors, @p strides,
982 * @p dilates, @p padding_l, @p padding_r, and @p padding_kind. In order to
983 * create a dilated deconvolution without bias, @p bias_desc should either be
984 * @c NULL or point to a descriptor with memory format equal to
985 * #mkldnn_format_undef.
987 * @note If @p padding_r is @c NULL, the padding is supposed to be symmetric.
989 * @note Memory descriptors are allowed to be initialized with #mkldnn_any
990 * value of @p format_kind.
993 * - src (#mkldnn_query_src_pd, 0)
994 * - weights (#mkldnn_query_weights_pd, 0)
995 * - bias (#mkldnn_query_weights_pd, 1), if created with bias
998 * - dst (#mkldnn_query_dst_pd, 0)
1000 mkldnn_status_t MKLDNN_API mkldnn_dilated_deconvolution_forward_desc_init(
1001 mkldnn_deconvolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind,
1002 mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc,
1003 const mkldnn_memory_desc_t *weights_desc,
1004 const mkldnn_memory_desc_t *bias_desc,
1005 const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides,
1006 const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l,
1007 const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind);
1009 /** Initializes a deconvolution descriptor @p conv_desc for backward propagation
1010 * with respect to data using @p alg_kind, memory descriptors, @p strides, @p
1011 * padding_l, @p padding_r, and @p padding_kind.
1013 * @note Memory descriptors are allowed to be initialized with #mkldnn_any
1014 * value of @p format_kind.
1017 * - diff_dst (#mkldnn_query_diff_dst_pd, 0)
1018 * - weights (#mkldnn_query_weights_pd, 0)
1021 * - diff_src (#mkldnn_query_diff_src_pd, 0)
1023 mkldnn_status_t MKLDNN_API mkldnn_deconvolution_backward_data_desc_init(
1024 mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind,
1025 const mkldnn_memory_desc_t *diff_src_desc,
1026 const mkldnn_memory_desc_t *weights_desc,
1027 const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides,
1028 const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r,
1029 mkldnn_padding_kind_t padding_kind);
1031 /** Initializes a dilated deconvolution descriptor @p conv_desc for backward
1032 * propagation with respect to data using @p alg_kind, memory descriptors, @p
1033 * strides, @p dilates, @p padding_l, @p padding_r, and @p padding_kind.
1035 * @note Memory descriptors are allowed to be initialized with #mkldnn_any
1036 * value of @p format_kind.
1039 * - diff_dst (#mkldnn_query_diff_dst_pd, 0)
1040 * - weights (#mkldnn_query_weights_pd, 0)
1043 * - diff_src (#mkldnn_query_diff_src_pd, 0)
1045 mkldnn_status_t MKLDNN_API mkldnn_dilated_deconvolution_backward_data_desc_init(
1046 mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind,
1047 const mkldnn_memory_desc_t *diff_src_desc,
1048 const mkldnn_memory_desc_t *weights_desc,
1049 const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides,
1050 const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l,
1051 const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind);
1053 /** Initializes a deconvolution descriptor @p conv_desc for backward propagation
1054 * with respect to weights using @p alg_kind, memory descriptors, @p strides,
1055 * @p padding_l, @p padding_r, and @p padding_kind.
1057 * @note Memory descriptors are allowed to be initialized with #mkldnn_any
1058 * value of @p format_kind.
1061 * - src (#mkldnn_query_src_pd, 0)
1062 * - diff_dst (#mkldnn_query_diff_dst_pd, 0)
1065 * - diff_weights (#mkldnn_query_diff_weights_pd, 0)
1066 * - diff_bias (#mkldnn_query_diff_weights_pd, 1), if created with bias
1068 mkldnn_status_t MKLDNN_API mkldnn_deconvolution_backward_weights_desc_init(
1069 mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind,
1070 const mkldnn_memory_desc_t *src_desc,
1071 const mkldnn_memory_desc_t *diff_weights_desc,
1072 const mkldnn_memory_desc_t *diff_bias_desc,
1073 const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides,
1074 const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r,
1075 mkldnn_padding_kind_t padding_kind);
1077 /** Initializes a dilated deconvolution descriptor @p conv_desc for backward
1078 * propagation with respect to weights using @p alg_kind, memory descriptors,
1079 * @p strides, @p dilates, @p padding_l, @p padding_r, and @p padding_kind.
1081 * @note Memory descriptors are allowed to be initialized with #mkldnn_any
1082 * value of @p format_kind.
1085 * - src (#mkldnn_query_src_pd, 0)
1086 * - diff_dst (#mkldnn_query_diff_dst_pd, 0)
1089 * - diff_weights (#mkldnn_query_diff_weights_pd, 0)
1090 * - diff_bias (#mkldnn_query_diff_weights_pd, 1), if created with bias
1092 mkldnn_status_t MKLDNN_API mkldnn_dilated_deconvolution_backward_weights_desc_init(
1093 mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind,
1094 const mkldnn_memory_desc_t *src_desc,
1095 const mkldnn_memory_desc_t *diff_weights_desc,
1096 const mkldnn_memory_desc_t *diff_bias_desc,
1097 const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides,
1098 const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l,
1099 const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind);
1103 /** @addtogroup c_api_shuffle Shuffle
1104 * A primitive to shuffle data along the axis.
1107 /** Initializes a @p shuffle_desc for forward propagation using @p prop_kind,
1108 * memory descriptor @p data_desc, @p axis, and @p group_size.
1111 * - src (#mkldnn_query_src_pd, 0)
1114 * - dst (#mkldnn_query_dst_pd, 0)
1117 mkldnn_status_t MKLDNN_API mkldnn_shuffle_forward_desc_init(
1118 mkldnn_shuffle_desc_t *shuffle_desc, mkldnn_prop_kind_t prop_kind,
1119 const mkldnn_memory_desc_t *data_desc, int axis, int group_size);
1121 /** Initializes a @p shuffle_desc for backward propagation using memory
1122 * descriptor @p diff_data_desc, @p axis, and @p group_size.
1126 * - diff_dst (#mkldnn_query_diff_dst_pd, 0)
1129 * - diff_src (#mkldnn_query_diff_src_pd, 0)
1132 mkldnn_status_t MKLDNN_API mkldnn_shuffle_backward_desc_init(
1133 mkldnn_shuffle_desc_t *shuffle_desc,
1134 const mkldnn_memory_desc_t *diff_data_desc, int axis, int group_size);
1138 /** @addtogroup c_api_eltwise Eltwise
1139 * A primitive to compute element-wise operations like parametric rectifier
1140 * linear unit (ReLU).
1142 * Both forward and backward passes support in-place operation; that is, src
1143 * and dst point to the same memory for forward pass, and diff_dst and diff_src
1144 * point to the same memory for backward pass.
1146 * @warning Because the original src is required for backward pass, in-place
1147 * forward pass in general cannot be applied during training. However, for some
1148 * kinds of element-wise operations (namely ReLU with alpha parameter equals 0),
1149 * dst and src can be interchangeable for the backward pass, which enables
1150 * performing in-place forward even for training.
1154 /** Initializes an @p eltwise_desc for forward propagation using @p prop_kind
1155 * (possible values are #mkldnn_forward_training and #mkldnn_forward_inference),
1156 * @p alg_kind algorithm, memory descriptor @p data_desc, @p alpha, and
1157 * @p beta parameters.
1159 * @sa mkldnn_eltwise_desc_t for details.
1162 * - src (#mkldnn_query_src_pd, 0)
1165 * - dst (#mkldnn_query_dst_pd, 0)
1167 mkldnn_status_t MKLDNN_API mkldnn_eltwise_forward_desc_init(
1168 mkldnn_eltwise_desc_t *eltwise_desc, mkldnn_prop_kind_t prop_kind,
1169 mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *data_desc,
1170 float alpha, float beta);
1172 /** Initializes an @p eltwise_desc for backward propagation using @p alg_kind
1173 * algorithm memory descriptors @p diff_data_desc and @p data_desc, and the
1174 * @p alpha and @p beta parameters.
1176 * @sa mkldnn_eltwise_desc_t for details.
1179 * - src (#mkldnn_query_src_pd, 0)
1180 * - diff_dst (#mkldnn_query_diff_dst_pd, 0)
1183 * - diff_src (#mkldnn_query_diff_src_pd, 0)
1185 mkldnn_status_t MKLDNN_API mkldnn_eltwise_backward_desc_init(
1186 mkldnn_eltwise_desc_t *eltwise_desc, mkldnn_alg_kind_t alg_kind,
1187 const mkldnn_memory_desc_t *diff_data_desc,
1188 const mkldnn_memory_desc_t *data_desc, float alpha, float beta);
1192 /** @addtogroup c_api_depthwise Depthwise
1193 * A primitive to compute channel wise operations like scale and shift
1196 /** Initializes a @p depthwise_desc for forward propagation using @p prop_kind
1197 * (possible values are #mkldnn_forward_training or #mkldnn_forward_inference),
1198 * @p alg_kind algorithm, memory descriptor @p data_desc.
1199 * @sa mkldnn_depthwise_desc_t for details */
1200 mkldnn_status_t MKLDNN_API mkldnn_depthwise_forward_desc_init(
1201 mkldnn_depthwise_desc_t *depthwise_desc, mkldnn_prop_kind_t prop_kind,
1202 mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *dst_desc,
1203 const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *bias_desc);
1207 /** @addtogroup c_api_softmax Softmax
1208 * A primitive to perform softmax.
1210 * \f[dst[u][c][in] =
1211 * \frac{\exp(src[ou][c][in]) - \max\limits_{c}(src[ou][c][in])}
1212 * {\sum\limits_{c}\{\exp(src[ou][c][in])
1213 * - \max\limits_{c}(src[ou][c][in])\}},\f]
1215 * where \f$ou, iu\f$ are outer and inner sizes repectively, defined
1216 * by @p data_desc.dims and @p softmax_axis.
1219 /** Initializes a @p softmax_desc for forward propagation using @p prop_kind
1220 * (possible values are #mkldnn_forward_training and #mkldnn_forward_inference)
1221 * and memory descriptor @p data_desc.
1224 * - src (#mkldnn_query_src_pd, 0)
1227 * - dst (#mkldnn_query_dst_pd, 0)
1229 mkldnn_status_t MKLDNN_API mkldnn_softmax_forward_desc_init(
1230 mkldnn_softmax_desc_t *softmax_desc, mkldnn_prop_kind_t prop_kind,
1231 const mkldnn_memory_desc_t *data_desc, int softmax_axis);
1233 /** Initializes a @p softmax_desc for backward propagation using memory
1234 * descriptors @p diff_desc and @p data_desc.
1237 * - dst (#mkldnn_query_dst_pd, 0)
1238 * - diff_dst (#mkldnn_query_diff_dst_pd, 0)
1241 * - diff_src (#mkldnn_query_diff_src_pd, 0)
1243 mkldnn_status_t MKLDNN_API mkldnn_softmax_backward_desc_init(
1244 mkldnn_softmax_desc_t *softmax_desc,
1245 const mkldnn_memory_desc_t *diff_desc,
1246 const mkldnn_memory_desc_t *data_desc, int softmax_axis);
1250 /** @addtogroup c_api_pooling Pooling
1251 * A primitive to perform max or average pooling.
1254 * \f[dst[n][oc][oh][ow] =
1255 * \max\limits_{kw,kh}
1256 * (src[n][ic][oh \cdot s_h - p_l[0] + kh][ow \cdot s_w - p_r[1] + kw]),\f]
1259 * \f[dst[n][oc][oh][ow] =
1260 * \frac{1}{KW \cdot KH}\sum\limits_{kw,kh}
1261 * src[n][ic][oh \cdot s_h - p_l[0] + kh][ow \cdot s_w - p_r[1] + kw],\f]
1263 * where \f$p_l, p_r\f$ are @p padding_l and @p padding_r respectively, and
1264 * output spatial dimensions are calculated similarly to how they are done in
1267 * During training, max pooling requires a workspace on forward
1268 * (#mkldnn_forward_training) and backward (#mkldnn_backward) passes to
1269 * save indices where maximum was found. The workspace layout is opaque, and
1270 * the indices cannot be restored from it. However, one can use backward
1271 * pooling to perform up-sampling (used in some detection topologies).
1275 /** Initializes a pooling descriptor @p pool_desc for forward propagation using
1276 * @p prop_kind (possible values are #mkldnn_forward_training and
1277 * #mkldnn_forward_inference), @p alg_kind, memory descriptors, and pooling
1278 * parameters in the spatial domain: @p strides, @p kernel sizes, @p padding_l,
1279 * @p padding_r, and @p padding_kind.
1281 * @note If @p padding_r is @c NULL, the padding is supposed to be symmetric.
1284 * - src (#mkldnn_query_src_pd, 0)
1287 * - dst (#mkldnn_query_dst_pd, 0)
1288 * - workspace (#mkldnn_query_workspace_pd, 0),
1289 * if @p alg_kind = #mkldnn_pooling_max and
1290 * @p prop_kind = #mkldnn_forward_training
1292 mkldnn_status_t MKLDNN_API mkldnn_pooling_forward_desc_init(
1293 mkldnn_pooling_desc_t *pool_desc, mkldnn_prop_kind_t prop_kind,
1294 mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc,
1295 const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides,
1296 const mkldnn_dims_t kernel, const mkldnn_dims_t padding_l,
1297 const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind);
1299 /** Initializes a pooling descriptor @p pool_desc for backward propagation
1300 * using @p alg_kind, memory descriptors, and pooling parameters in the spatial
1301 * domain: @p strides, @p kernel sizes, @p padding_l, @p padding_r, and @p
1304 * @note If @p padding_r is @c NULL, the padding is supposed to be symmetric.
1307 * - diff_dst (#mkldnn_query_diff_dst_pd, 0)
1308 * - workspace (#mkldnn_query_workspace_pd, 0),
1309 * if @p alg_kind = #mkldnn_pooling_max
1312 * - diff_src (#mkldnn_query_diff_src_pd, 0)
1314 mkldnn_status_t MKLDNN_API mkldnn_pooling_backward_desc_init(
1315 mkldnn_pooling_desc_t *pool_desc, mkldnn_alg_kind_t alg_kind,
1316 const mkldnn_memory_desc_t *diff_src_desc,
1317 const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides,
1318 const mkldnn_dims_t kernel, const mkldnn_dims_t padding_l,
1319 const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind);
1323 /** @addtogroup c_api_lrn LRN
1324 * A primitive to perform local response normalization (LRN) across or within
1327 * LRN accross channels:
1328 * \f[dst[n][c][h][w] = \left\{k + \frac{\alpha}{n_{l}}
1329 * \sum\limits_{i=-(n_{l}-1)/2}^{(n_{l}+1)/2}
1330 * (src[n][c+i][h][w])^2\right\}^{-\beta}
1331 * src[n][c][h][w],\f]
1333 * LRN within channels:
1334 * \f[dst[n][c][h][w] = \left\{k + \frac{\alpha}{n_{l}}
1335 * \sum\limits_{i=-(n_{l}-1)/2}^{(n_{l}+1)/2}
1336 * (src[n][c][h+i][w+i])^2\right\}^{-\beta}
1337 * src[n][c][h][w],\f]
1339 * where \f$n_{l}\f$ is the @p local_size.
1341 * During training, LRN might or might not require a workspace on forward
1342 * (#mkldnn_forward_training) and backward (#mkldnn_backward) passes. The
1343 * behavior is implementation specific. Optimized implementations typically
1344 * require a workspace and use it to save some intermediate results from the
1345 * forward pass that accelerate computations on the backward pass.
1347 * To check whether a workspace is required, query the LRN primitive descriptor
1348 * for the workspace (#mkldnn_query_workspace_pd). Success indicates that the
1349 * workspace is required and its description will be returned.
1350 * @sa mkldnn_primitive_desc_query and mkldnn_primitive_desc_query_pd
1354 /** Initializes an @p lrn_desc for forward propagation using @p prop_kind
1355 * (possible values are #mkldnn_forward_training and #mkldnn_forward_inference),
1356 * @p alg_kind, memory descriptor @p data_desc, and regularization
1357 * parameters @p local_size, @p alpha, @p beta, and @p k.
1360 * - src (#mkldnn_query_src_pd, 0)
1363 * - dst (#mkldnn_query_dst_pd, 0)
1364 * - workspace (#mkldnn_query_workspace_pd, 0),
1365 * if the underlying implementation requires
1367 mkldnn_status_t MKLDNN_API mkldnn_lrn_forward_desc_init(
1368 mkldnn_lrn_desc_t *lrn_desc, mkldnn_prop_kind_t prop_kind,
1369 mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *data_desc,
1370 int local_size, float alpha, float beta, float k);
1372 /** Initializes an @p lrn_desc for backward propagation using @p alg_kind,
1373 * memory descriptors @p data_desc and @p diff_data_desc, and regularization
1374 * parameters @p local_size, @p alpha, @p beta, and @p k.
1377 * - src (#mkldnn_query_src_pd, 0)
1378 * - diff_dst (#mkldnn_query_diff_dst_pd, 0)
1379 * - workspace (#mkldnn_query_workspace_pd, 0),
1380 * if the underlying implementation requires
1383 * - diff_src (#mkldnn_query_diff_src_pd, 0)
1385 mkldnn_status_t MKLDNN_API mkldnn_lrn_backward_desc_init(
1386 mkldnn_lrn_desc_t *lrn_desc, mkldnn_alg_kind_t alg_kind,
1387 const mkldnn_memory_desc_t *diff_data_desc,
1388 const mkldnn_memory_desc_t *data_desc, int local_size, float alpha,
1389 float beta, float k);
1393 /** @addtogroup c_api_batch_normalization Batch Normalization
1394 * A primitive to perform batch normalization.
1396 * \f[dst[n][c][h][w] = \gamma[c] \frac{src[n][c][h][w] - \mu[c]}
1397 * {\sqrt{\sigma[c] + eps}} + \beta[c],\f]
1399 * where \f$\gamma[c], \beta[c]\f$ are weights and bias for a channel and,
1401 * \f$\mu[c] = \frac{1}{NHW} \sum\limits_{whn} src[n][c][h][w]\f$,
1402 * \f$\sigma[c] = \frac{1}{NHW} \sum\limits_{whn}
1403 * (src[n][c][h][w] - \mu[c])^2\f$,
1405 * and @c eps is a constant to improve numerical stability.
1407 * Both forward and backward passes support in-place operation; that is, src
1408 * and dst point to the same memory for forward pass, and diff_dst and diff_src
1409 * point to the same memory for backward pass.
1411 * Batch normalization supports different flavors controlled by
1412 * mkldnn_batch_normalization_desc_t. For example, batch normalization can
1413 * compute the mean and variance on its own or take them as inputs. It can
1414 * either perform scaling and shifting using gamma and beta parameters or not.
1415 * Optionally it can also perform a fused ReLU, which in case of training would
1416 * also require a workspace.
1418 * @sa mkldnn_batch_normalization_desc_t
1421 /** Initializes a batch normalization descriptor @p bnrm_desc for forward
1422 * propagation using @p prop_kind (possible values are
1423 * #mkldnn_forward_training and #mkldnn_forward_inference), memory descriptor
1424 * @p data_desc, normalization parameter @p epsilon, and @p flags set using bit
1425 * flags of type mkldnn_batch_normalization_desc_t.
1428 * - src (#mkldnn_query_src_pd, 0)
1429 * - mean (#mkldnn_query_src_pd, 1),
1430 * if #mkldnn_use_global_stats bit-flags is set in @p flags
1431 * - variance (#mkldnn_query_src_pd, 2),
1432 * if #mkldnn_use_global_stats bit-flags is set in @p flags
1433 * - scale_and_shift (#mkldnn_query_weights_pd, 0),
1434 * if #mkldnn_use_scaleshift bit-flags is set in @p flags
1437 * - dst (#mkldnn_query_dst_pd, 0)
1438 * - mean (#mkldnn_query_dst_pd, 1),
1439 * if #mkldnn_use_global_stats bit-flags is not set in @p flags
1440 * @p prop_kind = #mkldnn_forward_training
1441 * - variance (#mkldnn_query_dst_pd, 2),
1442 * if #mkldnn_use_global_stats bit-flags is not set in @p flags
1443 * and @p prop_kind = #mkldnn_forward_training
1444 * - workspace (#mkldnn_query_workspace_pd, 0),
1445 * if #mkldnn_fuse_bn_relu bit-flags is set in @p flags
1446 * and @p prop_kind = #mkldnn_forward_training
1448 * @note In-place operation is supported; that is, dst points to the same memory
1451 * @sa mkldnn_batch_normalization_desc_t
1453 mkldnn_status_t MKLDNN_API mkldnn_batch_normalization_forward_desc_init(
1454 mkldnn_batch_normalization_desc_t *bnrm_desc,
1455 mkldnn_prop_kind_t prop_kind, const mkldnn_memory_desc_t *data_desc,
1456 float epsilon, unsigned flags);
1458 /** Initializes a batch normalization descriptor @p bnrm_desc for backward
1459 * propagation with respect to data and scale-shift parameters using memory
1460 * descriptors @p data_desc and @p diff_data_desc, normalization parameter
1461 * @p epsilon, and @p flags set using bit flags of type
1462 * mkldnn_batch_normalization_desc_t.
1465 * - src (#mkldnn_query_src_pd, 0)
1466 * - mean (#mkldnn_query_src_pd, 1)
1467 * - variance (#mkldnn_query_src_pd, 2)
1468 * - diff_dst (#mkldnn_query_diff_dst_pd, 0)
1469 * - scale_and_shift (#mkldnn_query_weights_pd, 0),
1470 * if #mkldnn_use_scaleshift bit-flags is set in @p flags
1471 * - workspace (#mkldnn_query_workspace_pd, 0),
1472 * if #mkldnn_fuse_bn_relu bit-flags is set in @p flags
1475 * - diff_src (#mkldnn_query_diff_src_pd, 0)
1476 * - diff_scale_and_shift (#mkldnn_query_diff_weights_pd, 0),
1477 * if #mkldnn_use_scaleshift bit-flags is set in @p flags
1478 * and @p prop_kind = #mkldnn_backward
1480 * @note in-place operation is supported,
1481 * i.e. diff_src points to the same memory as diff_dst.
1483 * @sa mkldnn_batch_normalization_desc_t
1485 mkldnn_status_t MKLDNN_API mkldnn_batch_normalization_backward_desc_init(
1486 mkldnn_batch_normalization_desc_t *bnrm_desc,
1487 mkldnn_prop_kind_t prop_kind,
1488 const mkldnn_memory_desc_t *diff_data_desc,
1489 const mkldnn_memory_desc_t *data_desc,
1490 float epsilon, unsigned flags);
1494 /** @addtogroup c_api_inner_product Inner product
1495 * A primitive to compute an inner product.
1497 * Inner product layer is also known as fully connected layer.
1498 * With spatial dimension:
1500 * \f[dst[n][oc] = \sum\limits_{ic, kh, kw}
1501 * src[n][ic][kh][kw] \cdot weights[oc][ic][kh][kw]
1505 /** Initializes an inner product descriptor @p ip_desc for forward propagation
1506 * using @p prop_kind (possible values are #mkldnn_forward_training and
1507 * #mkldnn_forward_inference) and memory descriptors. In order to create an
1508 * inner product without bias, @p bias_desc should be either @c NULL or a
1509 * pointer to a descriptor with memory format equal to #mkldnn_format_undef.
1512 * Memory descriptors are allowed to be initialized with #mkldnn_any value
1513 * of @p format_kind.
1516 * - src (#mkldnn_query_src_pd, 0)
1517 * - weights (#mkldnn_query_weights_pd, 0)
1518 * - bias (#mkldnn_query_weights_pd, 1), if created with bias
1521 * - dst (#mkldnn_query_dst_pd, 0)
1523 mkldnn_status_t MKLDNN_API mkldnn_inner_product_forward_desc_init(
1524 mkldnn_inner_product_desc_t *ip_desc, mkldnn_prop_kind_t prop_kind,
1525 const mkldnn_memory_desc_t *src_desc,
1526 const mkldnn_memory_desc_t *weights_desc,
1527 const mkldnn_memory_desc_t *bias_desc,
1528 const mkldnn_memory_desc_t *dst_desc);
1530 /** Initializes an inner product descriptor @p ip_desc for backward propagation
1531 * with respect to data using memory descriptors.
1534 * Memory descriptors are allowed to be initialized with #mkldnn_any value
1535 * of @p format_kind.
1538 * - diff_dst (#mkldnn_query_diff_dst_pd, 0)
1539 * - weights (#mkldnn_query_weights_pd, 0)
1542 * - diff_src (#mkldnn_query_diff_src_pd, 0)
1544 mkldnn_status_t MKLDNN_API mkldnn_inner_product_backward_data_desc_init(
1545 mkldnn_inner_product_desc_t *ip_desc,
1546 const mkldnn_memory_desc_t *diff_src_desc,
1547 const mkldnn_memory_desc_t *weights_desc,
1548 const mkldnn_memory_desc_t *diff_dst_desc);
1550 /** Initializes an inner product descriptor @p ip_desc for backward propagation
1551 * with respect to weights using memory descriptors.
1554 * Memory descriptors are allowed to be initialized with #mkldnn_any value
1555 * of @p format_kind.
1558 * - src (#mkldnn_query_src_pd, 0)
1559 * - diff_dst (#mkldnn_query_diff_dst_pd, 0)
1562 * - diff_weights (#mkldnn_query_diff_weights_pd, 0)
1563 * - diff_bias (#mkldnn_query_diff_weights_pd, 1), if created with bias
1565 mkldnn_status_t MKLDNN_API mkldnn_inner_product_backward_weights_desc_init(
1566 mkldnn_inner_product_desc_t *ip_desc,
1567 const mkldnn_memory_desc_t *src_desc,
1568 const mkldnn_memory_desc_t *diff_weights_desc,
1569 const mkldnn_memory_desc_t *diff_bias_desc,
1570 const mkldnn_memory_desc_t *diff_dst_desc);
1574 /** @addtogroup c_api_rnn RNN
1575 * A primitive to compute the common recurrent layer.
1576 * @todo add additional description for the group
1580 * Initializes a recurrent cell descriptor @p rnn_cell_desc
1581 * using @p rnn_cell_desc, @p kind (possible values are
1582 * #mkldnn_vanilla_rnn, #mkldnn_vanilla_lstm, #mkldnn_vanilla_gru, and
1583 * #mkldnn_gru_linear_before_reset),
1584 * @p f (possible values are #mkldnn_eltwise_relu and
1585 * #mkldnn_eltwise_tanh), @p flags, @p alpha, and @p clipping.
1587 mkldnn_status_t MKLDNN_API mkldnn_rnn_cell_desc_init(
1588 mkldnn_rnn_cell_desc_t *rnn_cell_desc,
1589 mkldnn_alg_kind_t kind, mkldnn_alg_kind_t f,
1590 unsigned int flags, float alpha, float clipping);
1592 /** Returns the number of gates of a particular @p rnn_cell_desc. */
1593 int MKLDNN_API mkldnn_rnn_cell_get_gates_count(
1594 const mkldnn_rnn_cell_desc_t *rnn_cell_desc);
1596 /** Returns the number of states of a particular @p rnn_cell_desc. */
1597 int MKLDNN_API mkldnn_rnn_cell_get_states_count(
1598 const mkldnn_rnn_cell_desc_t *rnn_cell_desc);
1600 /** Sets quantization @p scale and @p shift for RNN data tensors.
1601 * For performance reasons, low precision configuration of RNN primitive
1602 * expects input activations to have unsigned int8 data type. Scale and shift
1603 * used to quantize floating point data to unsigned integer must be passed to
1604 * RNN primitive using attributes.
1608 * int l = 2, t = 2, mb = 32, sic = 32, slc = 32, dic = 32, dlc = 32;
1609 * // activations quantization parameters
1610 * float scale = ..., shift = ..;
1612 * mkldnn_primitive_attr_t rnn_attr;
1613 * // create default attributes
1614 * mkldnn_primitive_attr_create(&rnn_attr);
1616 * // set scale and shift for int8 quantization of activation
1617 * mkldnn_primitive_attr_set_rnn_data_qparams(rnn_attr, scale, shift);
1619 * // create & configure rnn op_desc
1620 * mkldnn_rnn_desc_t rnn_d;
1621 * mkldnn_primitive_desc_t rnn_pd;
1622 * mkldnn_primitive_desc_create_v2(&rnn_pd, &rnn_d, attr, NULL);
1625 * Quantization scale and shift are common for src_layer, src_iter,
1626 * dst_iter and dst_layer.
1628 mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_rnn_data_qparams(
1629 mkldnn_primitive_attr_t attr, const float scale, const float shift);
1631 /** Sets quantization scales @p weights_scales for RNN weights tensors.
1632 * Low precision configuration of RNN primitive expects input weights to have
1633 * signed int8 data type. Scales used to quantize floating point data
1634 * to signed integer must be passed to RNN primitive using attributes.
1635 * The @p mask argument defines correspondence between output tensor dimensions
1636 * and the @p weights_scales array. Set i-th bit of @p mask to 1 to use
1637 * dedicated scaling factor for each slice of the output tensor over i-th
1638 * dimension. Set @p mask to 0 to use common scaling factor for the whole output
1639 * tensor. Example usage:
1642 * int l = 2, t = 2, mb = 32, sic = 32, slc = 32, dic = 32, dlc = 32;
1643 * // unique output scales per output channel
1644 * float weights_scales[dic * n_gates] = { ... };
1645 * // mask that specifies last two dimensions of ldigo format
1648 * mkldnn_primitive_attr_t attr;
1649 * // create default attributes
1650 * mkldnn_primitive_attr_create(&attr);
1652 * // set output channel-wise weights scales
1653 * mkldnn_primitive_attr_set_rnn_weights_qparams(attr, dic * n_gates, mask,
1656 * // create & configure rnn op_desc
1657 * mkldnn_rnn_desc_t rnn_d;
1658 * mkldnn_primitive_desc_t rnn_pd;
1659 * mkldnn_primitive_desc_create_v2(&rnn_pd, &rnn_d, attr, NULL);
1662 * The dimension order is always native and does not depend on the actual
1663 * layout used. For example, 5 dimensional weights always have
1664 * (l, d, i, g, o) logical dimension ordering.
1666 * Quantization sales are common for weights_layer and weights_iteration
1668 * There is no way to check that @p count corresponds to @p mask until an
1669 * actual primitive descriptor is created, so it is user's responsibility
1670 * to set proper values. The following formula must be held:
1672 * \f[count = \prod\limits_{d \in mask} output.dims[d]\f]
1674 mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_rnn_weights_qparams (
1675 mkldnn_primitive_attr_t attr, int count, int mask,
1676 const float *weights_scales);
1678 /** Initializes a rnn descriptor @p rnn_desc for forward propagation
1679 * using @p prop_kind, @p rnn_cell_desc, @p direction, and memory descriptors.
1680 * @note If @p prop_kind equals #mkldnn_forward_training, you must query a
1681 * workspace memory descriptor before creating the primitive.
1683 * @p src_iter_desc, @p bias_desc, and @p dst_iter_desc are allowed to either be
1684 * @c NULL or point to a zero memory descriptor, which would indicate that the
1685 * RNN primitive should not use them.
1687 * @note All memory descriptors except @p src_iter_desc are allowed to be
1688 * initialized with #mkldnn_any value of @p format_kind.
1691 * - src_layer (#mkldnn_query_src_pd, 0)
1692 * - src_iter (#mkldnn_query_src_pd, 1), if used
1693 * - weights_layer (#mkldnn_query_weights_pd, 0)
1694 * - weights_iter (#mkldnn_query_weights_pd, 1)
1695 * - bias (#mkldnn_query_weights_pd, 2), if used
1698 * - dst_layer (#mkldnn_query_dst_pd, 0)
1699 * - dst_iter (#mkldnn_query_dst_pd, 1), if used
1700 * - workspace (#mkldnn_query_workspace_pd, 0),
1701 * if @p prop_kind equals #mkldnn_forward_training
1703 mkldnn_status_t MKLDNN_API mkldnn_rnn_forward_desc_init(
1704 mkldnn_rnn_desc_t *rnn_desc, mkldnn_prop_kind_t prop_kind,
1705 const mkldnn_rnn_cell_desc_t *rnn_cell_desc,
1706 const mkldnn_rnn_direction_t direction,
1707 const mkldnn_memory_desc_t *src_layer_desc,
1708 const mkldnn_memory_desc_t *src_iter_desc,
1709 const mkldnn_memory_desc_t *weights_layer_desc,
1710 const mkldnn_memory_desc_t *weights_iter_desc,
1711 const mkldnn_memory_desc_t *bias_desc,
1712 const mkldnn_memory_desc_t *dst_layer_desc,
1713 const mkldnn_memory_desc_t *dst_iter_desc);
1715 /** Initializes a rnn descriptor @p rnn_desc for backward propagation
1716 * using @p prop_kind, @p rnn_cell_desc, @p direction, and memory descriptors.
1717 * @note All memory descriptors are allowed to be initialized with
1718 * #mkldnn_any value of @p format_kind.
1720 * @p src_iter_desc (simultaneously with @p diff_src_iter_desc),
1721 * @p bias_desc (simultaneously with @p diff_bias_desc), and
1722 * @p dst_iter_desc (simultaneously with @p diff_src_iter_desc) are allowed to
1723 * either be @c NULL or point to a zero memory descriptor, which would indicate
1724 * that the RNN primitive should not use them.
1727 * - src_layer (#mkldnn_query_src_pd, 0)
1728 * - src_iter (#mkldnn_query_src_pd, 1), if used
1729 * - weights_layer (#mkldnn_query_weights_pd, 0)
1730 * - weights_iter (#mkldnn_query_weights_pd, 1)
1731 * - bias (#mkldnn_query_weights_pd, 2), if used
1732 * - dst_layer (#mkldnn_query_dst_pd, 0)
1733 * - dst_iter (#mkldnn_query_dst_pd, 1), if used
1734 * - diff_dst_layer (#mkldnn_query_diff_dst_pd, 0)
1735 * - diff_dst_iter (#mkldnn_query_diff_dst_pd, 1), if used
1736 * - workspace (#mkldnn_query_workspace_pd, 0)
1739 * - diff_src_layer (#mkldnn_query_diff_src_pd, 0)
1740 * - diff_src_iter (#mkldnn_query_diff_src_pd, 1), if used
1741 * - diff_weights_layer (#mkldnn_query_diff_weights_pd, 0)
1742 * - diff_weights_iter (#mkldnn_query_diff_weights_pd, 1)
1743 * - diff_bias (#mkldnn_query_diff_weights_pd, 2), if used
1745 mkldnn_status_t MKLDNN_API mkldnn_rnn_backward_desc_init(
1746 mkldnn_rnn_desc_t *rnn_desc, mkldnn_prop_kind_t prop_kind,
1747 const mkldnn_rnn_cell_desc_t *rnn_cell_desc,
1748 const mkldnn_rnn_direction_t direction,
1749 const mkldnn_memory_desc_t *src_layer_desc,
1750 const mkldnn_memory_desc_t *src_iter_desc,
1751 const mkldnn_memory_desc_t *weights_layer_desc,
1752 const mkldnn_memory_desc_t *weights_iter_desc,
1753 const mkldnn_memory_desc_t *bias_desc,
1754 const mkldnn_memory_desc_t *dst_layer_desc,
1755 const mkldnn_memory_desc_t *dst_iter_desc,
1756 const mkldnn_memory_desc_t *diff_src_layer_desc,
1757 const mkldnn_memory_desc_t *diff_src_iter_desc,
1758 const mkldnn_memory_desc_t *diff_weights_layer_desc,
1759 const mkldnn_memory_desc_t *diff_weights_iter_desc,
1760 const mkldnn_memory_desc_t *diff_bias_desc,
1761 const mkldnn_memory_desc_t *diff_dst_layer,
1762 const mkldnn_memory_desc_t *diff_dst_iter_desc);
1766 /** @addtogroup c_api_roi_pooling ROI Pooling
1767 * A primitive to perform roi pooling.
1770 /** Initializes a @p roi_pooling_desc for forward propagation using @p prop_kind
1771 * (possible value are #mkldnn_forward_inference)
1772 * and memory descriptor @p data_desc. */
1773 mkldnn_status_t MKLDNN_API mkldnn_roi_pooling_forward_desc_init(
1774 mkldnn_roi_pooling_desc_t *roi_pooling_desc, mkldnn_prop_kind_t prop_kind,
1775 mkldnn_alg_kind_t algorithm,
1776 mkldnn_memory_desc_t *src_descs, int num_inputs,
1777 const mkldnn_memory_desc_t *dst_desc,
1778 int pooled_h, int pooled_w, double spatial_scale);
1782 /** @addtogroup c_api_binary_convolution Binary convolution
1783 * A primitive to compute binary convolution using different algorithms.
1786 /** Initializes a dilated binary convolution descriptor @p bin_conv_desc for forward
1787 * propagation using @p prop_kind (possible values are #mkldnn_forward_training
1788 * or #mkldnn_forward_inference), @p alg_kind, memory descriptors, @p strides,
1789 * @p dilates, @p padding_l, @p padding_r, and @p padding_kind.
1791 * @note if @p padding_r is @c NULL, the padding is supposed to be symmetric
1793 * @note memory descriptors are allowed to be initialized with #mkldnn_any
1794 * value of @p format_kind.
1797 * - src (#mkldnn_query_src_pd, 0)
1798 * - weights (#mkldnn_query_weights_pd, 0)
1801 * - dst (#mkldnn_query_dst_pd, 0)
1803 mkldnn_status_t MKLDNN_API mkldnn_dilated_binary_convolution_forward_desc_init(
1804 mkldnn_binary_convolution_desc_t *bin_conv_desc, mkldnn_prop_kind_t prop_kind,
1805 mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc,
1806 const mkldnn_memory_desc_t *weights_desc,
1807 const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides,
1808 const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l,
1809 const mkldnn_dims_t padding_r, float pad_value);
1813 /** @addtogroup c_api_binarization Binarization
1814 * A primitive to binarize input using different approaches
1817 /** Initializes a @p binarization_desc for forward propagation using @p prop_kind
1818 * (possible values are #mkldnn_forward_training or #mkldnn_forward_inference),
1819 * @p alg_kind algorithm and memory descriptors.
1820 * @sa mkldnn_binarization_desc_t for details */
1821 mkldnn_status_t MKLDNN_API mkldnn_binarization_forward_desc_init(
1822 mkldnn_binarization_desc_t *binarization_desc, mkldnn_prop_kind_t prop_kind,
1823 mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc,
1824 const mkldnn_memory_desc_t *dst_desc, const mkldnn_memory_desc_t *weights_desc);
1828 /** @addtogroup c_api_engine Engine operations
1831 /** Returns the number of engines of a particular @p kind. */
1832 size_t MKLDNN_API mkldnn_engine_get_count(mkldnn_engine_kind_t kind);
1834 /** Creates an @p engine of particular @p kind and @p index. */
1835 mkldnn_status_t MKLDNN_API mkldnn_engine_create(mkldnn_engine_t *engine,
1836 mkldnn_engine_kind_t kind, size_t index);
1838 /** Returns the kind of an @p engine. */
1839 mkldnn_status_t MKLDNN_API mkldnn_engine_get_kind(mkldnn_engine_t engine,
1840 mkldnn_engine_kind_t *kind);
1842 /** Destroys an @p engine. */
1843 mkldnn_status_t MKLDNN_API mkldnn_engine_destroy(mkldnn_engine_t engine);
1847 /** @addtogroup c_api_stream Execution stream operations
1850 /** Creates an execution @p stream of @p stream_kind. */
1851 mkldnn_status_t MKLDNN_API mkldnn_stream_create(mkldnn_stream_t *stream,
1852 mkldnn_stream_kind_t stream_kind);
1854 /** Submits @p primitives to an execution @p stream. The number of primitives
1855 * is @p n. All or none of the primitives can be lazy. In case of an error,
1856 * returns the offending @p error_primitive if it is not @c NULL. */
1857 mkldnn_status_t MKLDNN_API mkldnn_stream_submit(mkldnn_stream_t stream,
1858 size_t n, mkldnn_primitive_t primitives[],
1859 mkldnn_primitive_t *error_primitive);
1861 /** Waits for all primitives in the execution @p stream to finish. Returns
1862 * immediately if @p block is zero. In case of an error, returns
1863 * the offending @p error_primitive if it is not @c NULL. */
1864 mkldnn_status_t MKLDNN_API mkldnn_stream_wait(mkldnn_stream_t stream,
1865 int block, mkldnn_primitive_t *error_primitive);
1867 /** Reruns all the primitives within the @p stream. In case of an error,
1868 * returns the offending @p error_primitive if it is not @c NULL. */
1869 mkldnn_status_t MKLDNN_API mkldnn_stream_rerun(mkldnn_stream_t stream,
1870 mkldnn_primitive_t *error_primitive);
1872 /** Destroys an execution @p stream. */
1873 mkldnn_status_t MKLDNN_API mkldnn_stream_destroy(mkldnn_stream_t stream);
1877 /** @addtogroup c_api_service Service functions
1880 /** Sets verbosity level (print information to stdout).
1881 * Possible levels are:
1882 * - 0 -- no verbose output (default)
1883 * - 1 -- primitive information at execution
1884 * - 2 -- primitive information at creation and execution
1887 * Dumping information might affect performance.
1888 * This setting overrides the MKLDNN_VERBOSE environment variable. */
1889 mkldnn_status_t MKLDNN_API mkldnn_set_verbose(int level);
1891 /** Sets jit dump control.
1893 * - zero -- turn jit dump off (default)
1894 * - non-zero -- turn jit dump on
1897 * This setting overrides the MKLDNN_JIT_DUMP environment variable. */
1898 mkldnn_status_t MKLDNN_API mkldnn_set_jit_dump(int dump);
1900 /** Gets library version information.
1901 * Version information includes:
1902 * - major -- major version number
1903 * - minor -- minor version number
1904 * - patch -- patch release number
1905 * - hash -- git commit hash */
1906 const mkldnn_version_t MKLDNN_API *mkldnn_version();
1908 /** Returns cache size for specified level in bytes.
1910 * Currently, if it is not able to fetch the cache topology
1911 * function defaults to 32KB of L1, 512KB of L2 and 1MB of L3 per core. */
1912 unsigned int MKLDNN_API mkldnn_get_cache_size(int level, int per_core);
1916 /** @addtogroup c_api_blas BLAS functions
1917 * A subset of Basic Linear ALgebra (BLAS) functions to perform
1918 * matrix-matrix multiplication.
1921 /** SGEMM performs a matrix-matrix multiplication operation defined as
1923 * C := alpha*op( A )*op( B ) + beta*C
1926 * - op( X ) is one of op( X ) = X or op( X ) = X**T,
1927 * - alpha and beta are scalars,
1928 * - A, B and C are matrices, with op( A ) an m by k matrix, op( B ) a k by n matrix
1929 * and C an m by n matrix.
1931 * The matrices are assumed to be stored in column-major order (the elements
1932 * in a matrix columns are contiguous in memory).
1935 * The API is different from the standard BLAS routine
1936 * because it returns mkldnn_status_t for error handling.
1937 * XERBLA is not supported: no error message will be printed
1938 * in case of incorrect parameters. */
1939 mkldnn_status_t MKLDNN_API mkldnn_sgemm(const char *transa, const char *transb,
1940 const int *M, const int *N, const int *K,
1941 const float *alpha, const float *A, const int *lda,
1942 const float *B, const int *ldb,
1943 const float *beta, float *C, const int *ldc);
1945 /** gemm_s8u8s32 and gemm_s8s8s32 perform a matrix-matrix multiplication
1946 * operation and add the result to a scalar-matrix product. For the final
1947 * result, a vector is added to each row or column of the output matrix.
1948 * The operation is defined as:
1950 * C := alpha*(op(A) + A_offset) * (op(B) + B_offset) + beta*C + C_offset
1953 * - op( X ) = X or op( X ) = X**T,
1954 * - A_offset is an m-by-k matrix with every element equal to the value oa,
1955 * - B_offset is an k-by-n matrix with every element equal to the value ob,
1956 * - C_offset is an m-by-n matrix defined by the oc array, size len:
1957 * - if offsetc = F: len must be at least 1
1958 * - if offsetc = C: len must be at least max(1, m)
1959 * - if offsetc = R: len must be at least max(1, n)
1960 * - alpha and beta are scalars, and A, B and C are matrices, with op( A )
1961 * an m-by-k matrix, op( B ) a k-by-n matrix and C an m-by-n matrix.
1963 * The matrices are assumed to be stored in column-major order (the elements
1964 * in a matrix columns are contiguous in memory).
1967 * The API is different compared with the standard BLAS routine
1968 * because it returns mkldnn_status_t for error handling.
1969 * XERBLA is not supported: no error message will be printed
1970 * in case of incorrect parameters. */
1971 mkldnn_status_t MKLDNN_API mkldnn_gemm_s8u8s32(const char *transa,
1972 const char *transb, const char *offsetc, const int *M, const int *N,
1973 const int *K, const float *alpha, const int8_t *A, const int *lda,
1974 const int8_t *ao, const uint8_t *B, const int *ldb, const int8_t *bo,
1975 const float *beta, int32_t *c, const int *ldc, const int32_t *co);
1977 mkldnn_status_t MKLDNN_API mkldnn_gemm_s8s8s32(const char *transa,
1978 const char *transb, const char *offsetc, const int *M, const int *N,
1979 const int *K, const float *alpha, const int8_t *A, const int *lda,
1980 const int8_t *ao, const int8_t *B, const int *ldb, const int8_t *bo,
1981 const float *beta, int32_t *c, const int *ldc, const int32_t *co);