1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
20 #ifndef DOXYGEN_SHOULD_SKIP_THIS
33 /// @addtogroup cpp_api C++ API
36 /// @addtogroup cpp_api_utils Utils
39 /// A class that provides the destructor for an Intel(R) MKL-DNN C handle
40 template <typename T> class handle_traits {};
42 /// A class for wrapping an Intel(R) MKL-DNN handle. It is used as the base
43 /// class for primitive (#mkldnn_primitive_t), engine (#mkldnn_engine_t), and
44 /// stream (#mkldnn_stream_t) handles. An object of the #mkldnn::handle class
45 /// can be passed by value. This class enables wrapping:
46 /// - Newly constructed handles.
47 /// @n In this case, the constructed handle uses reference counting provided
48 /// by @p std::shared_ptr with a proper deleter function specified through
49 /// the @p handle_traits class.
50 /// - Pre-existing handles returned by the Intel(R) MKL-DNN C API (for
51 /// example, through #mkldnn_primitive_get_output()).
52 /// @n In this case, an Intel(R) MKL-DNN C API handle is wrapped without a
53 /// deleter because it is assumed that the handle wrapper for the original
54 /// object deletes the handle (this model is similar to @p std::weak_ptr).
55 template <typename T, typename traits=handle_traits<T>> class handle {
57 std::shared_ptr<typename std::remove_pointer<T>::type> _data;
58 handle(const handle &&) = delete;
59 handle &operator=(const handle &&other) = delete;
61 bool operator==(const T other) const { return other == _data.get(); }
62 bool operator!=(const T other) const { return !(*this == other); }
64 /// Constructs a C handle wrapper.
65 /// @param t The C handle to wrap.
66 /// @param weak A flag to specify whether to construct a weak wrapper.
67 handle(T t = 0, bool weak = false): _data(0) {
71 handle(const handle &other): _data(other._data) {}
72 handle &operator=(const handle &other) {
76 /// Resets the value of a C handle.
77 /// @param t The new value of the C handle.
78 /// @param weak A flag to specify whether the wrapper should be weak.
79 void reset(T t, bool weak = false) {
80 if (weak) _data.reset(t, [](T) { return decltype(traits::destructor(0))(0); });
81 else _data.reset(t, traits::destructor);
84 /// Returns the value of the underlying C handle.
85 T get() const { return _data.get(); }
87 bool operator==(const handle &other) const { return other._data.get() == _data.get(); }
88 bool operator!=(const handle &other) const { return !(*this == other); }
91 #ifndef DOXYGEN_SHOULD_SKIP_THIS
92 template <> struct handle_traits<mkldnn_primitive_desc_t> {
93 static constexpr auto destructor = &mkldnn_primitive_desc_destroy;
96 template <> struct handle_traits<mkldnn_primitive_t> {
97 static constexpr auto destructor = &mkldnn_primitive_destroy;
100 template <> struct handle_traits<mkldnn_primitive_desc_iterator_t> {
101 static constexpr auto destructor = &mkldnn_primitive_desc_iterator_destroy;
105 /// Base class for all computational primitives.
106 class primitive: public handle<mkldnn_primitive_t> {
108 friend struct stream;
109 friend class primitive_at;
110 using handle::handle;
112 /// A proxy to C primitive kind enum
114 undefined_primitive = mkldnn_undefined_primitive,
115 memory = mkldnn_memory,
117 reorder = mkldnn_reorder,
118 concat = mkldnn_concat,
119 concat_inplace = mkldnn_concat_inplace,
121 convolution = mkldnn_convolution,
122 deconvolution = mkldnn_deconvolution,
123 shuffle = mkldnn_shuffle,
124 eltwise = mkldnn_eltwise,
125 depthwise = mkldnn_depthwise,
126 softmax = mkldnn_softmax,
127 pooling = mkldnn_pooling,
129 batch_normalization = mkldnn_batch_normalization,
130 inner_product = mkldnn_inner_product,
132 binary_convolution = mkldnn_binary_convolution,
133 binarization = mkldnn_binarization,
136 /// A wrapper structure to specify a particular output of a primitive.
138 /// The underlying C API structure.
139 mkldnn_primitive_at_t data;
140 /// Constructs a wrapper specifying @p aprimitive output with index @p
143 /// @param aprimitive The target primitive.
144 /// @param at The output index.
146 at(const primitive &aprimitive, size_t at = 0)
147 : data(mkldnn_primitive_at(aprimitive.get(), at)) {}
148 /// Returns the specified output.
149 inline operator primitive() const;
152 /// Returns the descriptor of the underlying C API primitive.
153 inline const_mkldnn_primitive_desc_t get_primitive_desc() const;
154 // TODO: use the C++ API wrapper structure.
157 inline mkldnn_primitive_kind_t convert_to_c(primitive::kind akind) {
158 return static_cast<mkldnn_primitive_kind_t>(akind);
160 /// Intel(R) MKL-DNN exception class.
162 /// This class captures the status returned by the failed C API function, error
163 /// message, and, optionally, handle of the primitive that caused the error.
164 struct error: public std::exception {
165 mkldnn_status_t status;
167 primitive error_primitive;
169 /// Constructs an error instance.
171 /// @param astatus The error status returned by the C API.
172 /// @param amessage The error message.
173 /// @param aerror_primitive (optional) A C handle of the primitive that
174 /// caused the error.
176 error(mkldnn_status_t astatus, std::string amessage,
177 mkldnn_primitive_t aerror_primitive = 0)
180 , error_primitive(aerror_primitive, true)
183 /// A convenience function for wrapping calls to the C API. Checks the
184 /// return status and throws an #error in case of failure.
186 /// @param status The error status returned by the C API.
187 /// @param message The error message.
188 /// @param error_primitive (optional) A C handle of the primitive that
189 /// caused the error.
191 static void wrap_c_api(mkldnn_status_t status,
192 const std::string &message,
193 mkldnn_primitive_t *error_primitive = 0)
195 if (status != mkldnn_success) {
196 if (nullptr != error_primitive)
197 throw error(status, message, *error_primitive);
199 throw error(status, message, nullptr);
204 inline primitive::at::operator primitive() const {
205 const_mkldnn_primitive_t output;
207 mkldnn_primitive_get_output(data.primitive,
208 data.output_index, &output),
209 "could not get an output primitive");
210 return primitive(const_cast<mkldnn_primitive_t>(output), true);
213 const_mkldnn_primitive_desc_t primitive::get_primitive_desc() const {
214 const_mkldnn_primitive_desc_t pd;
215 error::wrap_c_api(mkldnn_primitive_get_primitive_desc(get(), &pd),
216 "could not get primitive descriptor by primitive");
221 /// @addtogroup cpp_api_enums Common data types and enumerations
222 /// A proxy to @ref c_api_types in @ref c_api.
227 round_nearest = mkldnn_round_nearest,
228 round_down = mkldnn_round_down,
231 inline mkldnn_round_mode_t convert_to_c(round_mode mode) {
232 return static_cast<mkldnn_round_mode_t>(mode);
236 zero = mkldnn_padding_zero
239 inline mkldnn_padding_kind_t convert_to_c(padding_kind kind) {
240 return static_cast<mkldnn_padding_kind_t>(kind);
244 forward_training = mkldnn_forward_training,
245 forward_scoring = mkldnn_forward_scoring,
246 forward_inference = mkldnn_forward_inference,
247 forward = mkldnn_forward,
248 backward = mkldnn_backward,
249 backward_data = mkldnn_backward_data,
250 backward_weights = mkldnn_backward_weights,
251 backward_bias = mkldnn_backward_bias
254 inline mkldnn_prop_kind_t convert_to_c(prop_kind kind) {
255 return static_cast<mkldnn_prop_kind_t>(kind);
259 algorithm_undef = mkldnn_alg_kind_undef,
260 convolution_auto = mkldnn_convolution_auto,
261 convolution_direct = mkldnn_convolution_direct,
262 convolution_winograd = mkldnn_convolution_winograd,
263 deconvolution_direct = mkldnn_deconvolution_direct,
264 deconvolution_winograd = mkldnn_deconvolution_winograd,
265 eltwise_relu = mkldnn_eltwise_relu,
266 eltwise_tanh = mkldnn_eltwise_tanh,
267 eltwise_elu = mkldnn_eltwise_elu,
268 eltwise_square = mkldnn_eltwise_square,
269 eltwise_abs = mkldnn_eltwise_abs,
270 eltwise_sqrt = mkldnn_eltwise_sqrt,
271 eltwise_linear = mkldnn_eltwise_linear,
272 eltwise_bounded_relu = mkldnn_eltwise_bounded_relu,
273 eltwise_soft_relu = mkldnn_eltwise_soft_relu,
274 eltwise_logistic = mkldnn_eltwise_logistic,
275 eltwise_clamp = mkldnn_eltwise_clamp,
276 eltwise_exp = mkldnn_eltwise_exp,
277 eltwise_not = mkldnn_eltwise_not,
278 depthwise_scale_shift = mkldnn_depthwise_scale_shift,
279 depthwise_prelu = mkldnn_depthwise_prelu,
280 lrn_across_channels = mkldnn_lrn_across_channels,
281 lrn_within_channel = mkldnn_lrn_within_channel,
282 pooling_max = mkldnn_pooling_max,
283 pooling_avg = mkldnn_pooling_avg,
284 pooling_avg_include_padding = mkldnn_pooling_avg_include_padding,
285 pooling_avg_exclude_padding = mkldnn_pooling_avg_exclude_padding,
286 vanilla_rnn = mkldnn_vanilla_rnn,
287 vanilla_lstm = mkldnn_vanilla_lstm,
288 vanilla_gru = mkldnn_vanilla_gru,
289 gru_linear_before_reset = mkldnn_gru_linear_before_reset,
290 roi_pooling_max = mkldnn_roi_pooling_max,
291 roi_pooling_bilinear = mkldnn_roi_pooling_bilinear,
292 binary_convolution_direct = mkldnn_binary_convolution_direct,
293 binarization_depthwise = mkldnn_binarization_depthwise,
296 inline mkldnn_alg_kind_t convert_to_c(algorithm aalgorithm) {
297 return static_cast<mkldnn_alg_kind_t>(aalgorithm);
300 enum batch_normalization_flag {
301 use_global_stats = mkldnn_use_global_stats,
302 use_scale_shift = mkldnn_use_scaleshift,
303 fuse_bn_relu = mkldnn_fuse_bn_relu
306 inline mkldnn_batch_normalization_flag_t convert_to_c(
307 batch_normalization_flag aflag) {
308 return static_cast<mkldnn_batch_normalization_flag_t>(aflag);
312 unidirectional_left2right = mkldnn_unidirectional_left2right,
313 unidirectional_right2left = mkldnn_unidirectional_right2left,
314 unidirectional = mkldnn_unidirectional,
315 bidirectional_concat = mkldnn_bidirectional_concat,
316 bidirectional_sum = mkldnn_bidirectional_sum,
319 inline mkldnn_rnn_direction_t convert_to_c(rnn_direction adir) {
320 return static_cast<mkldnn_rnn_direction_t>(adir);
324 undef = mkldnn_query_undef,
326 eengine = mkldnn_query_engine,
327 primitive_kind = mkldnn_query_primitive_kind,
329 num_of_inputs_s32 = mkldnn_query_num_of_inputs_s32,
330 num_of_outputs_s32 = mkldnn_query_num_of_outputs_s32,
332 time_estimate_f64 = mkldnn_query_time_estimate_f64,
333 memory_consumption_s64 = mkldnn_query_memory_consumption_s64,
335 impl_info_str = mkldnn_query_impl_info_str,
337 op_d = mkldnn_query_op_d,
338 memory_d = mkldnn_query_memory_d,
339 convolution_d = mkldnn_query_convolution_d,
340 deconvolution_d = mkldnn_query_deconvolution_d,
341 shuffle_d = mkldnn_query_shuffle_d,
342 eltwise_d = mkldnn_query_eltwise_d,
343 depthwise_d = mkldnn_query_depthwise_d,
344 softmax_d = mkldnn_query_softmax_d,
345 pooling_d = mkldnn_query_pooling_d,
346 lrn_d = mkldnn_query_lrn_d,
347 batch_normalization_d = mkldnn_query_batch_normalization_d,
348 inner_product_d = mkldnn_query_inner_product_d,
349 rnn_d = mkldnn_query_rnn_d,
350 binary_convolution_d = mkldnn_query_binary_convolution_d,
351 binarization_d = mkldnn_query_binarization_d,
353 input_pd = mkldnn_query_input_pd,
354 output_pd = mkldnn_query_output_pd,
355 src_pd = mkldnn_query_src_pd,
356 diff_src_pd = mkldnn_query_diff_src_pd,
357 weights_pd = mkldnn_query_weights_pd,
358 diff_weights_pd = mkldnn_query_diff_weights_pd,
359 dst_pd = mkldnn_query_dst_pd,
360 diff_dst_pd = mkldnn_query_diff_dst_pd,
361 workspace_pd = mkldnn_query_workspace_pd,
364 inline mkldnn_query_t convert_to_c(query aquery) {
365 return static_cast<mkldnn_query_t>(aquery);
370 /// @addtogroup cpp_api_attr Attributes
371 /// An extension for controlling primitive behavior.
373 /// @sa @ref c_api_attributes in @ref c_api
376 #ifndef DOXYGEN_SHOULD_SKIP_THIS
377 template <> struct handle_traits<mkldnn_post_ops_t> {
378 static constexpr auto destructor = &mkldnn_post_ops_destroy;
382 struct post_ops: public handle<mkldnn_post_ops_t> {
384 mkldnn_post_ops_t result;
385 error::wrap_c_api(mkldnn_post_ops_create(&result),
386 "could not create post operation sequence");
390 int len() const { return mkldnn_post_ops_len(get()); }
392 primitive::kind kind(int index) const {
394 index < len() ? mkldnn_success : mkldnn_invalid_arguments,
395 "post_ops index is out of range");
396 return static_cast<primitive::kind>(mkldnn_post_ops_get_kind(get(),
400 void append_sum(float scale = 1.) {
401 error::wrap_c_api(mkldnn_post_ops_append_sum(get(), scale),
402 "could not append sum");
405 void get_params_sum(int index, float &scale) const {
406 error::wrap_c_api(mkldnn_post_ops_get_params_sum(get(), index, &scale),
407 "could not get sum params");
410 void append_eltwise(float scale, algorithm alg, float alpha,
412 error::wrap_c_api(mkldnn_post_ops_append_eltwise(get(), scale,
413 convert_to_c(alg), alpha, beta),
414 "could not append eltwise");
417 void get_params_eltwise(int index, float &scale, algorithm &alg,
418 float &alpha, float &beta) const {
419 mkldnn_alg_kind_t c_alg;
420 error::wrap_c_api(mkldnn_post_ops_get_params_eltwise(get(), index,
421 &scale, &c_alg, &alpha, &beta),
422 "could not get eltwise params");
423 alg = static_cast<algorithm>(c_alg);
426 void append_depthwise(algorithm alg, const float* weights_data,
427 const float* biases_data) {
428 error::wrap_c_api(mkldnn_post_ops_append_depthwise(get(),
429 convert_to_c(alg), weights_data, biases_data),
430 "could not append depthwise");
433 void get_params_depthwise(int index, algorithm &alg,
434 const float** weights_data, const float** biases_data) const {
435 mkldnn_alg_kind_t c_alg;
436 error::wrap_c_api(mkldnn_post_ops_get_params_depthwise(get(), index,
437 &c_alg, weights_data, biases_data),
438 "could not get depthwise params");
439 alg = static_cast<algorithm>(c_alg);
442 void append_dw_conv(int in_h, int in_w, int ker_h, int ker_w, int str_h, int str_w,
443 const float* weights_data, const float* biases_data) {
444 error::wrap_c_api(mkldnn_post_ops_append_dw_conv(get(),
445 in_h, in_w, ker_h, ker_w, str_h, str_w, weights_data, biases_data),
446 "could not append dw conv");
449 void get_params_dw_conv(int index, int &in_h, int &in_w, int &ker_h, int &ker_w, int &str_h, int &str_w,
450 const float** weights_data, const float** biases_data) const {
451 error::wrap_c_api(mkldnn_post_ops_get_params_dw_conv(get(), index,
452 &in_h, &in_w, &ker_h, &ker_w, &str_h, &str_w, weights_data, biases_data),
453 "could not get dw conv params");
456 void append_binarization(algorithm alg, const float* weights_data, const float* output_mask) {
457 error::wrap_c_api(mkldnn_post_ops_append_binarization(get(), convert_to_c(alg), weights_data, output_mask),
458 "could not append binarization");
461 void get_params_binarization(int index, algorithm &alg, const float** weights_data, const float** output_mask) const {
462 mkldnn_alg_kind_t c_alg;
463 error::wrap_c_api(mkldnn_post_ops_get_params_binarization(get(), index, &c_alg, weights_data, output_mask),
464 "could not get binarization params");
465 alg = static_cast<algorithm>(c_alg);
469 #ifndef DOXYGEN_SHOULD_SKIP_THIS
470 template <> struct handle_traits<mkldnn_primitive_attr_t> {
471 static constexpr auto destructor = &mkldnn_primitive_attr_destroy;
475 struct primitive_attr: public handle<mkldnn_primitive_attr_t> {
477 mkldnn_primitive_attr_t result;
478 error::wrap_c_api(mkldnn_primitive_attr_create(&result),
479 "could not create a primitive attr");
483 round_mode get_int_output_round_mode() const {
484 mkldnn_round_mode_t result;
485 error::wrap_c_api(mkldnn_primitive_attr_get_int_output_round_mode(
486 get(), &result), "could not get int output round mode");
487 return round_mode(result);
490 void set_int_output_round_mode(round_mode mode) {
491 error::wrap_c_api(mkldnn_primitive_attr_set_int_output_round_mode(
492 get(), mkldnn::convert_to_c(mode)),
493 "could not set int output round mode");
496 void get_output_scales(int &mask, std::vector<float> &scales) const
499 const float *c_scales;
500 error::wrap_c_api(mkldnn_primitive_attr_get_output_scales(get(),
501 &count, &c_mask, &c_scales),
502 "could not get int output scales");
503 scales.resize(count);
506 for (int c = 0; c < count; ++c)
507 scales[c] = c_scales[c];
510 void set_output_scales(int mask, const std::vector<float> &scales)
512 error::wrap_c_api(mkldnn_primitive_attr_set_output_scales(get(),
513 (int)scales.size(), mask, &scales[0]),
514 "could not set int output scales");
517 const post_ops get_post_ops() const {
519 const_mkldnn_post_ops_t c_result;
520 error::wrap_c_api(mkldnn_primitive_attr_get_post_ops(get(), &c_result),
521 "could not get post operation sequence");
522 result.reset(const_cast<mkldnn_post_ops_t>(c_result), true);
526 void set_post_ops(post_ops ops) {
527 error::wrap_c_api(mkldnn_primitive_attr_set_post_ops(get(), ops.get()),
528 "could not set post operation sequence");
531 void set_rnn_data_qparams(const float scale, const float shift)
533 error::wrap_c_api(mkldnn_primitive_attr_set_rnn_data_qparams(get(),
534 scale, shift), "could not set rnn data int scale/shift");
537 void set_rnn_weights_qparams(int mask, const std::vector<float> &scales)
539 error::wrap_c_api(mkldnn_primitive_attr_set_rnn_weights_qparams(get(),
540 (int)scales.size(), mask, &scales[0]),
541 "could not set rnn weights int scales");
547 /// @addtogroup cpp_api_engine Engine
548 /// Engine operations.
550 /// @sa @ref c_api_engine in @ref c_api
553 #ifndef DOXYGEN_SHOULD_SKIP_THIS
554 template <> struct handle_traits<mkldnn_engine_t> {
555 static constexpr auto destructor = &mkldnn_engine_destroy;
559 /// An execution engine.
560 struct engine: public handle<mkldnn_engine_t> {
561 friend class primitive;
562 // gcc bug??? using handle::handle;
564 /// Kinds of engines.
566 /// An unspecified engine
567 any = mkldnn_any_engine,
572 /// Returns the number of engines of a certain kind.
574 /// @param akind The kind of engines to count.
576 static size_t get_count(kind akind) {
577 return mkldnn_engine_get_count(convert_to_c(akind));
580 /// Constructs an engine.
582 /// @param akind The kind of engine to construct.
583 /// @param index The index of the engine. Must be less than the value
584 /// returned by #get_count() for this particular kind of engine.
586 engine(kind akind, size_t index) {
587 mkldnn_engine_t aengine;
589 mkldnn_engine_create(&aengine,
590 convert_to_c(akind), index),
591 "could not create an engine");
595 explicit engine(const mkldnn_engine_t& aengine)
596 : handle(aengine, true) {}
598 engine(const handle<mkldnn_primitive_desc_t> &pd) {
599 mkldnn_engine_t engine_q;
601 mkldnn_primitive_desc_query(pd.get(),
602 mkldnn::convert_to_c(eengine), 0, &engine_q),
603 "could not get engine from primitive_desc");
604 reset(engine_q, true);
607 template <class primitive_desc>
608 static engine query(const primitive_desc &pd) {
609 mkldnn_engine_t engine_q;
611 mkldnn_primitive_desc_query(pd.get(),
612 mkldnn::convert_to_c(eengine), 0, &engine_q),
613 "could not get engine from primitive_desc");
615 return engine(engine_q);
619 static mkldnn_engine_kind_t convert_to_c(kind akind) {
620 return static_cast<mkldnn_engine_kind_t>(akind);
626 /// @addtogroup cpp_api_memory_related Memory and memory related operations
629 /// @addtogroup cpp_api_memory Memory
630 /// A primitive to describe and store data.
632 /// For more information, refer to @ref c_api_memory in @ref c_api.
635 /// Memory primitive that describes the data.
636 struct memory: public primitive {
638 std::shared_ptr<char> _handle;
641 typedef std::vector<std::remove_extent<mkldnn_dims_t>::type> dims;
643 template <typename T> static void validate_dims(std::vector<T> v) {
644 if (v.size() > TENSOR_MAX_DIMS)
645 throw error(mkldnn_invalid_arguments,
646 "invalid dimensions");
649 /// Data type specification. See #mkldnn_data_type_t for a detailed
652 data_undef = mkldnn_data_type_undef,
661 /// Memory format specification. See #mkldnn_memory_format_t
662 /// for a detailed description.
664 format_undef = mkldnn_format_undef,
666 blocked = mkldnn_blocked,
671 nCw16c = mkldnn_nCw16c,
675 nCw4c = mkldnn_nCw4c,
676 nCw8c = mkldnn_nCw8c,
677 nChw4c = mkldnn_nChw4c,
678 nChw8c = mkldnn_nChw8c,
679 nChw16c = mkldnn_nChw16c,
680 ncdhw = mkldnn_ncdhw,
681 ndhwc = mkldnn_ndhwc,
682 nCdhw4c = mkldnn_nCdhw4c,
683 nCdhw8c = mkldnn_nCdhw8c,
684 nCdhw16c = mkldnn_nCdhw16c,
689 Owi4o = mkldnn_Owi4o,
690 OIw4i4o = mkldnn_OIw4i4o,
691 Owi8o = mkldnn_Owi8o,
692 OIw8o8i = mkldnn_OIw8o8i,
693 OIw8i8o = mkldnn_OIw8i8o,
694 OIw16i16o = mkldnn_OIw16i16o,
695 OIw16o16i = mkldnn_OIw16o16i,
696 Oiw4o = mkldnn_Oiw4o,
697 Oiw16o = mkldnn_Oiw16o,
698 Owi16o = mkldnn_Owi16o,
699 OIw8i16o2i = mkldnn_OIw8i16o2i,
700 OIw8o16i2o = mkldnn_OIw8o16i2o,
701 IOw16o16i = mkldnn_IOw16o16i,
706 hwio_s8s8 = mkldnn_hwio_s8s8,
707 dhwio = mkldnn_dhwio,
708 oidhw = mkldnn_oidhw,
709 OIdhw4i4o = mkldnn_OIdhw4i4o,
710 Odhwi4o = mkldnn_Odhwi4o,
711 OIdhw8i8o = mkldnn_OIdhw8i8o,
712 OIdhw8o8i = mkldnn_OIdhw8o8i,
713 Odhwi8o = mkldnn_Odhwi8o,
714 OIdhw16i16o = mkldnn_OIdhw16i16o,
715 OIdhw16o16i = mkldnn_OIdhw16o16i,
716 Oidhw4o = mkldnn_Oidhw4o,
717 Oidhw16o = mkldnn_Oidhw16o,
718 Odhwi16o = mkldnn_Odhwi16o,
719 oIhw8i = mkldnn_oIhw8i,
720 oIhw16i = mkldnn_oIhw16i,
721 oIdhw8i = mkldnn_oIdhw8i,
722 oIdhw16i = mkldnn_oIdhw16i,
723 OIhw4i4o = mkldnn_OIhw4i4o,
724 OIhw8i8o = mkldnn_OIhw8i8o,
725 OIhw16i16o = mkldnn_OIhw16i16o,
726 OIhw8o8i = mkldnn_OIhw8o8i,
727 OIhw16o16i = mkldnn_OIhw16o16i,
728 IOhw16o16i = mkldnn_IOhw16o16i,
729 OIhw8i16o2i = mkldnn_OIhw8i16o2i,
730 OIdhw8i16o2i = mkldnn_OIdhw8i16o2i,
731 OIhw8o16i2o = mkldnn_OIhw8o16i2o,
732 OIhw4i16o4i = mkldnn_OIhw4i16o4i,
733 OIhw4i16o4i_s8s8 = mkldnn_OIhw4i16o4i_s8s8,
734 Oihw8o = mkldnn_Oihw8o,
735 Oihw4o = mkldnn_Oihw4o,
736 Oihw16o = mkldnn_Oihw16o,
737 Ohwi8o = mkldnn_Ohwi8o,
738 Ohwi4o = mkldnn_Ohwi4o,
739 Ohwi16o = mkldnn_Ohwi16o,
740 OhIw16o4i = mkldnn_OhIw16o4i,
741 OhIw8o4i = mkldnn_OhIw8o4i,
742 OhIw8o32i = mkldnn_OhIw8o32i,
743 OhIw16o32i = mkldnn_OhIw16o32i,
744 OhIw8o4i_s8s8 = mkldnn_OhIw8o4i_s8s8,
746 gOwi4o = mkldnn_gOwi4o,
747 gOIw4i4o = mkldnn_gOIw4i4o,
748 gOwi8o = mkldnn_gOwi8o,
749 gOIw8o8i = mkldnn_gOIw8o8i,
750 gOIw8i8o = mkldnn_gOIw8i8o,
751 gOIw16i16o = mkldnn_gOIw16i16o,
752 gOIw16o16i = mkldnn_gOIw16o16i,
753 gOiw4o = mkldnn_gOiw4o,
754 gOiw16o = mkldnn_gOiw16o,
755 gOwi16o = mkldnn_gOwi16o,
756 gOIw8i16o2i = mkldnn_gOIw8i16o2i,
757 gIOw16o16i = mkldnn_gIOw16o16i,
758 gOIw8o16i2o = mkldnn_gOIw8o16i2o,
759 goihw = mkldnn_goihw,
760 hwigo = mkldnn_hwigo,
761 giohw = mkldnn_giohw,
762 hwigo_s8s8 = mkldnn_hwigo_s8s8,
763 gOIdhw4i4o = mkldnn_gOIdhw4i4o,
764 gOdhwi4o = mkldnn_gOdhwi4o,
765 gOIdhw8i8o = mkldnn_gOIdhw8i8o,
766 gOIdhw8o8i = mkldnn_gOIdhw8o8i,
767 gOdhwi8o = mkldnn_gOdhwi8o,
768 gOIhw4i4o = mkldnn_gOIhw4i4o,
769 gOIhw8i8o = mkldnn_gOIhw8i8o,
770 gOIhw16i16o = mkldnn_gOIhw16i16o,
771 gOIhw8i16o2i = mkldnn_gOIhw8i16o2i,
772 gOIdhw8i16o2i = mkldnn_gOIdhw8i16o2i,
773 gOIhw8o16i2o = mkldnn_gOIhw8o16i2o,
774 gOIhw4i16o4i = mkldnn_gOIhw4i16o4i,
775 gOIhw4i16o4i_s8s8 = mkldnn_gOIhw4i16o4i_s8s8,
776 gOIhw2i8o4i = mkldnn_gOIhw2i8o4i,
777 gOIhw2i8o4i_s8s8 = mkldnn_gOIhw2i8o4i_s8s8,
778 gOihw8o = mkldnn_gOihw8o,
779 gOihw4o = mkldnn_gOihw4o,
780 gOihw16o = mkldnn_gOihw16o,
781 gOhwi4o = mkldnn_gOhwi4o,
782 gOhwi8o = mkldnn_gOhwi8o,
783 gOhwi16o = mkldnn_gOhwi16o,
784 Goihw8g = mkldnn_Goihw8g,
785 Goihw16g = mkldnn_Goihw16g,
786 Goihw16g_s8s8 = mkldnn_Goihw16g_s8s8,
787 gOIhw4o4i = mkldnn_gOIhw4o4i,
788 gOIhw4o4i_s8s8 = mkldnn_gOIhw4o4i_s8s8,
789 gOIhw8o8i = mkldnn_gOIhw8o8i,
790 gOIhw16o16i = mkldnn_gOIhw16o16i,
791 gIOhw16o16i = mkldnn_gIOhw16o16i,
792 gOhIw16o4i = mkldnn_gOhIw16o4i,
793 gOhIw8o4i = mkldnn_gOhIw8o4i,
794 gOhIw8o4i_s8s8 = mkldnn_gOhIw8o4i_s8s8,
795 goidhw = mkldnn_goidhw,
796 gOIdhw16i16o = mkldnn_gOIdhw16i16o,
797 gOIdhw16o16i = mkldnn_gOIdhw16o16i,
798 gOidhw4o = mkldnn_gOidhw4o,
799 gOidhw16o = mkldnn_gOidhw16o,
800 gOdhwi16o = mkldnn_gOdhwi16o,
803 ldsnc = mkldnn_ldsnc,
804 ldigo = mkldnn_ldigo,
805 ldgoi = mkldnn_ldgoi,
807 rnn_packed = mkldnn_rnn_packed,
808 wino_fmt = mkldnn_wino_fmt,
809 format_last = mkldnn_format_last,
812 /// A memory descriptor.
814 friend struct memory;
815 /// The underlying C API data structure.
816 mkldnn_memory_desc_t data;
818 /// Constructs a memory descriptor.
820 /// @param adims Data dimensions
821 /// @param adata_type Data precision/type.
822 /// @param aformat Data layout format.
823 desc(dims adims, data_type adata_type,
825 validate_dims(adims);
827 mkldnn_memory_desc_init(&data, (int)adims.size(),
828 adims.size() == 0 ? nullptr : &adims[0],
829 convert_to_c(adata_type), convert_to_c(aformat)),
830 "could not initialize a memory descriptor");
833 /// Constructs a memory descriptor from a C API data structure.
835 /// @param adata A C API #mkldnn_memory_desc_t structure.
836 desc(const mkldnn_memory_desc_t &adata): data(adata) {}
839 /// A memory primitive descriptor.
840 struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
841 friend struct memory;
843 // TODO: make private
846 /// Constructs a memory primitive descriptor.
847 primitive_desc(const desc &adesc, const engine &aengine) {
848 mkldnn_primitive_desc_t result;
850 mkldnn_memory_primitive_desc_create(&result,
851 &adesc.data, aengine.get()),
852 "could not initialize a memory primitive descriptor");
856 /// Returns the memory primitive descriptor.
857 memory::desc desc() {
858 auto memory_d = mkldnn_primitive_desc_query_memory_d(get());
859 return memory::desc(*memory_d); }
861 /// Returns the number of bytes required to allocate the memory described
862 /// including the padding area.
863 size_t get_size() const {
864 return mkldnn_memory_primitive_desc_get_size(get());
867 bool operator==(const primitive_desc &other) const {
868 return (0 == mkldnn_memory_primitive_desc_equal(get(),
869 other.get())) ? false : true;
872 bool operator!=(const primitive_desc &other) const {
873 return !operator==(other);
876 engine get_engine() { return engine::query(*this); }
879 /// Constructs a memory primitive from a generic primitive.
881 /// @param aprimitive The primitive to treat as memory.
882 memory(const primitive &aprimitive): primitive(aprimitive) {}
883 /// Constructs a memory primitive.
885 /// @param adesc Memory primitive descriptor.
886 memory(const primitive_desc &adesc) {
887 mkldnn_primitive_t result;
889 mkldnn_primitive_create(&result, adesc.get(), nullptr, nullptr),
890 "could not create a memory primitive");
892 auto _malloc = [](size_t size, int alignment) {
895 ptr = _aligned_malloc(size, alignment);
896 int rc = ((ptr)? 0 : errno);
898 int rc = ::posix_memalign(&ptr, alignment, size);
900 return (rc == 0) ? (char*)ptr : nullptr;
902 auto _free = [](char* p) {
904 _aligned_free((void*)p);
909 _handle.reset(_malloc(adesc.get_size(), 4096), _free);
910 set_data_handle(_handle.get());
913 memory(const primitive_desc &adesc, void *ahandle) {
914 mkldnn_primitive_t result;
916 mkldnn_primitive_create(&result, adesc.get(), nullptr, nullptr),
917 "could not create a memory primitive");
919 set_data_handle(ahandle);
922 /// Returns the descriptor of the memory primitive.
923 primitive_desc get_primitive_desc() const {
924 primitive_desc adesc;
925 const_mkldnn_primitive_desc_t cdesc;
926 error::wrap_c_api(mkldnn_primitive_get_primitive_desc(get(),
928 "could not get primitive descriptor from a memory primitive");
929 /* FIXME: no const_cast should be here */
930 adesc.reset(const_cast<mkldnn_primitive_desc_t>(cdesc), true);
934 /// Returns a handle of the data contained in the memory primitive. On
935 /// the CPU engine, this is a pointer to the allocated memory.
936 inline void *get_data_handle() const {
938 error::wrap_c_api(mkldnn_memory_get_data_handle(get(), &handle),
939 "could not get native handle");
943 inline void set_data_handle(void *handle) const {
944 error::wrap_c_api(mkldnn_memory_set_data_handle(get(), handle),
945 "could not set native handle");
948 // Must go away or be private:
949 static mkldnn_data_type_t convert_to_c(data_type adata_type) {
950 return static_cast<mkldnn_data_type_t>(adata_type);
952 static mkldnn_memory_format_t convert_to_c(format aformat) {
953 return static_cast<mkldnn_memory_format_t>(aformat);
957 inline memory::desc zero_md() {
958 auto zero = mkldnn_memory_desc_t();
959 zero.primitive_kind = mkldnn_memory;
960 return memory::desc(zero);
963 inline memory null_memory(engine eng) {
964 mkldnn::memory::desc zero = zero_md();
965 return memory({zero, eng}, nullptr);
968 inline void check_num_parameters(const const_mkldnn_primitive_desc_t
969 &aprimitive_desc, int n_inputs, int n_outputs,
970 const std::string &prim_name) {
971 const int n_inputs_expected = mkldnn_primitive_desc_query_s32(
972 aprimitive_desc, mkldnn_query_num_of_inputs_s32, 0);
973 const int n_outputs_expected = mkldnn_primitive_desc_query_s32(
974 aprimitive_desc, mkldnn_query_num_of_outputs_s32, 0);
975 if (n_outputs_expected > n_outputs ) {
976 std::string message = "could not create " + prim_name +
977 " primitive, not enought output parameters";
978 throw error(mkldnn_invalid_arguments, message, nullptr);
980 if (n_inputs_expected > n_inputs ) {
981 std::string message = "could not create " + prim_name +
982 " primitive, not enought input parameters";
983 throw error(mkldnn_invalid_arguments, message, nullptr);
988 inline bool is_null_memory(const const_mkldnn_primitive_t &aprimitive) {
989 const_mkldnn_primitive_desc_t aprimitive_pd;
990 mkldnn_primitive_get_primitive_desc(aprimitive, &aprimitive_pd);
991 const mkldnn_memory_desc_t *aprimitive_md = mkldnn_primitive_desc_query_memory_d(
994 return ((aprimitive_md != nullptr) && (aprimitive_md->ndims == 0));
997 inline bool operator==(mkldnn_data_type_t a, memory::data_type b) {
998 return a == memory::convert_to_c(b);
1000 inline bool operator!=(mkldnn_data_type_t a, memory::data_type b) {
1003 inline bool operator==(memory::data_type a, mkldnn_data_type_t b) {
1006 inline bool operator!=(memory::data_type a, mkldnn_data_type_t b) {
1010 inline bool operator==(mkldnn_memory_format_t a, memory::format b) {
1011 return a == memory::convert_to_c(b);
1013 inline bool operator!=(mkldnn_memory_format_t a, memory::format b) {
1016 inline bool operator==(memory::format a, mkldnn_memory_format_t b) {
1019 inline bool operator!=(memory::format a, mkldnn_memory_format_t b) {
1025 /// @addtogroup cpp_api_reorder Reorder
1026 /// A primitive to copy data between memory formats.
1028 /// @sa @ref c_api_reorder in @ref c_api
1031 struct reorder : public primitive {
1032 struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1033 primitive_desc(const memory::primitive_desc &input,
1034 const memory::primitive_desc &output) {
1035 mkldnn_primitive_desc_t result;
1036 error::wrap_c_api(mkldnn_reorder_primitive_desc_create(
1037 &result, input.get(), output.get()),
1038 "could not create a reorder primitive descriptor");
1042 primitive_desc(const memory::primitive_desc &input,
1043 const memory::primitive_desc &output,
1044 const primitive_attr &aattr) {
1045 mkldnn_primitive_desc_t result;
1046 error::wrap_c_api(mkldnn_reorder_primitive_desc_create_v2(
1047 &result, input.get(), output.get(), aattr.get()),
1048 "could not create a reorder primitive descriptor");
1052 engine get_engine() { return engine::query(*this); }
1055 reorder(const primitive_desc &aprimitive_desc,
1056 const primitive::at &input, const memory &output) {
1057 mkldnn_primitive_t result;
1058 mkldnn_primitive_at_t inputs[] = { input.data };
1059 const_mkldnn_primitive_t outputs[] = { output.get() };
1060 error::wrap_c_api(mkldnn_primitive_create(&result,
1061 aprimitive_desc.get(), inputs, outputs),
1062 "could not create a reorder primitive");
1066 reorder(const primitive::at &input, const memory &output) {
1067 auto input_mpd = memory(input).get_primitive_desc();
1068 auto output_mpd = output.get_primitive_desc();
1070 auto reorder_d = primitive_desc(input_mpd, output_mpd);
1072 mkldnn_primitive_t result;
1073 mkldnn_primitive_at_t inputs[] = { input.data };
1074 const_mkldnn_primitive_t outputs[] = { output.get() };
1075 error::wrap_c_api(mkldnn_primitive_create(&result,
1076 reorder_d.get(), inputs, outputs),
1077 "could not create a reorder primitive");
1084 /// @addtogroup cpp_api_view View
1085 /// A primitive to view on a memory.
1087 /// @sa mkldnn_view_primitive_desc_create in @ref c_api
1090 struct view : public primitive {
1091 struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1092 primitive_desc(const memory::primitive_desc &input, memory::dims dims,
1093 memory::dims offsets) {
1094 mkldnn_primitive_desc_t result;
1096 error::wrap_c_api(mkldnn_view_primitive_desc_create(
1097 &result, input.get(), &dims[0], &offsets[0]),
1098 "could not create a view primitive descriptor");
1102 memory::primitive_desc dst_primitive_desc() const {
1103 memory::primitive_desc adesc;
1104 mkldnn_primitive_desc_t cdesc;
1105 const_mkldnn_primitive_desc_t const_cdesc =
1106 mkldnn_primitive_desc_query_pd(get(),
1107 mkldnn::convert_to_c(dst_pd), 0);
1108 error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc,
1110 "could not clone a dst primitive descriptor");
1115 engine get_engine() { return engine::query(*this); }
1118 view(const primitive_desc &view_pd, primitive::at input) {
1119 mkldnn_primitive_t result;
1120 mkldnn_primitive_at_t inputs[] = { input.data };
1121 error::wrap_c_api(mkldnn_primitive_create(&result,
1122 view_pd.get(), inputs, nullptr),
1123 "could not create a view primitive");
1127 view(memory input, memory::dims dims, memory::dims offsets) {
1128 mkldnn_primitive_t result;
1129 primitive_desc view_pd(input.get_primitive_desc(), dims,
1131 mkldnn_primitive_at_t inputs[] = { primitive::at(input).data };
1132 error::wrap_c_api(mkldnn_primitive_create(&result,
1133 view_pd.get(), inputs, nullptr),
1134 "could not create a view primitive");
1141 /// @addtogroup cpp_api_concat Concat
1142 /// A primitive to concatenate data by arbitrary dimension.
1144 /// @sa @ref c_api_concat in @ref c_api
1147 struct concat : public primitive {
1148 struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1149 std::vector<const_mkldnn_primitive_desc_t> cpp_to_c(
1150 std::vector<memory::primitive_desc> inputs) {
1151 std::vector<const_mkldnn_primitive_desc_t> c_api_inputs;
1152 c_api_inputs.reserve(inputs.size());
1153 auto convert_to_c = [](memory::primitive_desc d) { return d.get(); };
1154 std::transform(inputs.begin(), inputs.end(),
1155 std::back_inserter(c_api_inputs), convert_to_c);
1156 return c_api_inputs;
1159 primitive_desc(const memory::desc &output, int concat_dimension,
1160 std::vector<memory::primitive_desc> inputs) {
1161 mkldnn_primitive_desc_t result;
1163 auto c_api_inputs = cpp_to_c(inputs);
1165 error::wrap_c_api(mkldnn_concat_primitive_desc_create(
1166 &result, &output.data, (int)c_api_inputs.size(),
1167 concat_dimension, &c_api_inputs[0]),
1168 "could not create a concat primitive descriptor");
1172 primitive_desc(int concat_dimension,
1173 std::vector<memory::primitive_desc> inputs) {
1174 mkldnn_primitive_desc_t result;
1176 auto c_api_inputs = cpp_to_c(inputs);
1178 error::wrap_c_api(mkldnn_concat_primitive_desc_create(
1179 &result, nullptr, (int)c_api_inputs.size(),
1180 concat_dimension, &c_api_inputs[0]),
1181 "could not create a concat primitive descriptor");
1185 memory::primitive_desc dst_primitive_desc() const {
1186 memory::primitive_desc adesc;
1187 mkldnn_primitive_desc_t cdesc;
1188 const_mkldnn_primitive_desc_t const_cdesc =
1189 mkldnn_primitive_desc_query_pd(get(),
1190 mkldnn::convert_to_c(dst_pd), 0);
1191 error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1192 "could not clone a dst primitive descriptor");
1197 engine get_engine() { return engine::query(*this); }
1200 concat(const primitive_desc &concat_pd,
1201 std::vector<primitive::at> &inputs, const memory &output) {
1202 mkldnn_primitive_t result;
1204 std::vector<mkldnn_primitive_at_t> p_inputs;
1205 for (size_t i = 0; i < inputs.size(); i++)
1206 p_inputs.push_back(inputs[i].data);
1207 const_mkldnn_primitive_t outputs[] = { output.get() };
1209 error::wrap_c_api(mkldnn_primitive_create(&result,
1210 concat_pd.get(), &p_inputs[0], outputs),
1211 "could not create a concat primitive");
1218 /// @addtogroup cpp_api_sum Sum
1219 /// A primitive to sum data.
1221 /// @sa @ref c_api_sum in @ref c_api
1224 struct sum : public primitive {
1225 struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1226 std::vector<const_mkldnn_primitive_desc_t> cpp_to_c(
1227 std::vector<memory::primitive_desc> inputs) {
1228 std::vector<const_mkldnn_primitive_desc_t> c_api_inputs;
1229 c_api_inputs.reserve(inputs.size());
1230 auto convert_to_c = [](memory::primitive_desc d) { return d.get();};
1231 std::transform(inputs.begin(), inputs.end(),
1232 std::back_inserter(c_api_inputs), convert_to_c);
1233 return c_api_inputs;
1236 primitive_desc(const memory::desc &output,
1237 const std::vector<float> &scales,
1238 std::vector<memory::primitive_desc> inputs) {
1239 mkldnn_primitive_desc_t result;
1241 auto c_api_inputs = cpp_to_c(inputs);
1244 scales.size() == inputs.size() ? mkldnn_success
1245 : mkldnn_invalid_arguments,
1246 "number of scales not equal to number of inputs");
1248 error::wrap_c_api(mkldnn_sum_primitive_desc_create(
1249 &result, &output.data, (int)c_api_inputs.size(),
1250 &scales[0], &c_api_inputs[0]),
1251 "could not create a sum primitive descriptor");
1255 primitive_desc(const std::vector<float> &scales,
1256 std::vector<memory::primitive_desc> inputs) {
1257 mkldnn_primitive_desc_t result;
1259 auto c_api_inputs = cpp_to_c(inputs);
1262 scales.size() == inputs.size() ? mkldnn_success
1263 : mkldnn_invalid_arguments,
1264 "number of scales not equal to number of inputs");
1266 error::wrap_c_api(mkldnn_sum_primitive_desc_create(
1267 &result, nullptr, (int)c_api_inputs.size(), &scales[0],
1269 "could not create a sum primitive descriptor");
1273 memory::primitive_desc dst_primitive_desc() const {
1274 memory::primitive_desc adesc;
1275 mkldnn_primitive_desc_t cdesc;
1276 const_mkldnn_primitive_desc_t const_cdesc =
1277 mkldnn_primitive_desc_query_pd(get(),
1278 mkldnn::convert_to_c(dst_pd), 0);
1279 error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc,
1281 "could not clone a dst primitive descriptor");
1286 engine get_engine() { return engine::query(*this); }
1289 sum(const primitive_desc &sum_pd,
1290 std::vector<primitive::at> &inputs, const memory &output) {
1291 mkldnn_primitive_t result;
1293 std::vector<mkldnn_primitive_at_t> p_inputs;
1294 for (size_t i = 0; i < inputs.size(); i++)
1295 p_inputs.push_back(inputs[i].data);
1296 const_mkldnn_primitive_t outputs[] = { output.get() };
1298 error::wrap_c_api(mkldnn_primitive_create(&result,
1299 sum_pd.get(), &p_inputs[0], outputs),
1300 "could not create a sum primitive");
1309 /// @addtogroup cpp_api_primitives Primitives
1312 /// @addtogroup cpp_api_primitive_descriptors Primitive descriptors
1315 /// A base class for all primitive descriptors.
1316 struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1317 primitive_desc(const_mkldnn_op_desc_t desc, const primitive_attr *attr,
1318 const engine &e, const_mkldnn_primitive_desc_t hint_fwd_pd) {
1319 mkldnn_primitive_desc_iterator_t iterator = nullptr;
1320 mkldnn_status_t status = mkldnn_primitive_desc_iterator_create_v2(
1321 &iterator, desc, attr ? attr->get() : nullptr, e.get(),
1323 error::wrap_c_api(status,
1324 "could not create a primitive descriptor iterator");
1325 pd_iterator.reset(iterator);
1329 engine get_engine() { return engine::query(*this); }
1331 primitive_attr get_primitive_attr() const {
1332 const_mkldnn_primitive_attr_t const_cattr;
1333 error::wrap_c_api(mkldnn_primitive_desc_get_attr(get(), &const_cattr),
1334 "could not get attributes");
1335 mkldnn_primitive_attr_t cattr;
1336 error::wrap_c_api(mkldnn_primitive_attr_clone(&cattr, const_cattr),
1337 "could not clone attributes");
1339 primitive_attr attr;
1344 /// Returns implementation name
1345 const char *impl_info_str() const {
1347 error::wrap_c_api(mkldnn_primitive_desc_query(get(),
1348 mkldnn_query_impl_info_str, 0, &res),
1349 "could not query implementation info string");
1353 /// Advances the next implementation for the given op descriptor.
1356 /// - @c true on success
1357 /// - @c false if the last implementation reached, and
1358 /// the primitive descriptor itself is kept unchanged
1360 mkldnn_status_t status = mkldnn_primitive_desc_iterator_next(
1362 if (status == mkldnn_iterator_ends) return false;
1363 error::wrap_c_api(status, "primitive descriptor iterator next failed");
1369 /// Queries and returns requested memory primitive descriptor.
1370 memory::primitive_desc query_mpd(query what, int idx = 0) const {
1371 std::vector<query> valid_w{input_pd, output_pd, src_pd, diff_src_pd,
1372 weights_pd, diff_weights_pd, dst_pd, diff_dst_pd, workspace_pd};
1373 if (!std::any_of(valid_w.cbegin(), valid_w.cend(),
1374 [=](query q) { return what == q; }))
1375 throw error(mkldnn_invalid_arguments, "invalid memory query");
1377 const_mkldnn_primitive_desc_t const_cdesc
1378 = mkldnn_primitive_desc_query_pd(get(),
1379 mkldnn::convert_to_c(what), idx);
1381 // TODO: is there a better way to inform about this?
1382 if (const_cdesc == nullptr)
1383 throw error(mkldnn_not_required, "queried memory is not required");
1385 mkldnn_primitive_desc_t cdesc;
1386 error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1387 "could not clone a memory primitive descriptor");
1389 memory::primitive_desc ret;
1394 // register specialized queries, e.g. src_primitive_desc()
1395 # define REG_QUERY_MPD(name, what, idx) \
1396 memory::primitive_desc name ## _primitive_desc() const \
1397 { return query_mpd(what ## _pd, idx); }
1400 handle<mkldnn_primitive_desc_iterator_t> pd_iterator;
1402 mkldnn_primitive_desc_t pd = mkldnn_primitive_desc_iterator_fetch(
1404 error::wrap_c_api(pd != nullptr ? mkldnn_success : mkldnn_runtime_error,
1405 "could not fetch a primitive descriptor from the iterator");
1412 /// @addtogroup cpp_api_convolution Convolution
1413 /// A primitive to compute convolution using different algorithms.
1415 /// @sa @ref c_api_convolution in @ref c_api
1418 struct convolution_forward: public primitive {
1420 mkldnn_convolution_desc_t data;
1421 desc(prop_kind aprop_kind, algorithm aalgorithm,
1422 const memory::desc &src_desc,
1423 const memory::desc &weights_desc,
1424 const memory::desc &bias_desc,
1425 const memory::desc &dst_desc,
1426 const memory::dims strides,
1427 const memory::dims padding_l,
1428 const memory::dims padding_r,
1429 const padding_kind apadding_kind) {
1430 memory::validate_dims(strides);
1431 memory::validate_dims(padding_l);
1432 memory::validate_dims(padding_r);
1433 error::wrap_c_api(mkldnn_convolution_forward_desc_init(&data,
1434 mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1435 &src_desc.data, &weights_desc.data, &bias_desc.data,
1436 &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
1437 mkldnn::convert_to_c(apadding_kind)),
1438 "could not create a convolution forward descriptor");
1440 desc(prop_kind aprop_kind, algorithm aalgorithm,
1441 const memory::desc &src_desc,
1442 const memory::desc &weights_desc,
1443 const memory::desc &dst_desc,
1444 const memory::dims strides,
1445 const memory::dims padding_l,
1446 const memory::dims padding_r,
1447 const padding_kind apadding_kind) {
1448 memory::validate_dims(strides);
1449 memory::validate_dims(padding_l);
1450 memory::validate_dims(padding_r);
1451 error::wrap_c_api(mkldnn_convolution_forward_desc_init(&data,
1452 mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1453 &src_desc.data, &weights_desc.data, nullptr,
1454 &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
1455 mkldnn::convert_to_c(apadding_kind)),
1456 "could not create a convolution forward descriptor");
1458 desc(prop_kind aprop_kind, algorithm aalgorithm,
1459 const memory::desc &src_desc,
1460 const memory::desc &weights_desc,
1461 const memory::desc &bias_desc,
1462 const memory::desc &dst_desc,
1463 const memory::dims strides,
1464 const memory::dims dilates,
1465 const memory::dims padding_l,
1466 const memory::dims padding_r,
1467 const padding_kind apadding_kind) {
1468 memory::validate_dims(strides);
1469 memory::validate_dims(dilates);
1470 memory::validate_dims(padding_l);
1471 memory::validate_dims(padding_r);
1473 mkldnn_dilated_convolution_forward_desc_init(&data,
1474 mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1475 &src_desc.data, &weights_desc.data, &bias_desc.data,
1476 &dst_desc.data, &strides[0], &dilates[0],
1477 &padding_l[0], &padding_r[0],
1478 mkldnn::convert_to_c(apadding_kind)),
1479 "could not create a dilated convolution forward descriptor");
1481 desc(prop_kind aprop_kind, algorithm aalgorithm,
1482 const memory::desc &src_desc,
1483 const memory::desc &weights_desc,
1484 const memory::desc &dst_desc,
1485 const memory::dims strides,
1486 const memory::dims dilates,
1487 const memory::dims padding_l,
1488 const memory::dims padding_r,
1489 const padding_kind apadding_kind) {
1490 memory::validate_dims(strides);
1491 memory::validate_dims(dilates);
1492 memory::validate_dims(padding_l);
1493 memory::validate_dims(padding_r);
1495 mkldnn_dilated_convolution_forward_desc_init(&data,
1496 mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1497 &src_desc.data, &weights_desc.data, nullptr,
1498 &dst_desc.data, &strides[0], &dilates[0],
1499 &padding_l[0], &padding_r[0],
1500 mkldnn::convert_to_c(apadding_kind)),
1501 "could not create a dilated convolution forward descriptor");
1505 struct primitive_desc : public mkldnn::primitive_desc {
1506 primitive_desc(const desc &desc, const engine &e)
1507 : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
1509 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
1510 : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
1512 REG_QUERY_MPD(src, src, 0);
1513 REG_QUERY_MPD(weights, weights, 0);
1514 REG_QUERY_MPD(bias, weights, 1);
1515 REG_QUERY_MPD(dst, dst, 0);
1518 convolution_forward(const primitive_desc &aprimitive_desc,
1519 const primitive::at &src, const primitive::at &weights,
1520 const primitive::at &bias, const memory &dst) {
1521 mkldnn_primitive_t result;
1522 mkldnn_primitive_at_t inputs[] = { src.data, weights.data,
1524 const_mkldnn_primitive_t outputs[] = { dst.get() };
1525 error::wrap_c_api(mkldnn_primitive_create(&result,
1526 aprimitive_desc.get(), inputs, outputs),
1527 "could not create a convolution forward bias primitive");
1531 convolution_forward(const primitive_desc &aprimitive_desc,
1532 const primitive::at &src, const primitive::at &weights,
1533 const memory &dst) {
1534 mkldnn_primitive_t result;
1535 mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
1536 const_mkldnn_primitive_t outputs[] = { dst.get() };
1537 check_num_parameters(aprimitive_desc.get(), 2, 1,
1538 "convolution forward");
1539 error::wrap_c_api(mkldnn_primitive_create(&result,
1540 aprimitive_desc.get(), inputs, outputs),
1541 "could not create a convolution forward primitive");
1546 struct convolution_backward_data : public primitive {
1548 mkldnn_convolution_desc_t data;
1549 desc(algorithm aalgorithm,
1550 const memory::desc &diff_src_desc,
1551 const memory::desc &weights_desc,
1552 const memory::desc &diff_dst_desc,
1553 const memory::dims strides,
1554 const memory::dims padding_l,
1555 const memory::dims padding_r,
1556 const padding_kind apadding_kind) {
1557 memory::validate_dims(strides);
1558 memory::validate_dims(padding_l);
1559 memory::validate_dims(padding_r);
1560 error::wrap_c_api(mkldnn_convolution_backward_data_desc_init(
1561 &data, convert_to_c(aalgorithm), &diff_src_desc.data,
1562 &weights_desc.data, &diff_dst_desc.data,
1563 &strides[0], &padding_l[0], &padding_r[0],
1564 mkldnn::convert_to_c(apadding_kind)),
1565 "could not create a convolution backward data descriptor");
1567 desc(algorithm aalgorithm,
1568 const memory::desc &diff_src_desc,
1569 const memory::desc &weights_desc,
1570 const memory::desc &diff_dst_desc,
1571 const memory::dims strides,
1572 const memory::dims dilates,
1573 const memory::dims padding_l,
1574 const memory::dims padding_r,
1575 const padding_kind apadding_kind) {
1576 memory::validate_dims(strides);
1577 memory::validate_dims(dilates);
1578 memory::validate_dims(padding_l);
1579 memory::validate_dims(padding_r);
1581 mkldnn_dilated_convolution_backward_data_desc_init(
1582 &data, convert_to_c(aalgorithm), &diff_src_desc.data,
1583 &weights_desc.data, &diff_dst_desc.data,
1584 &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1585 mkldnn::convert_to_c(apadding_kind)),
1586 "could not create a convolution backward data descriptor");
1590 struct primitive_desc : public mkldnn::primitive_desc {
1591 primitive_desc(const desc &desc, const engine &e,
1592 const convolution_forward::primitive_desc &hint_fwd_pd)
1593 : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
1595 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
1596 const convolution_forward::primitive_desc &hint_fwd_pd)
1597 : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
1599 REG_QUERY_MPD(diff_src, diff_src, 0);
1600 REG_QUERY_MPD(weights, weights, 0);
1601 REG_QUERY_MPD(diff_dst, diff_dst, 0);
1604 convolution_backward_data(const primitive_desc &aprimitive_desc,
1605 const primitive::at &diff_dst, const primitive::at &weights,
1606 const memory &diff_src) {
1607 mkldnn_primitive_t result;
1608 mkldnn_primitive_at_t inputs[] = { diff_dst.data, weights.data };
1609 const_mkldnn_primitive_t outputs[] = { diff_src.get() };
1610 check_num_parameters(aprimitive_desc.get(), 2, 1,
1611 "convolution backward data");
1612 error::wrap_c_api(mkldnn_primitive_create(&result,
1613 aprimitive_desc.get(), inputs, outputs),
1614 "could not create a convolution backward data primitive");
1619 struct convolution_backward_weights : public primitive {
1621 mkldnn_convolution_desc_t data;
1622 desc(algorithm aalgorithm,
1623 const memory::desc &src_desc,
1624 const memory::desc &diff_weights_desc,
1625 const memory::desc &diff_bias_desc,
1626 const memory::desc &diff_dst_desc,
1627 const memory::dims strides,
1628 const memory::dims padding_l,
1629 const memory::dims padding_r,
1630 const padding_kind apadding_kind) {
1631 memory::validate_dims(strides);
1632 memory::validate_dims(padding_l);
1633 memory::validate_dims(padding_r);
1634 error::wrap_c_api(mkldnn_convolution_backward_weights_desc_init(
1635 &data, convert_to_c(aalgorithm), &src_desc.data,
1636 &diff_weights_desc.data, &diff_bias_desc.data,
1637 &diff_dst_desc.data,
1638 &strides[0], &padding_l[0], &padding_r[0],
1639 mkldnn::convert_to_c(apadding_kind)),
1640 "could not create a convolution backward weights descriptor");
1642 desc(algorithm aalgorithm,
1643 const memory::desc &src_desc,
1644 const memory::desc &diff_weights_desc,
1645 const memory::desc &diff_dst_desc,
1646 const memory::dims strides,
1647 const memory::dims padding_l,
1648 const memory::dims padding_r,
1649 const padding_kind apadding_kind) {
1650 memory::validate_dims(strides);
1651 memory::validate_dims(padding_l);
1652 memory::validate_dims(padding_r);
1653 error::wrap_c_api(mkldnn_convolution_backward_weights_desc_init(
1654 &data, convert_to_c(aalgorithm), &src_desc.data,
1655 &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
1656 &strides[0], &padding_l[0], &padding_r[0],
1657 mkldnn::convert_to_c(apadding_kind)),
1658 "could not create a convolution backward weights descriptor");
1660 desc(algorithm aalgorithm,
1661 const memory::desc &src_desc,
1662 const memory::desc &diff_weights_desc,
1663 const memory::desc &diff_bias_desc,
1664 const memory::desc &diff_dst_desc,
1665 const memory::dims strides,
1666 const memory::dims dilates,
1667 const memory::dims padding_l,
1668 const memory::dims padding_r,
1669 const padding_kind apadding_kind) {
1670 memory::validate_dims(strides);
1671 memory::validate_dims(dilates);
1672 memory::validate_dims(padding_l);
1673 memory::validate_dims(padding_r);
1674 error::wrap_c_api(mkldnn_dilated_convolution_backward_weights_desc_init(
1675 &data, convert_to_c(aalgorithm), &src_desc.data,
1676 &diff_weights_desc.data, &diff_bias_desc.data,
1677 &diff_dst_desc.data,
1678 &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1679 mkldnn::convert_to_c(apadding_kind)),
1680 "could not create a convolution backward weights descriptor");
1682 desc(algorithm aalgorithm,
1683 const memory::desc &src_desc,
1684 const memory::desc &diff_weights_desc,
1685 const memory::desc &diff_dst_desc,
1686 const memory::dims strides,
1687 const memory::dims dilates,
1688 const memory::dims padding_l,
1689 const memory::dims padding_r,
1690 const padding_kind apadding_kind) {
1691 memory::validate_dims(strides);
1692 memory::validate_dims(dilates);
1693 memory::validate_dims(padding_l);
1694 memory::validate_dims(padding_r);
1695 error::wrap_c_api(mkldnn_dilated_convolution_backward_weights_desc_init(
1696 &data, convert_to_c(aalgorithm), &src_desc.data,
1697 &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
1698 &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1699 mkldnn::convert_to_c(apadding_kind)),
1700 "could not create a convolution backward weights descriptor");
1705 struct primitive_desc : public mkldnn::primitive_desc {
1706 primitive_desc(const desc &desc, const engine &e,
1707 const convolution_forward::primitive_desc &hint_fwd_pd)
1708 : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
1710 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
1711 const convolution_forward::primitive_desc &hint_fwd_pd)
1712 : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
1714 REG_QUERY_MPD(src, src, 0);
1715 REG_QUERY_MPD(diff_weights, diff_weights, 0);
1716 REG_QUERY_MPD(diff_bias, diff_weights, 1);
1717 REG_QUERY_MPD(diff_dst, diff_dst, 0);
1720 convolution_backward_weights(const primitive_desc &aprimitive_desc,
1721 const primitive::at &src, const primitive::at &diff_dst,
1722 const memory &diff_weights, const memory &diff_bias) {
1723 mkldnn_primitive_t result;
1724 mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
1725 const_mkldnn_primitive_t outputs[] = { diff_weights.get(),
1727 check_num_parameters(aprimitive_desc.get(), 2, 2,
1728 "convolution backward weights");
1729 error::wrap_c_api(mkldnn_primitive_create(&result,
1730 aprimitive_desc.get(), inputs, outputs),
1731 "could not create a convolution backward weights primitive");
1734 convolution_backward_weights(const primitive_desc &aprimitive_desc,
1735 const primitive::at &src, const primitive::at &diff_dst,
1736 const memory &diff_weights) {
1737 mkldnn_primitive_t result;
1738 mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
1739 const_mkldnn_primitive_t outputs[] = { diff_weights.get() };
1740 check_num_parameters(aprimitive_desc.get(), 2, 1,
1741 "convolution backward weights");
1742 error::wrap_c_api(mkldnn_primitive_create(&result,
1743 aprimitive_desc.get(), inputs, outputs),
1744 "could not create a convolution backward weights primitive");
1751 /// @addtogroup cpp_api_deconvolution Deconvolution
1752 /// A primitive to compute deconvolution using different algorithms.
1754 /// @sa @ref c_api_deconvolution in @ref c_api
1757 struct deconvolution_forward: public primitive {
1759 mkldnn_deconvolution_desc_t data;
1760 desc(prop_kind aprop_kind, algorithm aalgorithm,
1761 const memory::desc &src_desc,
1762 const memory::desc &weights_desc,
1763 const memory::desc &bias_desc,
1764 const memory::desc &dst_desc,
1765 const memory::dims strides,
1766 const memory::dims padding_l,
1767 const memory::dims padding_r,
1768 const padding_kind apadding_kind) {
1769 memory::validate_dims(strides);
1770 memory::validate_dims(padding_l);
1771 memory::validate_dims(padding_r);
1772 error::wrap_c_api(mkldnn_deconvolution_forward_desc_init(&data,
1773 mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1774 &src_desc.data, &weights_desc.data, &bias_desc.data,
1775 &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
1776 mkldnn::convert_to_c(apadding_kind)),
1777 "could not create a deconvolution forward descriptor");
1779 desc(prop_kind aprop_kind, algorithm aalgorithm,
1780 const memory::desc &src_desc,
1781 const memory::desc &weights_desc,
1782 const memory::desc &dst_desc,
1783 const memory::dims strides,
1784 const memory::dims padding_l,
1785 const memory::dims padding_r,
1786 const padding_kind apadding_kind) {
1787 memory::validate_dims(strides);
1788 memory::validate_dims(padding_l);
1789 memory::validate_dims(padding_r);
1790 error::wrap_c_api(mkldnn_deconvolution_forward_desc_init(&data,
1791 mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1792 &src_desc.data, &weights_desc.data, nullptr,
1793 &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
1794 mkldnn::convert_to_c(apadding_kind)),
1795 "could not create a deconvolution forward descriptor");
1797 desc(prop_kind aprop_kind, algorithm aalgorithm,
1798 const memory::desc &src_desc,
1799 const memory::desc &weights_desc,
1800 const memory::desc &bias_desc,
1801 const memory::desc &dst_desc,
1802 const memory::dims strides,
1803 const memory::dims dilates,
1804 const memory::dims padding_l,
1805 const memory::dims padding_r,
1806 const padding_kind apadding_kind) {
1807 memory::validate_dims(strides);
1808 memory::validate_dims(dilates);
1809 memory::validate_dims(padding_l);
1810 memory::validate_dims(padding_r);
1811 error::wrap_c_api(mkldnn_dilated_deconvolution_forward_desc_init(&data,
1812 mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1813 &src_desc.data, &weights_desc.data, &bias_desc.data,
1814 &dst_desc.data, &strides[0], &dilates[0], &padding_l[0],
1815 &padding_r[0], mkldnn::convert_to_c(apadding_kind)),
1816 "could not create a dilated deconvolution forward descriptor");
1818 desc(prop_kind aprop_kind, algorithm aalgorithm,
1819 const memory::desc &src_desc,
1820 const memory::desc &weights_desc,
1821 const memory::desc &dst_desc,
1822 const memory::dims strides,
1823 const memory::dims dilates,
1824 const memory::dims padding_l,
1825 const memory::dims padding_r,
1826 const padding_kind apadding_kind) {
1827 memory::validate_dims(strides);
1828 memory::validate_dims(dilates);
1829 memory::validate_dims(padding_l);
1830 memory::validate_dims(padding_r);
1831 error::wrap_c_api(mkldnn_dilated_deconvolution_forward_desc_init(&data,
1832 mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1833 &src_desc.data, &weights_desc.data, nullptr,
1834 &dst_desc.data, &strides[0], &dilates[0], &padding_l[0],
1835 &padding_r[0], mkldnn::convert_to_c(apadding_kind)),
1836 "could not create a dilated deconvolution forward descriptor");
1840 struct primitive_desc : public mkldnn::primitive_desc {
1841 primitive_desc(const desc &desc, const engine &e)
1842 : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
1844 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
1845 : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
1847 REG_QUERY_MPD(src, src, 0);
1848 REG_QUERY_MPD(weights, weights, 0);
1849 REG_QUERY_MPD(bias, weights, 1);
1850 REG_QUERY_MPD(dst, dst, 0);
1853 deconvolution_forward(const primitive_desc &aprimitive_desc,
1854 const primitive::at &src, const primitive::at &weights,
1855 const primitive::at &bias, const memory &dst) {
1856 mkldnn_primitive_t result;
1857 mkldnn_primitive_at_t inputs[] = { src.data, weights.data,
1859 const_mkldnn_primitive_t outputs[] = { dst.get() };
1860 check_num_parameters(aprimitive_desc.get(), 3, 1,
1861 "deconvolution forward");
1862 error::wrap_c_api(mkldnn_primitive_create(&result,
1863 aprimitive_desc.get(), inputs, outputs),
1864 "could not create a deconvolution forward bias primitive");
1868 deconvolution_forward(const primitive_desc &aprimitive_desc,
1869 const primitive::at &src, const primitive::at &weights,
1870 const memory &dst) {
1871 mkldnn_primitive_t result;
1872 mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
1873 const_mkldnn_primitive_t outputs[] = { dst.get() };
1874 check_num_parameters(aprimitive_desc.get(), 2, 1,
1875 "deconvolution forward");
1876 error::wrap_c_api(mkldnn_primitive_create(&result,
1877 aprimitive_desc.get(), inputs, outputs),
1878 "could not create a deconvolution forward primitive");
1883 struct deconvolution_backward_data : public primitive {
1885 mkldnn_deconvolution_desc_t data;
1886 desc(algorithm aalgorithm,
1887 const memory::desc &diff_src_desc,
1888 const memory::desc &weights_desc,
1889 const memory::desc &diff_dst_desc,
1890 const memory::dims strides,
1891 const memory::dims padding_l,
1892 const memory::dims padding_r,
1893 const padding_kind apadding_kind) {
1894 memory::validate_dims(strides);
1895 memory::validate_dims(padding_l);
1896 memory::validate_dims(padding_r);
1897 error::wrap_c_api(mkldnn_deconvolution_backward_data_desc_init(
1898 &data, convert_to_c(aalgorithm), &diff_src_desc.data,
1899 &weights_desc.data, &diff_dst_desc.data,
1900 &strides[0], &padding_l[0], &padding_r[0],
1901 mkldnn::convert_to_c(apadding_kind)),
1902 "could not create a deconvolution backward data descriptor");
1904 desc(algorithm aalgorithm,
1905 const memory::desc &diff_src_desc,
1906 const memory::desc &weights_desc,
1907 const memory::desc &diff_dst_desc,
1908 const memory::dims strides,
1909 const memory::dims dilates,
1910 const memory::dims padding_l,
1911 const memory::dims padding_r,
1912 const padding_kind apadding_kind) {
1913 memory::validate_dims(strides);
1914 memory::validate_dims(dilates);
1915 memory::validate_dims(padding_l);
1916 memory::validate_dims(padding_r);
1917 error::wrap_c_api(mkldnn_dilated_deconvolution_backward_data_desc_init(
1918 &data, convert_to_c(aalgorithm), &diff_src_desc.data,
1919 &weights_desc.data, &diff_dst_desc.data,
1920 &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1921 mkldnn::convert_to_c(apadding_kind)),
1922 "could not create a dilated deconvolution backward data descriptor");
1926 struct primitive_desc : public mkldnn::primitive_desc {
1927 primitive_desc(const desc &desc, const engine &e,
1928 const deconvolution_forward::primitive_desc &hint_fwd_pd)
1929 : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
1931 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
1932 const deconvolution_forward::primitive_desc &hint_fwd_pd)
1933 : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
1935 REG_QUERY_MPD(diff_src, diff_src, 0);
1936 REG_QUERY_MPD(weights, weights, 0);
1937 REG_QUERY_MPD(diff_dst, diff_dst, 0);
1940 deconvolution_backward_data(const primitive_desc &aprimitive_desc,
1941 const primitive::at &diff_dst, const primitive::at &weights,
1942 const memory &diff_src) {
1943 mkldnn_primitive_t result;
1944 mkldnn_primitive_at_t inputs[] = { diff_dst.data, weights.data };
1945 const_mkldnn_primitive_t outputs[] = { diff_src.get() };
1946 check_num_parameters(aprimitive_desc.get(), 2, 1,
1947 "deconvolution backward data");
1948 error::wrap_c_api(mkldnn_primitive_create(&result,
1949 aprimitive_desc.get(), inputs, outputs),
1950 "could not create a deconvolution backward data primitive");
1955 struct deconvolution_backward_weights : public primitive {
1957 mkldnn_deconvolution_desc_t data;
1958 desc(algorithm aalgorithm,
1959 const memory::desc &src_desc,
1960 const memory::desc &diff_weights_desc,
1961 const memory::desc &diff_bias_desc,
1962 const memory::desc &diff_dst_desc,
1963 const memory::dims strides,
1964 const memory::dims padding_l,
1965 const memory::dims padding_r,
1966 const padding_kind apadding_kind) {
1967 memory::validate_dims(strides);
1968 memory::validate_dims(padding_l);
1969 memory::validate_dims(padding_r);
1970 error::wrap_c_api(mkldnn_deconvolution_backward_weights_desc_init(
1971 &data, convert_to_c(aalgorithm), &src_desc.data,
1972 &diff_weights_desc.data, &diff_bias_desc.data,
1973 &diff_dst_desc.data,
1974 &strides[0], &padding_l[0], &padding_r[0],
1975 mkldnn::convert_to_c(apadding_kind)),
1976 "could not create a deconvolution backward weights descriptor");
1978 desc(algorithm aalgorithm,
1979 const memory::desc &src_desc,
1980 const memory::desc &diff_weights_desc,
1981 const memory::desc &diff_dst_desc,
1982 const memory::dims strides,
1983 const memory::dims padding_l,
1984 const memory::dims padding_r,
1985 const padding_kind apadding_kind) {
1986 memory::validate_dims(strides);
1987 memory::validate_dims(padding_l);
1988 memory::validate_dims(padding_r);
1989 error::wrap_c_api(mkldnn_deconvolution_backward_weights_desc_init(
1990 &data, convert_to_c(aalgorithm), &src_desc.data,
1991 &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
1992 &strides[0], &padding_l[0], &padding_r[0],
1993 mkldnn::convert_to_c(apadding_kind)),
1994 "could not create a deconvolution backward weights descriptor");
1996 desc(algorithm aalgorithm,
1997 const memory::desc &src_desc,
1998 const memory::desc &diff_weights_desc,
1999 const memory::desc &diff_bias_desc,
2000 const memory::desc &diff_dst_desc,
2001 const memory::dims strides,
2002 const memory::dims dilates,
2003 const memory::dims padding_l,
2004 const memory::dims padding_r,
2005 const padding_kind apadding_kind) {
2006 memory::validate_dims(strides);
2007 memory::validate_dims(dilates);
2008 memory::validate_dims(padding_l);
2009 memory::validate_dims(padding_r);
2010 error::wrap_c_api(mkldnn_dilated_deconvolution_backward_weights_desc_init(
2011 &data, convert_to_c(aalgorithm), &src_desc.data,
2012 &diff_weights_desc.data, &diff_bias_desc.data,
2013 &diff_dst_desc.data,
2014 &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
2015 mkldnn::convert_to_c(apadding_kind)),
2016 "could not create a dilated deconvolution backward weights descriptor");
2018 desc(algorithm aalgorithm,
2019 const memory::desc &src_desc,
2020 const memory::desc &diff_weights_desc,
2021 const memory::desc &diff_dst_desc,
2022 const memory::dims strides,
2023 const memory::dims dilates,
2024 const memory::dims padding_l,
2025 const memory::dims padding_r,
2026 const padding_kind apadding_kind) {
2027 memory::validate_dims(strides);
2028 memory::validate_dims(dilates);
2029 memory::validate_dims(padding_l);
2030 memory::validate_dims(padding_r);
2031 error::wrap_c_api(mkldnn_dilated_deconvolution_backward_weights_desc_init(
2032 &data, convert_to_c(aalgorithm), &src_desc.data,
2033 &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
2034 &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
2035 mkldnn::convert_to_c(apadding_kind)),
2036 "could not create a dilated deconvolution backward weights descriptor");
2040 struct primitive_desc : public mkldnn::primitive_desc {
2041 primitive_desc(const desc &desc, const engine &e,
2042 const deconvolution_forward::primitive_desc &hint_fwd_pd)
2043 : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2045 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2046 const deconvolution_forward::primitive_desc &hint_fwd_pd)
2047 : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2049 REG_QUERY_MPD(src, src, 0);
2050 REG_QUERY_MPD(diff_weights, diff_weights, 0);
2051 REG_QUERY_MPD(diff_bias, diff_weights, 1);
2052 REG_QUERY_MPD(diff_dst, diff_dst, 0);
2055 deconvolution_backward_weights(const primitive_desc &aprimitive_desc,
2056 const primitive::at &src, const primitive::at &diff_dst,
2057 const memory &diff_weights, const memory &diff_bias) {
2058 mkldnn_primitive_t result;
2059 mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
2060 const_mkldnn_primitive_t outputs[] = { diff_weights.get(),
2062 check_num_parameters(aprimitive_desc.get(), 2, 2,
2063 "deconvolution backward weights");
2064 error::wrap_c_api(mkldnn_primitive_create(&result,
2065 aprimitive_desc.get(), inputs, outputs),
2066 "could not create a deconvolution backward weights primitive");
2069 deconvolution_backward_weights(const primitive_desc &aprimitive_desc,
2070 const primitive::at &src, const primitive::at &diff_dst,
2071 const memory &diff_weights) {
2072 mkldnn_primitive_t result;
2073 mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
2074 const_mkldnn_primitive_t outputs[] = { diff_weights.get() };
2075 check_num_parameters(aprimitive_desc.get(), 2, 1,
2076 "deconvolution backward weights");
2077 error::wrap_c_api(mkldnn_primitive_create(&result,
2078 aprimitive_desc.get(), inputs, outputs),
2079 "could not create a deconvolution backward weights primitive");
2086 /// @addtogroup cpp_api_roi_pooling ROIPooling
2089 struct roi_pooling_forward : public primitive {
2091 mkldnn_roi_pooling_desc_t data;
2092 std::vector<mkldnn_memory_desc_t> c_api_inputs;
2094 desc(prop_kind aprop_kind, algorithm aalgorithm, std::vector<memory::desc> inputs,
2095 const memory::desc &dst_desc, int pooled_h, int pooled_w, double spatial_scale) {
2097 for(size_t i = 0; i < inputs.size(); i++) {
2098 c_api_inputs.push_back(inputs[i].data);
2101 error::wrap_c_api(mkldnn_roi_pooling_forward_desc_init(&data,
2102 mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), &c_api_inputs[0],
2103 c_api_inputs.size(),
2104 &dst_desc.data, pooled_h, pooled_w, spatial_scale),
2105 "could not create a roi pooling forward descriptor");
2109 struct primitive_desc : public handle<mkldnn_primitive_desc_t>{
2110 primitive_desc(const desc &adesc, const engine &aengine) {
2111 mkldnn_primitive_desc_t result;
2112 error::wrap_c_api(mkldnn_primitive_desc_create(
2113 &result, &adesc.data, aengine.get(), nullptr),
2114 "could not create a roi pooling forward primitive descriptor");
2119 roi_pooling_forward(const primitive_desc &aprimitive_desc,
2120 std::vector<primitive::at> &inputs, const memory &dst) {
2121 mkldnn_primitive_t result;
2123 std::vector<mkldnn_primitive_at_t> p_inputs;
2124 for (size_t i = 0; i < inputs.size(); i++) {
2125 p_inputs.push_back(inputs[i].data);
2128 const_mkldnn_primitive_t outputs[] = { dst.get() };
2129 error::wrap_c_api(mkldnn_primitive_create(&result,
2130 aprimitive_desc.get(), &p_inputs[0], outputs),
2131 "could not create a roi pooling forward primitive");
2138 /// @addtogroup cpp_api_lrn LRN
2139 /// A primitive to perform local response normalization (LRN) across or within
2142 /// @sa @ref c_api_lrn in @ref c_api
2145 struct lrn_forward : public primitive {
2147 mkldnn_lrn_desc_t data;
2148 desc(prop_kind aprop_kind, algorithm aalgorithm,
2149 const memory::desc &src_desc,
2150 int local_size, float alpha, float beta, float k)
2152 error::wrap_c_api(mkldnn_lrn_forward_desc_init(&data,
2153 mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
2154 &src_desc.data, local_size, alpha, beta, k),
2155 "could not create a lrn forward descriptor");
2157 desc(prop_kind aprop_kind, algorithm aalgorithm,
2158 const memory::desc &src_desc,
2159 int local_size, float alpha, float beta)
2161 error::wrap_c_api(mkldnn_lrn_forward_desc_init(&data,
2162 mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
2163 &src_desc.data, local_size, alpha, beta, float(1.0)),
2164 "could not create a lrn forward descriptor");
2168 struct primitive_desc : public mkldnn::primitive_desc {
2169 primitive_desc(const desc &desc, const engine &e)
2170 : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2172 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2173 : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2175 REG_QUERY_MPD(src, src, 0);
2176 REG_QUERY_MPD(dst, dst, 0);
2177 REG_QUERY_MPD(workspace, workspace, 0);
2180 lrn_forward(const primitive_desc &aprimitive_desc,
2181 const primitive::at &src, const memory &workspace,
2182 const memory &dst) {
2183 mkldnn_primitive_t result;
2184 mkldnn_primitive_at_t inputs[] = { src.data };
2185 const_mkldnn_primitive_t outputs[] = { dst.get(),
2187 check_num_parameters(aprimitive_desc.get(), 1, 2, "lrn forward");
2188 error::wrap_c_api(mkldnn_primitive_create(&result,
2189 aprimitive_desc.get(), inputs, outputs),
2190 "could not create a lrn forward primitive");
2194 lrn_forward(const primitive_desc &aprimitive_desc,
2195 const primitive::at &src, const memory &dst) {
2196 mkldnn_primitive_t result;
2197 mkldnn_primitive_at_t inputs[] = { src.data };
2198 const_mkldnn_primitive_t outputs[] = { dst.get() };
2199 check_num_parameters(aprimitive_desc.get(), 1, 1, "lrn forward");
2200 error::wrap_c_api(mkldnn_primitive_create(&result,
2201 aprimitive_desc.get(), inputs, outputs),
2202 "could not create a lrn forward primitive");
2207 struct lrn_backward : public primitive {
2209 mkldnn_lrn_desc_t data;
2210 desc(algorithm aalgorithm,
2211 const memory::desc &data_desc,
2212 const memory::desc &diff_data_desc,
2213 int local_size, float alpha, float beta, float k)
2215 error::wrap_c_api(mkldnn_lrn_backward_desc_init(&data,
2216 convert_to_c(aalgorithm), &diff_data_desc.data,
2217 &data_desc.data, local_size, alpha, beta, k),
2218 "could not create a lrn backward descriptor");
2220 desc(algorithm aalgorithm,
2221 const memory::desc &data_desc,
2222 const memory::desc &diff_data_desc,
2223 int local_size, float alpha, float beta)
2225 error::wrap_c_api(mkldnn_lrn_backward_desc_init(&data,
2226 convert_to_c(aalgorithm), &diff_data_desc.data,
2227 &data_desc.data, local_size, alpha, beta, float(1.0)),
2228 "could not create a lrn backward descriptor");
2232 struct primitive_desc : public mkldnn::primitive_desc {
2233 primitive_desc(const desc &desc, const engine &e,
2234 const lrn_forward::primitive_desc &hint_fwd_pd)
2235 : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2237 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2238 const lrn_forward::primitive_desc &hint_fwd_pd)
2239 : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2241 REG_QUERY_MPD(diff_src, diff_src, 0);
2242 REG_QUERY_MPD(diff_dst, diff_dst, 0);
2243 REG_QUERY_MPD(workspace, workspace, 0);
2246 lrn_backward(const primitive_desc &aprimitive_desc,
2247 const primitive::at &src, const primitive::at &diff_dst,
2248 const primitive::at &workspace, const memory &diff_src) {
2249 mkldnn_primitive_t result;
2250 mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data,
2252 const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2253 check_num_parameters(aprimitive_desc.get(), 3, 1, "lrn backward");
2254 error::wrap_c_api(mkldnn_primitive_create(&result,
2255 aprimitive_desc.get(), inputs, outputs),
2256 "could not create a lrn backward primitive");
2260 lrn_backward(const primitive_desc &aprimitive_desc,
2261 const primitive::at &src, const primitive::at &diff_dst,
2262 const memory &diff_src) {
2263 mkldnn_primitive_t result;
2264 mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
2265 const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2266 check_num_parameters(aprimitive_desc.get(), 2, 1, "lrn backward");
2267 error::wrap_c_api(mkldnn_primitive_create(&result,
2268 aprimitive_desc.get(), inputs, outputs),
2269 "could not create a lrn backward primitive");
2276 /// @addtogroup cpp_api_pooling Pooling
2277 /// A primitive to perform max or average pooling.
2279 /// @sa @ref c_api_pooling in @ref c_api
2282 struct pooling_forward : public primitive {
2284 mkldnn_pooling_desc_t data;
2285 desc(prop_kind aprop_kind, algorithm aalgorithm,
2286 const memory::desc &src_desc,
2287 const memory::desc &dst_desc,
2288 const memory::dims strides,
2289 const memory::dims kernel,
2290 const memory::dims padding_l,
2291 const memory::dims padding_r,
2292 const padding_kind apadding_kind) {
2293 memory::validate_dims(strides);
2294 memory::validate_dims(kernel);
2295 memory::validate_dims(padding_l);
2296 memory::validate_dims(padding_r);
2297 error::wrap_c_api(mkldnn_pooling_forward_desc_init(&data,
2298 mkldnn::convert_to_c(aprop_kind),
2299 convert_to_c(aalgorithm),
2300 &src_desc.data, &dst_desc.data,
2301 &strides[0], &kernel[0],
2302 &padding_l[0], &padding_r[0],
2303 mkldnn::convert_to_c(apadding_kind)),
2304 "could not init a forward pooling descriptor");
2308 struct primitive_desc : public mkldnn::primitive_desc {
2309 primitive_desc(const desc &desc, const engine &e)
2310 : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2312 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2313 : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2315 REG_QUERY_MPD(src, src, 0);
2316 REG_QUERY_MPD(dst, dst, 0);
2317 REG_QUERY_MPD(workspace, workspace, 0);
2320 pooling_forward(const primitive_desc &aprimitive_desc, const primitive::at &src,
2321 const memory &dst) {
2322 mkldnn_primitive_t result;
2323 mkldnn_primitive_at_t inputs[] = { src.data };
2324 const_mkldnn_primitive_t outputs[] = { dst.get(), nullptr };
2325 check_num_parameters(aprimitive_desc.get(), 1, 1, "pooling forward");
2326 error::wrap_c_api(mkldnn_primitive_create(&result,
2327 aprimitive_desc.get(), inputs, outputs),
2328 "could not create a pooling forward primitive");
2332 pooling_forward(const primitive_desc &aprimitive_desc, const primitive::at &src,
2333 const memory &dst, const memory &workspace) {
2334 mkldnn_primitive_t result;
2335 mkldnn_primitive_at_t inputs[] = { src.data };
2336 const_mkldnn_primitive_t outputs[] = { dst.get(), workspace.get() };
2337 check_num_parameters(aprimitive_desc.get(), 1, 2, "pooling forward");
2338 error::wrap_c_api(mkldnn_primitive_create(&result,
2339 aprimitive_desc.get(), inputs, outputs),
2340 "could not create a pooling forward primitive");
2345 struct pooling_backward : public primitive {
2347 mkldnn_pooling_desc_t data;
2348 desc(algorithm aalgorithm,
2349 const memory::desc &diff_src_desc,
2350 const memory::desc &diff_dst_desc,
2351 const memory::dims &strides,
2352 const memory::dims &kernel,
2353 const memory::dims &padding_l,
2354 const memory::dims &padding_r,
2355 const padding_kind apadding_kind) {
2356 memory::validate_dims(strides);
2357 memory::validate_dims(kernel);
2358 memory::validate_dims(padding_l);
2359 memory::validate_dims(padding_r);
2360 error::wrap_c_api(mkldnn_pooling_backward_desc_init(&data,
2361 convert_to_c(aalgorithm),
2362 &diff_src_desc.data, &diff_dst_desc.data,
2363 &strides[0], &kernel[0],
2364 &padding_l[0], &padding_r[0],
2365 mkldnn::convert_to_c(apadding_kind)),
2366 "could not init a backward pooling descriptor");
2370 struct primitive_desc : public mkldnn::primitive_desc {
2371 primitive_desc(const desc &desc, const engine &e,
2372 const pooling_forward::primitive_desc &hint_fwd_pd)
2373 : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2375 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2376 const pooling_forward::primitive_desc &hint_fwd_pd)
2377 : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2379 REG_QUERY_MPD(diff_src, diff_src, 0);
2380 REG_QUERY_MPD(diff_dst, diff_dst, 0);
2381 REG_QUERY_MPD(workspace, workspace, 0);
2384 pooling_backward(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst,
2385 const memory &diff_src) {
2386 mkldnn_primitive_t result;
2387 mkldnn_primitive_at_t inputs[] = { diff_dst.data };
2388 const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2389 check_num_parameters(aprimitive_desc.get(), 1, 1, "pooling backward");
2390 error::wrap_c_api(mkldnn_primitive_create(&result,
2391 aprimitive_desc.get(), inputs, outputs),
2392 "could not create a pooling backward primitive");
2396 pooling_backward(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst,
2397 const primitive::at &workspace, const memory &diff_src) {
2398 mkldnn_primitive_t result;
2399 mkldnn_primitive_at_t inputs[] = { diff_dst.data, workspace.data };
2400 const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2401 check_num_parameters(aprimitive_desc.get(), 2, 1, "pooling backward");
2402 error::wrap_c_api(mkldnn_primitive_create(&result,
2403 aprimitive_desc.get(), inputs, outputs),
2404 "could not create a pooling backward primitive");
2411 /// @addtogroup cpp_api_eltwise Eltwise
2412 /// A primitive to compute element-wise operations like parametric rectifier
2413 /// linear unit (ReLU).
2415 /// @sa @ref c_api_eltwise in @ref c_api
2418 struct eltwise_forward : public primitive {
2420 mkldnn_eltwise_desc_t data;
2421 template <typename T>
2422 desc(prop_kind aprop_kind, algorithm alg_kind,
2423 const memory::desc &src_desc, T alpha = 0, T beta = 0) {
2424 error::wrap_c_api(mkldnn_eltwise_forward_desc_init(&data,
2425 mkldnn::convert_to_c(aprop_kind),
2426 mkldnn::convert_to_c(alg_kind), &src_desc.data,
2427 static_cast<float>(alpha), static_cast<float>(beta)),
2428 "could not create a eltwise forward descriptor");
2432 struct primitive_desc : public mkldnn::primitive_desc {
2433 primitive_desc(const desc &desc, const engine &e)
2434 : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2436 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2437 : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2439 REG_QUERY_MPD(src, src, 0);
2440 REG_QUERY_MPD(dst, dst, 0);
2443 eltwise_forward(const primitive_desc &aprimitive_desc,
2444 const primitive::at &src, const memory &dst) {
2445 mkldnn_primitive_t result;
2446 mkldnn_primitive_at_t inputs[] = { src.data };
2447 const_mkldnn_primitive_t outputs[] = { dst.get() };
2448 check_num_parameters(aprimitive_desc.get(), 1, 1, "eltwise forward");
2449 error::wrap_c_api(mkldnn_primitive_create(&result,
2450 aprimitive_desc.get(), inputs, outputs),
2451 "could not create a eltwise forward primitive");
2456 struct eltwise_backward : public primitive {
2458 mkldnn_eltwise_desc_t data;
2460 template <typename T>
2461 desc(algorithm alg_kind, const memory::desc &diff_data_desc,
2462 const memory::desc &data_desc, T alpha = 0, T beta = 0) {
2463 error::wrap_c_api(mkldnn_eltwise_backward_desc_init(&data,
2464 mkldnn::convert_to_c(alg_kind), &diff_data_desc.data,
2465 &data_desc.data, static_cast<float>(alpha),
2466 static_cast<float>(beta)),
2467 "could not create a eltwise backward descriptor");
2471 struct primitive_desc : public mkldnn::primitive_desc {
2472 primitive_desc(const desc &desc, const engine &e,
2473 const eltwise_forward::primitive_desc &hint_fwd_pd)
2474 : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2476 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2477 const eltwise_forward::primitive_desc &hint_fwd_pd)
2478 : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2480 REG_QUERY_MPD(src, src, 0);
2481 REG_QUERY_MPD(diff_src, diff_src, 0);
2482 REG_QUERY_MPD(diff_dst, diff_dst, 0);
2485 eltwise_backward(const primitive_desc &aprimitive_desc,
2486 const primitive::at &src, const primitive::at &diff_dst,
2487 const memory &diff_src) {
2488 mkldnn_primitive_t result;
2489 mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
2490 const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2491 check_num_parameters(aprimitive_desc.get(), 2, 1, "eltwise backward");
2492 error::wrap_c_api(mkldnn_primitive_create(&result,
2493 aprimitive_desc.get(), inputs, outputs),
2494 "could not create a eltwise backward primitive");
2501 /// @addtogroup cpp_api_depthwise Depthwise
2504 struct depthwise_forward : public primitive {
2506 mkldnn_depthwise_desc_t data;
2508 desc(prop_kind aprop_kind, algorithm alg_kind,
2509 const memory::desc &src_desc, const memory::desc &dst_desc, const memory::desc &weights_desc,
2510 const memory::desc &bias_desc) {
2511 error::wrap_c_api(mkldnn_depthwise_forward_desc_init(&data,
2512 mkldnn::convert_to_c(aprop_kind),
2513 mkldnn::convert_to_c(alg_kind),
2514 &src_desc.data, &dst_desc.data,
2515 &weights_desc.data, &bias_desc.data),
2516 "could not create a depthwise forward descriptor");
2519 desc(prop_kind aprop_kind, algorithm alg_kind,
2520 const memory::desc &src_desc, const memory::desc &dst_desc, const memory::desc &weights_desc) {
2521 error::wrap_c_api(mkldnn_depthwise_forward_desc_init(&data,
2522 mkldnn::convert_to_c(aprop_kind),
2523 mkldnn::convert_to_c(alg_kind),
2524 &src_desc.data, &dst_desc.data,
2525 &weights_desc.data, nullptr),
2526 "could not create a depthwise forward descriptor");
2530 struct primitive_desc : public mkldnn::primitive_desc {
2531 primitive_desc(const desc &desc, const engine &e)
2532 : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2534 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2535 : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2537 REG_QUERY_MPD(src, src, 0);
2538 REG_QUERY_MPD(dst, dst, 0);
2541 depthwise_forward(const primitive_desc &aprimitive_desc,
2542 const primitive::at &src, const primitive::at &weights,
2543 const primitive::at &bias, const memory &dst) {
2544 mkldnn_primitive_t result;
2545 mkldnn_primitive_at_t inputs[] = { src.data, weights.data, bias.data };
2546 const_mkldnn_primitive_t outputs[] = { dst.get() };
2547 error::wrap_c_api(mkldnn_primitive_create(&result,
2548 aprimitive_desc.get(), inputs, outputs),
2549 "could not create a depthwise forward primitive");
2553 depthwise_forward(const primitive_desc &aprimitive_desc,
2554 const primitive::at &src, const primitive::at &weights,
2555 const memory &dst) {
2556 mkldnn_primitive_t result;
2557 mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
2558 const_mkldnn_primitive_t outputs[] = { dst.get() };
2559 error::wrap_c_api(mkldnn_primitive_create(&result,
2560 aprimitive_desc.get(), inputs, outputs),
2561 "could not create a depthwise forward primitive");
2568 /// @addtogroup cpp_api_softmax Softmax
2569 /// A primitive to perform softmax.
2571 /// @sa @ref c_api_softmax in @ref c_api
2574 struct softmax_forward : public primitive {
2576 mkldnn_softmax_desc_t data;
2577 desc(prop_kind aprop_kind, const memory::desc &data_desc,
2579 error::wrap_c_api(mkldnn_softmax_forward_desc_init(&data,
2580 mkldnn::convert_to_c(aprop_kind), &data_desc.data,
2582 "could not create a softmax forward descriptor");
2586 struct primitive_desc : public mkldnn::primitive_desc {
2587 primitive_desc(const desc &desc, const engine &e)
2588 : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2590 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2591 : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2593 REG_QUERY_MPD(src, src, 0);
2594 REG_QUERY_MPD(dst, dst, 0);
2597 softmax_forward(const primitive_desc &aprimitive_desc,
2598 const primitive::at &src, const memory &dst) {
2599 mkldnn_primitive_t result;
2600 mkldnn_primitive_at_t inputs[] = { src.data };
2601 const_mkldnn_primitive_t outputs[] = { dst.get() };
2602 check_num_parameters(aprimitive_desc.get(), 1, 1, "softmax forward");
2603 error::wrap_c_api(mkldnn_primitive_create(&result,
2604 aprimitive_desc.get(), inputs, outputs),
2605 "could not create a softmax forward primitive");
2610 struct softmax_backward : public primitive {
2612 mkldnn_softmax_desc_t data;
2613 desc(const memory::desc &diff_desc, const memory::desc &data_desc,
2615 error::wrap_c_api(mkldnn_softmax_backward_desc_init(&data,
2616 &diff_desc.data, &data_desc.data, softmax_axis),
2617 "could not init a backward softmax descriptor");
2621 struct primitive_desc : public mkldnn::primitive_desc {
2622 primitive_desc(const desc &desc, const engine &e,
2623 const softmax_forward::primitive_desc &hint_fwd_pd)
2624 : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2626 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2627 const softmax_forward::primitive_desc &hint_fwd_pd)
2628 : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2630 REG_QUERY_MPD(dst, dst, 0);
2631 REG_QUERY_MPD(diff_src, diff_src, 0);
2632 REG_QUERY_MPD(diff_dst, diff_dst, 0);
2633 REG_QUERY_MPD(workspace, workspace, 0);
2636 softmax_backward(const primitive_desc &aprimitive_desc,
2637 const primitive::at &dst, const primitive::at &diff_dst,
2638 const memory &diff_src) {
2639 mkldnn_primitive_t result;
2640 mkldnn_primitive_at_t inputs[] = { dst.data, diff_dst.data };
2641 const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2642 error::wrap_c_api(mkldnn_primitive_create(&result,
2643 aprimitive_desc.get(), inputs, outputs),
2644 "could not create a softmax backward primitive");
2651 /// @addtogroup cpp_api_batch_norm Batch normalization
2652 /// A primitive to perform batch normalization.
2654 /// @sa @ref c_api_batch_normalization in @ref c_api
2657 struct batch_normalization_forward : public primitive {
2659 mkldnn_batch_normalization_desc_t data;
2660 template <typename T>
2661 desc(prop_kind aprop_kind, const memory::desc &src_desc, T epsilon,
2664 mkldnn_batch_normalization_forward_desc_init(&data,
2665 mkldnn::convert_to_c(aprop_kind), &src_desc.data,
2666 static_cast<float>(epsilon), flags),
2667 "could not create a batch normalization forward descriptor");
2671 struct primitive_desc : public mkldnn::primitive_desc {
2672 primitive_desc(const desc &desc, const engine &e)
2673 : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2675 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2676 : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2678 REG_QUERY_MPD(src, src, 0);
2679 REG_QUERY_MPD(weights, weights, 0);
2680 REG_QUERY_MPD(dst, dst, 0);
2681 REG_QUERY_MPD(workspace, workspace, 0);
2683 memory::primitive_desc mean_primitive_desc() const
2684 { return stat_primitive_desc(mean); }
2685 memory::primitive_desc variance_primitive_desc() const
2686 { return stat_primitive_desc(var); }
2689 enum { mean = 1, var = 2, };
2690 memory::primitive_desc stat_primitive_desc(int kind) const {
2691 mkldnn_batch_normalization_desc_t *p;
2692 error::wrap_c_api(mkldnn_primitive_desc_query(
2693 get(), mkldnn::convert_to_c(batch_normalization_d), 0, &p),
2694 "could not get a batch-normalization descriptor");
2695 return query_mpd(p->flags & use_global_stats ? src_pd : dst_pd, kind);
2699 batch_normalization_forward(const primitive_desc &aprimitive_desc,
2700 const primitive::at &src, const primitive::at &mean,
2701 const primitive::at &variance, const primitive::at &weights,
2702 const memory &dst) {
2703 mkldnn_primitive_t result;
2704 mkldnn_primitive_at_t inputs[] = { src.data,
2705 mean.data, variance.data, weights.data };
2706 const_mkldnn_primitive_t outputs[] = { dst.get() };
2707 check_num_parameters(aprimitive_desc.get(), 4, 1,
2708 "batch normalization forward");
2709 error::wrap_c_api(mkldnn_primitive_create(&result,
2710 aprimitive_desc.get(), inputs, outputs),
2711 "could not create a batch normalization forward primitive");
2715 batch_normalization_forward(const primitive_desc &aprimitive_desc,
2716 const primitive::at &src, const primitive::at &mean,
2717 const primitive::at &variance, const memory &dst) {
2718 mkldnn_primitive_t result;
2719 mkldnn_primitive_at_t inputs[] = { src.data,
2720 mean.data, variance.data };
2721 const_mkldnn_primitive_t outputs[] = { dst.get() };
2722 check_num_parameters(aprimitive_desc.get(), 3, 1,
2723 "batch normalization forward");
2724 error::wrap_c_api(mkldnn_primitive_create(&result,
2725 aprimitive_desc.get(), inputs, outputs),
2726 "could not create a batch normalization forward primitive");
2730 /// @warning batch_normalization_forward has two constructors with very
2731 /// similar signatures:
2732 /// - (pd, src, weights, dst, mean, variance) // 2 in, 3 out
2733 /// - (pd, src, dst, mean, variance, workspace) // 1 in, 4 out
2734 /// The only way to distinguish between them is to explicitly
2735 /// cast all input parameters to their type; that is, to
2736 /// const primitive:at &.
2737 batch_normalization_forward(const primitive_desc &aprimitive_desc,
2738 const primitive::at &src, const primitive::at &weights,
2739 const memory &dst, const memory &mean, const memory &variance) {
2740 mkldnn_primitive_t result;
2741 mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
2742 const_mkldnn_primitive_t outputs[] = { dst.get(),
2743 mean.get(), variance.get() };
2744 check_num_parameters(aprimitive_desc.get(), 2, 3,
2745 "batch normalization forward");
2746 error::wrap_c_api(mkldnn_primitive_create(&result,
2747 aprimitive_desc.get(), inputs, outputs),
2748 "could not create a batch normalization forward primitive");
2752 batch_normalization_forward(const primitive_desc &aprimitive_desc,
2753 const primitive::at &src, const primitive::at &weights,
2754 const memory &dst, const memory &mean, const memory &variance,
2755 const memory &workspace) {
2756 mkldnn_primitive_t result;
2757 mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
2758 const_mkldnn_primitive_t outputs[] = { dst.get(),
2759 mean.get(), variance.get(), workspace.get() };
2760 check_num_parameters(aprimitive_desc.get(), 2, 4,
2761 "batch normalization forward");
2762 error::wrap_c_api(mkldnn_primitive_create(&result,
2763 aprimitive_desc.get(), inputs, outputs),
2764 "could not create a batch normalization forward primitive");
2768 batch_normalization_forward(const primitive_desc &aprimitive_desc,
2769 const primitive::at &src, const memory &dst, const memory &mean,
2770 const memory &variance) {
2771 mkldnn_primitive_t result;
2772 mkldnn_primitive_at_t inputs[] = { src.data };
2773 const_mkldnn_primitive_t outputs[] = { dst.get(),
2774 mean.get(), variance.get() };
2775 check_num_parameters(aprimitive_desc.get(), 1, 3,
2776 "batch normalization forward");
2777 error::wrap_c_api(mkldnn_primitive_create(&result,
2778 aprimitive_desc.get(), inputs, outputs),
2779 "could not create a batch normalization forward primitive");
2783 /// @warning batch_normalization_forward has two constructors with very
2784 /// similar signatures:
2785 /// - (pd, src, weights, dst, mean, variance) // 2 in, 3 out
2786 /// - (pd, src, dst, mean, variance, workspace) // 1 in, 4 out
2787 /// The only way to distinguish between them is to explicitly
2788 /// cast all input parameters to their type; that is, to
2789 /// const primitive:at &.
2790 /// @note To make users' experience a little better, this constructor
2791 /// checks whether parameters match the corresponding primitive
2792 /// descriptor, and if not, calls the other (proper) constructor.
2793 batch_normalization_forward(const primitive_desc &aprimitive_desc,
2794 const primitive::at &src, const memory &dst, const memory &mean,
2795 const memory &variance, const memory &workspace) {
2796 mkldnn_primitive_t result;
2797 mkldnn_primitive_at_t inputs[2] = { src.data };
2798 const_mkldnn_primitive_t outputs[4] = { dst.get(),
2799 mean.get(), variance.get(), workspace.get() };
2801 if (1) { // check whether this is the `wrong` constructor
2802 const int n_inputs_expected = mkldnn_primitive_desc_query_s32(
2803 aprimitive_desc.get(), mkldnn_query_num_of_inputs_s32, 0);
2804 const int n_outputs_expected = mkldnn_primitive_desc_query_s32(
2805 aprimitive_desc.get(), mkldnn_query_num_of_outputs_s32, 0);
2806 if (n_inputs_expected == 2 && n_outputs_expected == 3) {
2807 // shift parameters, get rid of workspace, and add weights...
2808 auto _weights = dst;
2809 inputs[1] = {_weights.get(), 0};
2811 auto _dst = mean, _mean = variance, _variance = workspace;
2812 outputs[0] = _dst.get();
2813 outputs[1] = _mean.get();
2814 outputs[2] = _variance.get();
2815 outputs[3] = nullptr;
2818 error::wrap_c_api(mkldnn_primitive_create(&result,
2819 aprimitive_desc.get(), inputs, outputs),
2820 "could not create a batch normalization forward primitive");
2824 batch_normalization_forward(const primitive_desc &aprimitive_desc,
2825 const primitive::at &src, const primitive::at &weights,
2826 const memory &dst) {
2827 mkldnn_primitive_t result;
2828 mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
2829 const_mkldnn_primitive_t outputs[] = { dst.get() };
2830 check_num_parameters(aprimitive_desc.get(), 2, 1,
2831 "batch normalization forward");
2832 error::wrap_c_api(mkldnn_primitive_create(&result,
2833 aprimitive_desc.get(), inputs, outputs),
2834 "could not create a batch normalization forward primitive");
2838 batch_normalization_forward(const primitive_desc &aprimitive_desc,
2839 const primitive::at &src, const memory &dst) {
2840 mkldnn_primitive_t result;
2841 mkldnn_primitive_at_t inputs[] = { src.data };
2842 const_mkldnn_primitive_t outputs[] = { dst.get() };
2843 check_num_parameters(aprimitive_desc.get(), 1, 1,
2844 "batch normalization forward");
2845 error::wrap_c_api(mkldnn_primitive_create(&result,
2846 aprimitive_desc.get(), inputs, outputs),
2847 "could not create a batch normalization forward primitive");
2852 struct batch_normalization_backward : public primitive {
2854 mkldnn_batch_normalization_desc_t data;
2855 template <typename T>
2856 desc(prop_kind aprop_kind, const memory::desc &diff_data_desc,
2857 const memory::desc &data_desc, T epsilon, unsigned flags) {
2859 mkldnn_batch_normalization_backward_desc_init(&data,
2860 mkldnn::convert_to_c(aprop_kind),
2861 &diff_data_desc.data, &data_desc.data,
2862 static_cast<float>(epsilon), flags),
2863 "could not create a batch normalization backward descriptor");
2867 struct primitive_desc : public mkldnn::primitive_desc {
2868 primitive_desc(const desc &desc, const engine &e,
2869 const batch_normalization_forward::primitive_desc &hint_fwd_pd)
2870 : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2872 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2873 const batch_normalization_forward::primitive_desc &hint_fwd_pd)
2874 : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2876 REG_QUERY_MPD(src, src, 0);
2877 REG_QUERY_MPD(mean, src, 1);
2878 REG_QUERY_MPD(variance, src, 2);
2879 REG_QUERY_MPD(weights, weights, 0);
2880 REG_QUERY_MPD(dst, dst, 0);
2881 REG_QUERY_MPD(diff_dst, diff_dst, 0);
2882 REG_QUERY_MPD(workspace, workspace, 0);
2884 REG_QUERY_MPD(diff_src, diff_src, 0);
2885 REG_QUERY_MPD(diff_weights, diff_weights, 0);
2888 // Prop_kind == backward
2889 batch_normalization_backward(const primitive_desc &aprimitive_desc,
2890 const primitive::at &src, const primitive::at &mean,
2891 const primitive::at &variance, const primitive::at &diff_dst,
2892 const primitive::at &weights, const memory &diff_src,
2893 const memory &diff_weights) {
2894 mkldnn_primitive_t result;
2895 mkldnn_primitive_at_t inputs[] = { src.data,
2896 mean.data, variance.data, diff_dst.data, weights.data };
2897 const_mkldnn_primitive_t outputs[] = { diff_src.get(),
2898 diff_weights.get() };
2899 check_num_parameters(aprimitive_desc.get(), 5, 2,
2900 "batch normalization backward");
2901 error::wrap_c_api(mkldnn_primitive_create(&result,
2902 aprimitive_desc.get(), inputs, outputs),
2903 "could not create a batch normalization backward primitive");
2907 // Prop_kind == backward (+ws)
2908 batch_normalization_backward(const primitive_desc &aprimitive_desc,
2909 const primitive::at &src, const primitive::at &mean,
2910 const primitive::at &variance, const primitive::at &diff_dst,
2911 const primitive::at &weights, const primitive::at &workspace,
2912 const memory &diff_src, const memory &diff_weights) {
2913 mkldnn_primitive_t result;
2914 mkldnn_primitive_at_t inputs[] = { src.data, mean.data, variance.data,
2915 diff_dst.data, weights.data, workspace.data };
2916 const_mkldnn_primitive_t outputs[] = { diff_src.get(),
2917 diff_weights.get() };
2918 check_num_parameters(aprimitive_desc.get(), 6, 2,
2919 "batch normalization backward");
2920 error::wrap_c_api(mkldnn_primitive_create(&result,
2921 aprimitive_desc.get(), inputs, outputs),
2922 "could not create a batch normalization backward primitive");
2926 // Prop_kind == backward_data (+ws or +weights)
2927 /// @warning This constructor works for backward_data propagation
2928 /// - w/ weights but w/o workspace, or
2929 /// - w/ workspace but w/o weights
2930 batch_normalization_backward(const primitive_desc &aprimitive_desc,
2931 const primitive::at &src, const primitive::at &mean,
2932 const primitive::at &variance,const primitive::at &diff_dst,
2933 const primitive::at &weights_or_workspace, const memory &diff_src) {
2934 mkldnn_primitive_t result;
2935 mkldnn_primitive_at_t inputs[] = { src.data, mean.data, variance.data,
2936 diff_dst.data, weights_or_workspace.data };
2937 const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2938 check_num_parameters(aprimitive_desc.get(), 5, 1,
2939 "batch normalization backward");
2940 error::wrap_c_api(mkldnn_primitive_create(&result,
2941 aprimitive_desc.get(), inputs, outputs),
2942 "could not create a batch normalization backward primitive");
2946 // Prop_kind == backward_data
2947 batch_normalization_backward(const primitive_desc &aprimitive_desc,
2948 const primitive::at &src, const primitive::at &mean,
2949 const primitive::at &variance, const primitive::at &diff_dst,
2950 const memory &diff_src) {
2951 mkldnn_primitive_t result;
2952 mkldnn_primitive_at_t inputs[] = { src.data,
2953 mean.data, variance.data, diff_dst.data };
2954 const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2955 check_num_parameters(aprimitive_desc.get(), 4, 1,
2956 "batch normalization backward");
2957 error::wrap_c_api(mkldnn_primitive_create(&result,
2958 aprimitive_desc.get(), inputs, outputs),
2959 "could not create a batch normalization backward primitive");
2966 /// @addtogroup cpp_api_inner_product Inner Product
2967 /// A primitive to compute an inner product.
2969 /// @sa @ref c_api_inner_product in @ref c_api
2972 struct inner_product_forward: public primitive {
2974 mkldnn_inner_product_desc_t data;
2975 desc(prop_kind aprop_kind, const memory::desc &src_desc,
2976 const memory::desc &weights_desc,
2977 const memory::desc &bias_desc,
2978 const memory::desc &dst_desc) {
2980 mkldnn_inner_product_forward_desc_init(&data,
2981 mkldnn::convert_to_c(aprop_kind), &src_desc.data,
2982 &weights_desc.data, &bias_desc.data, &dst_desc.data),
2983 "could not create a inner product forward descriptor");
2986 desc(prop_kind aprop_kind, const memory::desc &src_desc,
2987 const memory::desc &weights_desc,
2988 const memory::desc &dst_desc) {
2990 mkldnn_inner_product_forward_desc_init(&data,
2991 mkldnn::convert_to_c(aprop_kind), &src_desc.data,
2992 &weights_desc.data, nullptr, &dst_desc.data),
2993 "could not create a inner product forward descriptor");
2997 struct primitive_desc : public mkldnn::primitive_desc {
2998 primitive_desc(const desc &desc, const engine &e)
2999 : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
3001 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
3002 : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
3004 REG_QUERY_MPD(src, src, 0);
3005 REG_QUERY_MPD(weights, weights, 0);
3006 REG_QUERY_MPD(bias, weights, 1);
3007 REG_QUERY_MPD(dst, dst, 0);
3010 inner_product_forward(const primitive_desc &aprimitive_desc,
3011 const primitive::at &src, const primitive::at weights,
3012 const primitive::at &bias, const memory &dst) {
3013 mkldnn_primitive_t result;
3014 mkldnn_primitive_at_t inputs[] = { src.data, weights.data,
3016 const_mkldnn_primitive_t outputs[] = { dst.get() };
3017 check_num_parameters(aprimitive_desc.get(), 3, 1,
3018 "inner product forward");
3019 error::wrap_c_api(mkldnn_primitive_create(&result,
3020 aprimitive_desc.get(), inputs, outputs),
3021 "could not create a inner product forward primitive");
3025 inner_product_forward(const primitive_desc &aprimitive_desc,
3026 const primitive::at &src, const primitive::at weights,
3027 const memory &dst) {
3028 mkldnn_primitive_t result;
3029 mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
3030 const_mkldnn_primitive_t outputs[] = { dst.get() };
3031 check_num_parameters(aprimitive_desc.get(), 2, 1,
3032 "inner product forward");
3033 error::wrap_c_api(mkldnn_primitive_create(&result,
3034 aprimitive_desc.get(), inputs, outputs),
3035 "could not create a inner product forward primitive");
3040 struct inner_product_backward_data: public primitive {
3042 mkldnn_inner_product_desc_t data;
3043 desc(const memory::desc &diff_src_desc,
3044 const memory::desc &weights_desc,
3045 const memory::desc &diff_dst_desc) {
3047 mkldnn_inner_product_backward_data_desc_init(&data,
3048 &diff_src_desc.data, &weights_desc.data,
3049 &diff_dst_desc.data),
3050 "could not create a inner product backward data descriptor");
3054 struct primitive_desc : public mkldnn::primitive_desc {
3055 primitive_desc(const desc &desc, const engine &e,
3056 const inner_product_forward::primitive_desc &hint_fwd_pd)
3057 : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
3059 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
3060 const inner_product_forward::primitive_desc &hint_fwd_pd)
3061 : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
3063 REG_QUERY_MPD(diff_src, diff_src, 0);
3064 REG_QUERY_MPD(weights, weights, 0);
3065 REG_QUERY_MPD(diff_dst, diff_dst, 0);
3068 inner_product_backward_data(const primitive_desc &aprimitive_desc,
3069 const primitive::at &diff_dst, const primitive::at weights,
3070 const memory &diff_src) {
3071 mkldnn_primitive_t result;
3072 mkldnn_primitive_at_t inputs[] = { diff_dst.data, weights.data };
3073 const_mkldnn_primitive_t outputs[] = { diff_src.get() };
3074 check_num_parameters(aprimitive_desc.get(), 2, 1,
3075 "inner product backward data");
3076 error::wrap_c_api(mkldnn_primitive_create(&result,
3077 aprimitive_desc.get(), inputs, outputs),
3078 "could not create a inner product backward data primitive");
3083 struct inner_product_backward_weights: public primitive {
3085 mkldnn_inner_product_desc_t data;
3086 desc(const memory::desc &src_desc,
3087 const memory::desc &diff_weights_desc,
3088 const memory::desc &diff_bias_desc,
3089 const memory::desc &diff_dst_desc) {
3091 mkldnn_inner_product_backward_weights_desc_init(
3092 &data, &src_desc.data, &diff_weights_desc.data,
3093 &diff_bias_desc.data, &diff_dst_desc.data),
3094 "could not create a inner product backward weights descriptor");
3096 desc(const memory::desc &src_desc,
3097 const memory::desc &diff_weights_desc,
3098 const memory::desc &diff_dst_desc) {
3100 mkldnn_inner_product_backward_weights_desc_init(
3101 &data, &src_desc.data, &diff_weights_desc.data,
3102 nullptr, &diff_dst_desc.data),
3103 "could not create a inner product backward weights descriptor");
3107 struct primitive_desc : public mkldnn::primitive_desc {
3108 primitive_desc(const desc &desc, const engine &e,
3109 const inner_product_forward::primitive_desc &hint_fwd_pd)
3110 : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
3112 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
3113 const inner_product_forward::primitive_desc &hint_fwd_pd)
3114 : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
3116 REG_QUERY_MPD(src, src, 0);
3117 REG_QUERY_MPD(diff_weights, diff_weights, 0);
3118 REG_QUERY_MPD(diff_bias, diff_weights, 1);
3119 REG_QUERY_MPD(diff_dst, diff_dst, 0);
3122 inner_product_backward_weights(const primitive_desc &aprimitive_desc,
3123 const primitive::at &src, const primitive::at diff_dst,
3124 const memory &diff_weights) {
3125 mkldnn_primitive_t result;
3126 mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
3127 const_mkldnn_primitive_t outputs[] = { diff_weights.get() };
3128 check_num_parameters(aprimitive_desc.get(), 2, 1,
3129 "inner product backward weights");
3130 error::wrap_c_api(mkldnn_primitive_create(&result,
3131 aprimitive_desc.get(), inputs, outputs),
3132 "could not create a inner product backward weights primitive");
3136 inner_product_backward_weights(const primitive_desc &aprimitive_desc,
3137 const primitive::at &src, const primitive::at diff_dst,
3138 const memory &diff_weights, const memory &diff_bias) {
3139 mkldnn_primitive_t result;
3140 mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
3141 const_mkldnn_primitive_t outputs[] =
3142 { diff_weights.get(), diff_bias.get()};
3143 check_num_parameters(aprimitive_desc.get(), 2, 2,
3144 "inner product backward weights");
3145 error::wrap_c_api(mkldnn_primitive_create(&result,
3146 aprimitive_desc.get(), inputs, outputs),
3147 "could not create a inner product backward weights primitive");
3154 /// @addtogroup cpp_api_rnn RNN
3155 /// A primitive to compute common recurrent layer.
3157 /// @sa @ref c_api_rnn in @ref c_api
3162 mkldnn_rnn_cell_desc_t c_rnn_cell_;
3164 desc(algorithm kind, algorithm activation_f) {
3165 error::wrap_c_api(mkldnn_rnn_cell_desc_init(&c_rnn_cell_,
3166 mkldnn::convert_to_c(kind),
3167 mkldnn::convert_to_c(activation_f), 0U, 0, 0),
3168 "could not init an rnn cell descriptor");
3170 desc(algorithm kind): desc(kind, algorithm::algorithm_undef) {}
3172 operator const mkldnn_rnn_cell_desc_t*() const { return &c_rnn_cell_; }
3174 algorithm get_cell_kind() const
3175 { return algorithm(c_rnn_cell_.cell_kind); }
3176 algorithm get_activation() const
3177 { return algorithm(c_rnn_cell_.activation_kind); }
3179 float get_alpha() const { return c_rnn_cell_.alpha; }
3180 void set_alpha(float alpha) {
3181 c_rnn_cell_.flags |= mkldnn_rnn_cell_with_relu;
3182 c_rnn_cell_.alpha = alpha;
3185 float get_clipping() const { return c_rnn_cell_.clipping; }
3186 void set_clipping(float clipping) {
3187 c_rnn_cell_.flags |= mkldnn_rnn_cell_with_clipping;
3188 c_rnn_cell_.clipping = clipping;
3191 int get_gates_count() const {
3192 return mkldnn_rnn_cell_get_gates_count(&c_rnn_cell_);
3194 int get_state_count() const {
3195 return mkldnn_rnn_cell_get_states_count(&c_rnn_cell_);
3200 struct rnn_forward : public primitive {
3202 mkldnn_rnn_desc_t data;
3203 desc(prop_kind aprop_kind, rnn_cell::desc cell,
3204 const rnn_direction direction,
3205 const memory::desc &src_layer_desc,
3206 const memory::desc &src_iter_desc,
3207 const memory::desc &weights_layer_desc,
3208 const memory::desc &weights_iter_desc,
3209 const memory::desc &bias_desc,
3210 const memory::desc &dst_layer_desc,
3211 const memory::desc &dst_iter_desc
3213 error::wrap_c_api(mkldnn_rnn_forward_desc_init(&data,
3214 mkldnn::convert_to_c(aprop_kind), cell,
3215 mkldnn::convert_to_c(direction),
3216 &src_layer_desc.data, &src_iter_desc.data,
3217 &weights_layer_desc.data, &weights_iter_desc.data,
3219 &dst_layer_desc.data, &dst_iter_desc.data),
3220 "could not create an RNN forward descriptor");
3225 struct primitive_desc : public mkldnn::primitive_desc {
3226 primitive_desc(const desc &desc, const engine &e)
3227 : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
3229 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
3230 : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
3232 REG_QUERY_MPD(src_layer, src, 0);
3233 REG_QUERY_MPD(src_iter, src, 1);
3234 REG_QUERY_MPD(weights_layer, weights, 0);
3235 REG_QUERY_MPD(weights_iter, weights, 1);
3236 REG_QUERY_MPD(bias, weights, 2);
3237 REG_QUERY_MPD(dst_layer, dst, 0);
3238 REG_QUERY_MPD(dst_iter, dst, 1);
3239 REG_QUERY_MPD(workspace, workspace, 0);
3242 rnn_forward(const primitive_desc &aprimitive_desc,
3243 const primitive::at &src_layer, const primitive::at &src_iter,
3244 const primitive::at &weights_layer,
3245 const primitive::at &weights_iter, const primitive::at &bias,
3246 const memory &dst_layer, const memory &dst_iter,
3247 const memory &workspace) {
3248 mkldnn_primitive_t result;
3249 mkldnn_primitive_at_t inputs[5];
3250 const_mkldnn_primitive_t outputs[3];
3252 inputs[idx++] = src_layer.data;
3253 if (!is_null_memory(src_iter.data.primitive))
3254 inputs[idx++] = src_iter.data;
3255 inputs[idx++] = weights_layer.data;
3256 inputs[idx++] = weights_iter.data;
3257 if (!is_null_memory(bias.data.primitive)) inputs[idx++] = bias.data;
3260 outputs[idx++] = dst_layer.get();
3261 if (!is_null_memory(dst_iter.get())) outputs[idx++] = dst_iter.get();
3262 if (!is_null_memory(workspace.get())) outputs[idx++] = workspace.get();
3264 error::wrap_c_api(mkldnn_primitive_create(&result,
3265 aprimitive_desc.get(), inputs, outputs),
3266 "could not create an RNN forward primitive");
3271 struct rnn_backward : public primitive {
3273 mkldnn_rnn_desc_t data;
3274 desc(prop_kind aprop_kind, rnn_cell::desc cell,
3275 const rnn_direction direction,
3276 const memory::desc &src_layer_desc,
3277 const memory::desc &src_iter_desc,
3278 const memory::desc &weights_layer_desc,
3279 const memory::desc &weights_iter_desc,
3280 const memory::desc &bias_desc,
3281 const memory::desc &dst_layer_desc,
3282 const memory::desc &dst_iter_desc,
3283 const memory::desc &diff_src_layer_desc,
3284 const memory::desc &diff_src_iter_desc,
3285 const memory::desc &diff_weights_layer_desc,
3286 const memory::desc &diff_weights_iter_desc,
3287 const memory::desc &diff_bias_desc,
3288 const memory::desc &diff_dst_layer_desc,
3289 const memory::desc &diff_dst_iter_desc) {
3290 error::wrap_c_api(mkldnn_rnn_backward_desc_init(&data,
3291 mkldnn::convert_to_c(aprop_kind), cell,
3292 mkldnn::convert_to_c(direction),
3293 &src_layer_desc.data, &src_iter_desc.data,
3294 &weights_layer_desc.data, &weights_iter_desc.data,
3296 &dst_layer_desc.data, &dst_iter_desc.data,
3297 &diff_src_layer_desc.data, &diff_src_iter_desc.data,
3298 &diff_weights_layer_desc.data,
3299 &diff_weights_iter_desc.data, &diff_bias_desc.data,
3300 &diff_dst_layer_desc.data, &diff_dst_iter_desc.data),
3301 "could not create an RNN backward descriptor");
3306 struct primitive_desc : public mkldnn::primitive_desc {
3307 primitive_desc(const desc &desc, const engine &e,
3308 const rnn_forward::primitive_desc &hint_fwd_pd)
3309 : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
3311 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
3312 const rnn_forward::primitive_desc &hint_fwd_pd)
3313 : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
3315 REG_QUERY_MPD(src_layer, src, 0);
3316 REG_QUERY_MPD(src_iter, src, 1);
3317 REG_QUERY_MPD(weights_layer, weights, 0);
3318 REG_QUERY_MPD(weights_iter, weights, 1);
3319 REG_QUERY_MPD(bias, weights, 2);
3320 REG_QUERY_MPD(dst_layer, dst, 0);
3321 REG_QUERY_MPD(dst_iter, dst, 1);
3322 REG_QUERY_MPD(workspace, workspace, 0);
3324 REG_QUERY_MPD(diff_src_layer, diff_src, 0);
3325 REG_QUERY_MPD(diff_src_iter, diff_src, 1);
3326 REG_QUERY_MPD(diff_weights_layer, diff_weights, 0);
3327 REG_QUERY_MPD(diff_weights_iter, diff_weights, 1);
3328 REG_QUERY_MPD(diff_bias, diff_weights, 2);
3329 REG_QUERY_MPD(diff_dst_layer, diff_dst, 0);
3330 REG_QUERY_MPD(diff_dst_iter, diff_dst, 1);
3333 // With last iteration (with and without input src_iter)
3334 rnn_backward(const primitive_desc &aprimitive_desc,
3335 const primitive::at &src_layer,
3336 const primitive::at &src_iter,
3337 const primitive::at &weights_layer,
3338 const primitive::at &weights_iter,
3339 const primitive::at &bias,
3340 const primitive::at &dst_layer,
3341 const primitive::at &dst_iter,
3342 const memory &diff_src_layer,
3343 const memory &diff_src_iter,
3344 const memory &diff_weights_layer,
3345 const memory &diff_weights_iter,
3346 const memory &diff_bias,
3347 const primitive::at &diff_dst_layer,
3348 const primitive::at &diff_dst_iter,
3349 const primitive::at &workspace) {
3350 mkldnn_primitive_t result;
3351 mkldnn_primitive_at_t inputs[10];
3352 const_mkldnn_primitive_t outputs[5];
3354 inputs[idx++] = src_layer.data;
3355 if (!is_null_memory(src_iter.data.primitive))
3356 inputs[idx++] = src_iter.data;
3357 inputs[idx++] = weights_layer.data;
3358 inputs[idx++] = weights_iter.data;
3359 if (!is_null_memory(bias.data.primitive))
3360 inputs[idx++] = bias.data;
3361 inputs[idx++] = dst_layer.data;
3362 if (!is_null_memory(dst_iter.data.primitive))
3363 inputs[idx++] = dst_iter.data;
3364 inputs[idx++] = diff_dst_layer.data;
3365 if (!is_null_memory(diff_dst_iter.data.primitive))
3366 inputs[idx++] = diff_dst_iter.data;
3367 inputs[idx++] = workspace.data;
3370 outputs[idx++] = diff_src_layer.get();
3371 if (!is_null_memory(diff_src_iter.get()))
3372 outputs[idx++] = diff_src_iter.get();
3373 outputs[idx++] = diff_weights_layer.get();
3374 outputs[idx++] = diff_weights_iter.get();
3375 if (!is_null_memory(diff_bias.get())) outputs[idx++] = diff_bias.get();
3376 error::wrap_c_api(mkldnn_primitive_create(&result,
3377 aprimitive_desc.get(), inputs, outputs),
3378 "could not create an RNN backward primitive");
3385 /// @addtogroup cpp_api_shuffle Shuffle
3386 /// A primitive to shuffle data along the axis.
3388 /// @sa @ref c_api_shuffle in @ref c_api
3391 struct shuffle_forward : public primitive {
3393 mkldnn_shuffle_desc_t data;
3394 desc(prop_kind aprop_kind, const memory::desc &data_desc,
3395 int axis, int group_size) {
3396 error::wrap_c_api(mkldnn_shuffle_forward_desc_init(&data,
3397 mkldnn::convert_to_c(aprop_kind), &data_desc.data,
3399 "could not create a shuffle forward descriptor");
3403 struct primitive_desc : public mkldnn::primitive_desc {
3404 primitive_desc(const desc &desc, const engine &e)
3405 : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
3407 REG_QUERY_MPD(src, src, 0);
3408 REG_QUERY_MPD(dst, dst, 0);
3411 shuffle_forward(const primitive_desc &aprimitive_desc,
3412 const primitive::at &src, const memory &dst) {
3413 mkldnn_primitive_t result;
3414 mkldnn_primitive_at_t inputs[] = { src.data };
3415 const_mkldnn_primitive_t outputs[] = { dst.get() };
3416 check_num_parameters(aprimitive_desc.get(), 1, 1, "shuffle forward");
3417 error::wrap_c_api(mkldnn_primitive_create(&result,
3418 aprimitive_desc.get(), inputs, outputs),
3419 "could not create a shuffle forward primitive");
3424 struct shuffle_backward : public primitive {
3426 mkldnn_shuffle_desc_t data;
3427 desc(const memory::desc &diff_data_desc, int axis, int group_size) {
3428 error::wrap_c_api(mkldnn_shuffle_backward_desc_init(&data,
3429 &diff_data_desc.data, axis, group_size),
3430 "could not create a shuffle backward descriptor");
3434 struct primitive_desc : public mkldnn::primitive_desc {
3435 primitive_desc(const desc &desc, const engine &e,
3436 const shuffle_forward::primitive_desc &hint_fwd_pd)
3437 : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
3439 REG_QUERY_MPD(diff_src, diff_src, 0);
3440 REG_QUERY_MPD(diff_dst, diff_dst, 0);
3443 shuffle_backward(const primitive_desc &aprimitive_desc,
3444 const primitive::at &diff_dst, const memory &diff_src) {
3445 mkldnn_primitive_t result;
3446 mkldnn_primitive_at_t inputs[] = { diff_dst.data};
3447 const_mkldnn_primitive_t outputs[] = { diff_src.get() };
3448 check_num_parameters(aprimitive_desc.get(), 1, 1, "shuffle backward");
3449 error::wrap_c_api(mkldnn_primitive_create(&result,
3450 aprimitive_desc.get(), inputs, outputs),
3451 "could not create a shuffle backward primitive");
3458 /// @addtogroup cpp_api_binary_convolution Binary convolution
3459 /// A primitive to compute binary convolution using different algorithms.
3461 /// @sa @ref c_api_binary_convolution in @ref c_api
3464 struct binary_convolution_forward: public primitive {
3466 mkldnn_binary_convolution_desc_t data;
3467 desc(prop_kind aprop_kind, algorithm aalgorithm,
3468 const memory::desc &src_desc,
3469 const memory::desc &weights_desc,
3470 const memory::desc &dst_desc,
3471 const memory::dims strides,
3472 const memory::dims dilates,
3473 const memory::dims padding_l,
3474 const memory::dims padding_r,
3475 const float pad_value) {
3476 memory::validate_dims(strides);
3477 memory::validate_dims(dilates);
3478 memory::validate_dims(padding_l);
3479 memory::validate_dims(padding_r);
3481 mkldnn_dilated_binary_convolution_forward_desc_init(&data,
3482 mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
3483 &src_desc.data, &weights_desc.data, &dst_desc.data,
3484 &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
3486 "could not create a dilated binary convolution forward descriptor");
3490 struct primitive_desc : public mkldnn::primitive_desc {
3491 primitive_desc(const desc &desc, const engine &e)
3492 : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
3494 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
3495 : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
3497 REG_QUERY_MPD(src, src, 0);
3498 REG_QUERY_MPD(weights, weights, 0);
3499 REG_QUERY_MPD(dst, dst, 0);
3502 binary_convolution_forward(const primitive_desc &aprimitive_desc,
3503 const primitive::at &src, const primitive::at &weights, const memory &dst) {
3504 mkldnn_primitive_t result;
3505 mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
3506 const_mkldnn_primitive_t outputs[] = { dst.get() };
3507 check_num_parameters(aprimitive_desc.get(), 2, 1,
3508 "binary convolution forward");
3509 error::wrap_c_api(mkldnn_primitive_create(&result,
3510 aprimitive_desc.get(), inputs, outputs),
3511 "could not create a binary convolution forward primitive");
3518 /// @addtogroup cpp_api_binarization Binarization
3521 struct binarization_forward : public primitive {
3523 mkldnn_binarization_desc_t data;
3525 desc(prop_kind aprop_kind, algorithm alg_kind,
3526 const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &output_mask_desc,
3527 const memory::desc &dst_desc) {
3528 error::wrap_c_api(mkldnn_binarization_forward_desc_init(&data,
3529 mkldnn::convert_to_c(aprop_kind),
3530 mkldnn::convert_to_c(alg_kind),
3531 &src_desc.data, &dst_desc.data,
3532 &weights_desc.data, &output_mask_desc.data),
3533 "could not create a binarization forward descriptor");
3537 struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
3538 primitive_desc(const desc &adesc, const engine &aengine) {
3539 mkldnn_primitive_desc_t result;
3540 error::wrap_c_api(mkldnn_primitive_desc_create(
3541 &result, &adesc.data, aengine.get(), nullptr),
3542 "could not create a binarization forward primitive descriptor");
3546 engine get_engine() { return engine::query(*this); }
3549 binarization_forward(const primitive_desc &aprimitive_desc,
3550 const primitive::at &src, const primitive::at &weights, const primitive::at &output_mask,
3551 const memory &dst) {
3552 mkldnn_primitive_t result;
3553 mkldnn_primitive_at_t inputs[] = { src.data, weights.data, output_mask.data};
3554 const_mkldnn_primitive_t outputs[] = { dst.get() };
3555 error::wrap_c_api(mkldnn_primitive_create(&result, aprimitive_desc.get(), inputs, outputs),
3556 "could not create a binarization forward primitive");
3565 /// @addtogroup cpp_api_stream Stream
3566 /// Execution stream operations.
3568 /// @sa @ref c_api_stream in @ref c_api
3571 #ifndef DOXYGEN_SHOULD_SKIP_THIS
3572 template <> struct handle_traits<mkldnn_stream_t> {
3573 static constexpr auto destructor = &mkldnn_stream_destroy;
3577 struct stream: public handle<mkldnn_stream_t> {
3578 using handle::handle;
3580 enum kind { any = mkldnn_stream_kind_t::mkldnn_any_stream,
3581 eager = mkldnn_stream_kind_t::mkldnn_eager,
3582 lazy = mkldnn_stream_kind_t::mkldnn_lazy };
3584 static mkldnn_stream_kind_t convert_to_c(kind akind) {
3585 return static_cast<mkldnn_stream_kind_t>(akind);
3587 /// Constructs a stream.
3588 stream(kind akind) {
3589 mkldnn_stream_t astream;
3590 error::wrap_c_api(mkldnn_stream_create(&astream,
3591 convert_to_c(akind)),
3592 "could not create a stream");
3596 /// Submits a vector of primitives to a stream for computations.
3598 /// @param primitives The vector of primitives to submit.
3599 /// @returns The stream.
3600 stream &submit(std::vector<primitive> primitives) {
3601 // TODO: find a proper way to convert vector<primitive> to
3602 // vector<mkldnn_primitive_t>
3603 if (primitives.size() == 0) return *this;
3604 std::vector<mkldnn_primitive_t> c_api_primitives;
3605 c_api_primitives.reserve(primitives.size());
3606 auto convert_to_c = [](primitive p) { return p.get(); };
3607 std::transform(primitives.begin(), primitives.end(),
3608 std::back_inserter(c_api_primitives), convert_to_c);
3610 mkldnn_primitive_t c_api_error_primitive;
3612 mkldnn_stream_submit(get(),
3613 c_api_primitives.size(), &c_api_primitives[0],
3614 &c_api_error_primitive),
3615 "could not submit primitives to a stream",
3616 &c_api_error_primitive);
3621 /// Waits for all computations submitted to the stream to complete.
3623 /// @param block Specifies whether the operation should wait indefinitely or
3624 /// return immediately.
3625 /// @returns @c true if all computations completed.
3626 /// @returns @c false if not all computations completed.
3627 bool wait(bool block = true) {
3628 mkldnn_primitive_t c_api_error_primitive;
3629 mkldnn_status_t status = mkldnn_stream_wait(get(),
3630 block, &c_api_error_primitive);
3631 if (status != mkldnn_success
3632 && status != mkldnn_try_again)
3633 error::wrap_c_api(status, "could not wait on a stream",
3634 &c_api_error_primitive);
3635 return (status == mkldnn_success);
3639 mkldnn_primitive_t c_api_error_primitive;
3641 mkldnn_stream_rerun(get(), &c_api_error_primitive),
3642 "could not rerun a stream", &c_api_error_primitive);
3647 #undef REG_QUERY_MPD
3653 } // namespace mkldnn