1 /*******************************************************************************
2 * Copyright 2016-2019 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
20 #ifndef DOXYGEN_SHOULD_SKIP_THIS
33 /// @addtogroup cpp_api C++ API
36 /// @addtogroup cpp_api_utils Utils
39 /// A class that provides the destructor for an Intel(R) MKL-DNN C handle
40 template <typename T> class handle_traits {};
42 /// A class for wrapping an Intel(R) MKL-DNN handle. It is used as the base
43 /// class for primitive (#mkldnn_primitive_t), engine (#mkldnn_engine_t), and
44 /// stream (#mkldnn_stream_t) handles. An object of the #mkldnn::handle class
45 /// can be passed by value. This class enables wrapping:
46 /// - Newly constructed handles.
47 /// @n In this case, the constructed handle uses reference counting provided
48 /// by @p std::shared_ptr with a proper deleter function specified through
49 /// the @p handle_traits class.
50 /// - Pre-existing handles returned by the Intel(R) MKL-DNN C API (for
51 /// example, through #mkldnn_primitive_get_output()).
52 /// @n In this case, an Intel(R) MKL-DNN C API handle is wrapped without a
53 /// deleter because it is assumed that the handle wrapper for the original
54 /// object deletes the handle (this model is similar to @p std::weak_ptr).
55 template <typename T, typename traits=handle_traits<T>> class handle {
57 std::shared_ptr<typename std::remove_pointer<T>::type> _data;
58 handle(const handle &&) = delete;
59 handle &operator=(const handle &&other) = delete;
61 bool operator==(const T other) const { return other == _data.get(); }
62 bool operator!=(const T other) const { return !(*this == other); }
64 /// Constructs a C handle wrapper.
65 /// @param t The C handle to wrap.
66 /// @param weak A flag to specify whether to construct a weak wrapper.
67 handle(T t = 0, bool weak = false): _data(0) {
71 handle(const handle &other): _data(other._data) {}
72 handle &operator=(const handle &other) {
76 /// Resets the value of a C handle.
77 /// @param t The new value of the C handle.
78 /// @param weak A flag to specify whether the wrapper should be weak.
79 void reset(T t, bool weak = false) {
80 if (weak) _data.reset(t, [](T) { return decltype(traits::destructor(0))(0); });
81 else _data.reset(t, traits::destructor);
84 /// Returns the value of the underlying C handle.
85 T get() const { return _data.get(); }
87 bool operator==(const handle &other) const { return other._data.get() == _data.get(); }
88 bool operator!=(const handle &other) const { return !(*this == other); }
91 #ifndef DOXYGEN_SHOULD_SKIP_THIS
92 template <> struct handle_traits<mkldnn_primitive_desc_t> {
93 static constexpr auto destructor = &mkldnn_primitive_desc_destroy;
96 template <> struct handle_traits<mkldnn_primitive_t> {
97 static constexpr auto destructor = &mkldnn_primitive_destroy;
100 template <> struct handle_traits<mkldnn_primitive_desc_iterator_t> {
101 static constexpr auto destructor = &mkldnn_primitive_desc_iterator_destroy;
105 /// Base class for all computational primitives.
106 class primitive: public handle<mkldnn_primitive_t> {
108 friend struct stream;
109 friend class primitive_at;
110 using handle::handle;
112 /// A proxy to C primitive kind enum
114 undefined_primitive = mkldnn_undefined_primitive,
115 memory = mkldnn_memory,
117 reorder = mkldnn_reorder,
118 concat = mkldnn_concat,
119 concat_inplace = mkldnn_concat_inplace,
121 convolution = mkldnn_convolution,
122 deconvolution = mkldnn_deconvolution,
123 shuffle = mkldnn_shuffle,
124 eltwise = mkldnn_eltwise,
125 depthwise = mkldnn_depthwise,
126 softmax = mkldnn_softmax,
127 pooling = mkldnn_pooling,
129 batch_normalization = mkldnn_batch_normalization,
130 inner_product = mkldnn_inner_product,
132 binary_convolution = mkldnn_binary_convolution,
133 binarization = mkldnn_binarization,
134 deformable_convolution = mkldnn_deformable_convolution,
137 /// A wrapper structure to specify a particular output of a primitive.
139 /// The underlying C API structure.
140 mkldnn_primitive_at_t data;
141 /// Constructs a wrapper specifying @p aprimitive output with index @p
144 /// @param aprimitive The target primitive.
145 /// @param at The output index.
147 at(const primitive &aprimitive, size_t at = 0)
148 : data(mkldnn_primitive_at(aprimitive.get(), at)) {}
149 /// Returns the specified output.
150 inline operator primitive() const;
153 /// Returns the descriptor of the underlying C API primitive.
154 inline const_mkldnn_primitive_desc_t get_primitive_desc() const;
155 // TODO: use the C++ API wrapper structure.
158 inline mkldnn_primitive_kind_t convert_to_c(primitive::kind akind) {
159 return static_cast<mkldnn_primitive_kind_t>(akind);
161 /// Intel(R) MKL-DNN exception class.
163 /// This class captures the status returned by the failed C API function, error
164 /// message, and, optionally, handle of the primitive that caused the error.
165 struct error: public std::exception {
166 mkldnn_status_t status;
168 primitive error_primitive;
170 /// Constructs an error instance.
172 /// @param astatus The error status returned by the C API.
173 /// @param amessage The error message.
174 /// @param aerror_primitive (optional) A C handle of the primitive that
175 /// caused the error.
177 error(mkldnn_status_t astatus, std::string amessage,
178 mkldnn_primitive_t aerror_primitive = 0)
181 , error_primitive(aerror_primitive, true)
184 /// A convenience function for wrapping calls to the C API. Checks the
185 /// return status and throws an #error in case of failure.
187 /// @param status The error status returned by the C API.
188 /// @param message The error message.
189 /// @param error_primitive (optional) A C handle of the primitive that
190 /// caused the error.
192 static void wrap_c_api(mkldnn_status_t status,
193 const std::string &message,
194 mkldnn_primitive_t *error_primitive = 0)
196 if (status != mkldnn_success) {
197 if (nullptr != error_primitive)
198 throw error(status, message, *error_primitive);
200 throw error(status, message, nullptr);
205 inline primitive::at::operator primitive() const {
206 const_mkldnn_primitive_t output;
208 mkldnn_primitive_get_output(data.primitive,
209 data.output_index, &output),
210 "could not get an output primitive");
211 return primitive(const_cast<mkldnn_primitive_t>(output), true);
214 const_mkldnn_primitive_desc_t primitive::get_primitive_desc() const {
215 const_mkldnn_primitive_desc_t pd;
216 error::wrap_c_api(mkldnn_primitive_get_primitive_desc(get(), &pd),
217 "could not get primitive descriptor by primitive");
222 /// @addtogroup cpp_api_enums Common data types and enumerations
223 /// A proxy to @ref c_api_types in @ref c_api.
228 round_nearest = mkldnn_round_nearest,
229 round_down = mkldnn_round_down,
232 inline mkldnn_round_mode_t convert_to_c(round_mode mode) {
233 return static_cast<mkldnn_round_mode_t>(mode);
237 zero = mkldnn_padding_zero
240 inline mkldnn_padding_kind_t convert_to_c(padding_kind kind) {
241 return static_cast<mkldnn_padding_kind_t>(kind);
245 forward_training = mkldnn_forward_training,
246 forward_scoring = mkldnn_forward_scoring,
247 forward_inference = mkldnn_forward_inference,
248 forward = mkldnn_forward,
249 backward = mkldnn_backward,
250 backward_data = mkldnn_backward_data,
251 backward_weights = mkldnn_backward_weights,
252 backward_bias = mkldnn_backward_bias
255 inline mkldnn_prop_kind_t convert_to_c(prop_kind kind) {
256 return static_cast<mkldnn_prop_kind_t>(kind);
260 algorithm_undef = mkldnn_alg_kind_undef,
261 convolution_auto = mkldnn_convolution_auto,
262 convolution_direct = mkldnn_convolution_direct,
263 convolution_winograd = mkldnn_convolution_winograd,
264 deconvolution_direct = mkldnn_deconvolution_direct,
265 deconvolution_winograd = mkldnn_deconvolution_winograd,
266 eltwise_relu = mkldnn_eltwise_relu,
267 eltwise_tanh = mkldnn_eltwise_tanh,
268 eltwise_elu = mkldnn_eltwise_elu,
269 eltwise_square = mkldnn_eltwise_square,
270 eltwise_abs = mkldnn_eltwise_abs,
271 eltwise_sqrt = mkldnn_eltwise_sqrt,
272 eltwise_linear = mkldnn_eltwise_linear,
273 eltwise_bounded_relu = mkldnn_eltwise_bounded_relu,
274 eltwise_soft_relu = mkldnn_eltwise_soft_relu,
275 eltwise_logistic = mkldnn_eltwise_logistic,
276 eltwise_clamp = mkldnn_eltwise_clamp,
277 eltwise_exp = mkldnn_eltwise_exp,
278 eltwise_not = mkldnn_eltwise_not,
279 depthwise_scale_shift = mkldnn_depthwise_scale_shift,
280 depthwise_prelu = mkldnn_depthwise_prelu,
281 lrn_across_channels = mkldnn_lrn_across_channels,
282 lrn_within_channel = mkldnn_lrn_within_channel,
283 pooling_max = mkldnn_pooling_max,
284 pooling_avg = mkldnn_pooling_avg,
285 pooling_avg_include_padding = mkldnn_pooling_avg_include_padding,
286 pooling_avg_exclude_padding = mkldnn_pooling_avg_exclude_padding,
287 vanilla_rnn = mkldnn_vanilla_rnn,
288 vanilla_lstm = mkldnn_vanilla_lstm,
289 vanilla_gru = mkldnn_vanilla_gru,
290 gru_linear_before_reset = mkldnn_gru_linear_before_reset,
291 roi_pooling_max = mkldnn_roi_pooling_max,
292 roi_pooling_bilinear = mkldnn_roi_pooling_bilinear,
293 binary_convolution_direct = mkldnn_binary_convolution_direct,
294 binarization_depthwise = mkldnn_binarization_depthwise,
295 deformable_convolution_direct = mkldnn_deformable_convolution_direct,
298 inline mkldnn_alg_kind_t convert_to_c(algorithm aalgorithm) {
299 return static_cast<mkldnn_alg_kind_t>(aalgorithm);
302 enum batch_normalization_flag {
303 use_global_stats = mkldnn_use_global_stats,
304 use_scale_shift = mkldnn_use_scaleshift,
305 fuse_bn_relu = mkldnn_fuse_bn_relu
308 inline mkldnn_batch_normalization_flag_t convert_to_c(
309 batch_normalization_flag aflag) {
310 return static_cast<mkldnn_batch_normalization_flag_t>(aflag);
314 unidirectional_left2right = mkldnn_unidirectional_left2right,
315 unidirectional_right2left = mkldnn_unidirectional_right2left,
316 unidirectional = mkldnn_unidirectional,
317 bidirectional_concat = mkldnn_bidirectional_concat,
318 bidirectional_sum = mkldnn_bidirectional_sum,
321 inline mkldnn_rnn_direction_t convert_to_c(rnn_direction adir) {
322 return static_cast<mkldnn_rnn_direction_t>(adir);
326 undef = mkldnn_query_undef,
328 eengine = mkldnn_query_engine,
329 primitive_kind = mkldnn_query_primitive_kind,
331 num_of_inputs_s32 = mkldnn_query_num_of_inputs_s32,
332 num_of_outputs_s32 = mkldnn_query_num_of_outputs_s32,
334 time_estimate_f64 = mkldnn_query_time_estimate_f64,
335 memory_consumption_s64 = mkldnn_query_memory_consumption_s64,
337 impl_info_str = mkldnn_query_impl_info_str,
339 op_d = mkldnn_query_op_d,
340 memory_d = mkldnn_query_memory_d,
341 convolution_d = mkldnn_query_convolution_d,
342 deconvolution_d = mkldnn_query_deconvolution_d,
343 shuffle_d = mkldnn_query_shuffle_d,
344 eltwise_d = mkldnn_query_eltwise_d,
345 depthwise_d = mkldnn_query_depthwise_d,
346 softmax_d = mkldnn_query_softmax_d,
347 pooling_d = mkldnn_query_pooling_d,
348 lrn_d = mkldnn_query_lrn_d,
349 batch_normalization_d = mkldnn_query_batch_normalization_d,
350 inner_product_d = mkldnn_query_inner_product_d,
351 rnn_d = mkldnn_query_rnn_d,
352 binary_convolution_d = mkldnn_query_binary_convolution_d,
353 binarization_d = mkldnn_query_binarization_d,
354 deformable_convolution_d = mkldnn_query_deformable_convolution_d,
356 input_pd = mkldnn_query_input_pd,
357 output_pd = mkldnn_query_output_pd,
358 src_pd = mkldnn_query_src_pd,
359 diff_src_pd = mkldnn_query_diff_src_pd,
360 weights_pd = mkldnn_query_weights_pd,
361 diff_weights_pd = mkldnn_query_diff_weights_pd,
362 dst_pd = mkldnn_query_dst_pd,
363 diff_dst_pd = mkldnn_query_diff_dst_pd,
364 workspace_pd = mkldnn_query_workspace_pd,
367 inline mkldnn_query_t convert_to_c(query aquery) {
368 return static_cast<mkldnn_query_t>(aquery);
373 /// @addtogroup cpp_api_attr Attributes
374 /// An extension for controlling primitive behavior.
376 /// @sa @ref c_api_attributes in @ref c_api
379 #ifndef DOXYGEN_SHOULD_SKIP_THIS
380 template <> struct handle_traits<mkldnn_post_ops_t> {
381 static constexpr auto destructor = &mkldnn_post_ops_destroy;
385 struct post_ops: public handle<mkldnn_post_ops_t> {
387 mkldnn_post_ops_t result;
388 error::wrap_c_api(mkldnn_post_ops_create(&result),
389 "could not create post operation sequence");
393 int len() const { return mkldnn_post_ops_len(get()); }
395 primitive::kind kind(int index) const {
397 index < len() ? mkldnn_success : mkldnn_invalid_arguments,
398 "post_ops index is out of range");
399 return static_cast<primitive::kind>(mkldnn_post_ops_get_kind(get(),
403 void append_sum(float scale = 1.) {
404 error::wrap_c_api(mkldnn_post_ops_append_sum(get(), scale),
405 "could not append sum");
408 void get_params_sum(int index, float &scale) const {
409 error::wrap_c_api(mkldnn_post_ops_get_params_sum(get(), index, &scale),
410 "could not get sum params");
413 void append_eltwise(float scale, algorithm alg, float alpha,
415 error::wrap_c_api(mkldnn_post_ops_append_eltwise(get(), scale,
416 convert_to_c(alg), alpha, beta),
417 "could not append eltwise");
420 void get_params_eltwise(int index, float &scale, algorithm &alg,
421 float &alpha, float &beta) const {
422 mkldnn_alg_kind_t c_alg;
423 error::wrap_c_api(mkldnn_post_ops_get_params_eltwise(get(), index,
424 &scale, &c_alg, &alpha, &beta),
425 "could not get eltwise params");
426 alg = static_cast<algorithm>(c_alg);
429 void append_depthwise(algorithm alg, const float* weights_data,
430 const float* biases_data) {
431 error::wrap_c_api(mkldnn_post_ops_append_depthwise(get(),
432 convert_to_c(alg), weights_data, biases_data),
433 "could not append depthwise");
436 void get_params_depthwise(int index, algorithm &alg,
437 const float** weights_data, const float** biases_data) const {
438 mkldnn_alg_kind_t c_alg;
439 error::wrap_c_api(mkldnn_post_ops_get_params_depthwise(get(), index,
440 &c_alg, weights_data, biases_data),
441 "could not get depthwise params");
442 alg = static_cast<algorithm>(c_alg);
445 void append_dw_conv(int in_h, int in_w, int ker_h, int ker_w, int str_h, int str_w, mkldnn_data_type_t in_dt,
446 const float* weights_data, const float* biases_data) {
447 error::wrap_c_api(mkldnn_post_ops_append_dw_conv(get(),
448 in_h, in_w, ker_h, ker_w, str_h, str_w, in_dt, weights_data, biases_data),
449 "could not append dw conv");
452 void get_params_dw_conv(int index, int &in_h, int &in_w, int &ker_h, int &ker_w, int &str_h, int &str_w, mkldnn_data_type_t& in_dt,
453 const float** weights_data, const float** biases_data) const {
454 error::wrap_c_api(mkldnn_post_ops_get_params_dw_conv(get(), index,
455 &in_h, &in_w, &ker_h, &ker_w, &str_h, &str_w, &in_dt, weights_data, biases_data),
456 "could not get dw conv params");
459 void append_binarization(algorithm alg, const float* weights_data, const float* output_mask) {
460 error::wrap_c_api(mkldnn_post_ops_append_binarization(get(), convert_to_c(alg), weights_data, output_mask),
461 "could not append binarization");
464 void get_params_binarization(int index, algorithm &alg, const float** weights_data, const float** output_mask) const {
465 mkldnn_alg_kind_t c_alg;
466 error::wrap_c_api(mkldnn_post_ops_get_params_binarization(get(), index, &c_alg, weights_data, output_mask),
467 "could not get binarization params");
468 alg = static_cast<algorithm>(c_alg);
472 #ifndef DOXYGEN_SHOULD_SKIP_THIS
473 template <> struct handle_traits<mkldnn_primitive_attr_t> {
474 static constexpr auto destructor = &mkldnn_primitive_attr_destroy;
478 struct primitive_attr: public handle<mkldnn_primitive_attr_t> {
480 mkldnn_primitive_attr_t result;
481 error::wrap_c_api(mkldnn_primitive_attr_create(&result),
482 "could not create a primitive attr");
486 round_mode get_int_output_round_mode() const {
487 mkldnn_round_mode_t result;
488 error::wrap_c_api(mkldnn_primitive_attr_get_int_output_round_mode(
489 get(), &result), "could not get int output round mode");
490 return round_mode(result);
493 void set_int_output_round_mode(round_mode mode) {
494 error::wrap_c_api(mkldnn_primitive_attr_set_int_output_round_mode(
495 get(), mkldnn::convert_to_c(mode)),
496 "could not set int output round mode");
499 void get_output_scales(int &mask, std::vector<float> &scales) const
502 const float *c_scales;
503 error::wrap_c_api(mkldnn_primitive_attr_get_output_scales(get(),
504 &count, &c_mask, &c_scales),
505 "could not get int output scales");
506 scales.resize(count);
509 for (int c = 0; c < count; ++c)
510 scales[c] = c_scales[c];
513 void set_output_scales(int mask, const std::vector<float> &scales)
515 error::wrap_c_api(mkldnn_primitive_attr_set_output_scales(get(),
516 (int)scales.size(), mask, &scales[0]),
517 "could not set int output scales");
520 const post_ops get_post_ops() const {
522 const_mkldnn_post_ops_t c_result;
523 error::wrap_c_api(mkldnn_primitive_attr_get_post_ops(get(), &c_result),
524 "could not get post operation sequence");
525 result.reset(const_cast<mkldnn_post_ops_t>(c_result), true);
529 void set_post_ops(post_ops ops) {
530 error::wrap_c_api(mkldnn_primitive_attr_set_post_ops(get(), ops.get()),
531 "could not set post operation sequence");
534 void set_rnn_data_qparams(const float scale, const float shift)
536 error::wrap_c_api(mkldnn_primitive_attr_set_rnn_data_qparams(get(),
537 scale, shift), "could not set rnn data int scale/shift");
540 void set_rnn_weights_qparams(int mask, const std::vector<float> &scales)
542 error::wrap_c_api(mkldnn_primitive_attr_set_rnn_weights_qparams(get(),
543 (int)scales.size(), mask, &scales[0]),
544 "could not set rnn weights int scales");
550 /// @addtogroup cpp_api_engine Engine
551 /// Engine operations.
553 /// @sa @ref c_api_engine in @ref c_api
556 #ifndef DOXYGEN_SHOULD_SKIP_THIS
557 template <> struct handle_traits<mkldnn_engine_t> {
558 static constexpr auto destructor = &mkldnn_engine_destroy;
562 /// An execution engine.
563 struct engine: public handle<mkldnn_engine_t> {
564 friend class primitive;
565 // gcc bug??? using handle::handle;
567 /// Kinds of engines.
569 /// An unspecified engine
570 any = mkldnn_any_engine,
575 /// Returns the number of engines of a certain kind.
577 /// @param akind The kind of engines to count.
579 static size_t get_count(kind akind) {
580 return mkldnn_engine_get_count(convert_to_c(akind));
583 /// Constructs an engine.
585 /// @param akind The kind of engine to construct.
586 /// @param index The index of the engine. Must be less than the value
587 /// returned by #get_count() for this particular kind of engine.
589 engine(kind akind, size_t index) {
590 mkldnn_engine_t aengine;
592 mkldnn_engine_create(&aengine,
593 convert_to_c(akind), index),
594 "could not create an engine");
598 explicit engine(const mkldnn_engine_t& aengine)
599 : handle(aengine, true) {}
601 engine(const handle<mkldnn_primitive_desc_t> &pd) {
602 mkldnn_engine_t engine_q;
604 mkldnn_primitive_desc_query(pd.get(),
605 mkldnn::convert_to_c(eengine), 0, &engine_q),
606 "could not get engine from primitive_desc");
607 reset(engine_q, true);
610 template <class primitive_desc>
611 static engine query(const primitive_desc &pd) {
612 mkldnn_engine_t engine_q;
614 mkldnn_primitive_desc_query(pd.get(),
615 mkldnn::convert_to_c(eengine), 0, &engine_q),
616 "could not get engine from primitive_desc");
618 return engine(engine_q);
622 static mkldnn_engine_kind_t convert_to_c(kind akind) {
623 return static_cast<mkldnn_engine_kind_t>(akind);
629 /// @addtogroup cpp_api_memory_related Memory and memory related operations
632 /// @addtogroup cpp_api_memory Memory
633 /// A primitive to describe and store data.
635 /// For more information, refer to @ref c_api_memory in @ref c_api.
638 /// Memory primitive that describes the data.
639 struct memory: public primitive {
641 std::shared_ptr<char> _handle;
644 typedef std::vector<std::remove_extent<mkldnn_dims_t>::type> dims;
646 template <typename T> static void validate_dims(std::vector<T> v) {
647 if (v.size() > TENSOR_MAX_DIMS)
648 throw error(mkldnn_invalid_arguments,
649 "invalid dimensions");
652 /// Data type specification. See #mkldnn_data_type_t for a detailed
655 data_undef = mkldnn_data_type_undef,
665 /// Memory format specification. See #mkldnn_memory_format_t
666 /// for a detailed description.
668 format_undef = mkldnn_format_undef,
670 blocked = mkldnn_blocked,
675 nCw16c = mkldnn_nCw16c,
679 nCw4c = mkldnn_nCw4c,
680 nCw8c = mkldnn_nCw8c,
681 nChw4c = mkldnn_nChw4c,
682 nChw8c = mkldnn_nChw8c,
683 nChw16c = mkldnn_nChw16c,
684 ncdhw = mkldnn_ncdhw,
685 ndhwc = mkldnn_ndhwc,
686 nCdhw4c = mkldnn_nCdhw4c,
687 nCdhw8c = mkldnn_nCdhw8c,
688 nCdhw16c = mkldnn_nCdhw16c,
693 Owi4o = mkldnn_Owi4o,
694 OIw4i4o = mkldnn_OIw4i4o,
695 Owi8o = mkldnn_Owi8o,
696 OIw8o8i = mkldnn_OIw8o8i,
697 OIw8i8o = mkldnn_OIw8i8o,
698 OIw16i16o = mkldnn_OIw16i16o,
699 OIw16o16i = mkldnn_OIw16o16i,
700 Oiw4o = mkldnn_Oiw4o,
701 Oiw16o = mkldnn_Oiw16o,
702 Owi16o = mkldnn_Owi16o,
703 OIw8i16o2i = mkldnn_OIw8i16o2i,
704 OIw8o16i2o = mkldnn_OIw8o16i2o,
705 IOw8o16i2o = mkldnn_IOw8o16i2o,
706 IOw16o16i = mkldnn_IOw16o16i,
707 OIw4i16o4i = mkldnn_OIw4i16o4i,
708 OIw4i16o4i_s8s8 = mkldnn_OIw4i16o4i_s8s8,
713 hwio_s8s8 = mkldnn_hwio_s8s8,
714 dhwio = mkldnn_dhwio,
715 oidhw = mkldnn_oidhw,
716 OIdhw4i4o = mkldnn_OIdhw4i4o,
717 Odhwi4o = mkldnn_Odhwi4o,
718 OIdhw8i8o = mkldnn_OIdhw8i8o,
719 OIdhw8o8i = mkldnn_OIdhw8o8i,
720 Odhwi8o = mkldnn_Odhwi8o,
721 OIdhw16i16o = mkldnn_OIdhw16i16o,
722 OIdhw16o16i = mkldnn_OIdhw16o16i,
723 Oidhw4o = mkldnn_Oidhw4o,
724 Oidhw16o = mkldnn_Oidhw16o,
725 Odhwi16o = mkldnn_Odhwi16o,
726 oIhw8i = mkldnn_oIhw8i,
727 oIhw16i = mkldnn_oIhw16i,
728 oIdhw8i = mkldnn_oIdhw8i,
729 oIdhw16i = mkldnn_oIdhw16i,
730 OIhw4i4o = mkldnn_OIhw4i4o,
731 OIhw8i8o = mkldnn_OIhw8i8o,
732 OIhw16i16o = mkldnn_OIhw16i16o,
733 OIhw8o8i = mkldnn_OIhw8o8i,
734 OIhw16o16i = mkldnn_OIhw16o16i,
735 IOhw16o16i = mkldnn_IOhw16o16i,
736 OIhw8i16o2i = mkldnn_OIhw8i16o2i,
737 IOhw8i16o2i = mkldnn_IOhw8i16o2i,
738 OIhw8o16i2o = mkldnn_OIhw8o16i2o,
739 IOhw8o16i2o = mkldnn_IOhw8o16i2o,
740 OIdhw8i16o2i = mkldnn_OIdhw8i16o2i,
741 OIdhw8o16i2o = mkldnn_OIdhw8o16i2o,
742 IOdhw8o16i2o = mkldnn_IOdhw8o16i2o,
743 OIhw4i16o4i = mkldnn_OIhw4i16o4i,
744 OIhw4i16o4i_s8s8 = mkldnn_OIhw4i16o4i_s8s8,
745 Oihw8o = mkldnn_Oihw8o,
746 Oihw4o = mkldnn_Oihw4o,
747 Oihw16o = mkldnn_Oihw16o,
748 Ohwi8o = mkldnn_Ohwi8o,
749 Ohwi4o = mkldnn_Ohwi4o,
750 Ohwi16o = mkldnn_Ohwi16o,
751 OhIw16o4i = mkldnn_OhIw16o4i,
752 OhIw8o4i = mkldnn_OhIw8o4i,
753 OhIw8o32i = mkldnn_OhIw8o32i,
754 OhIw16o32i = mkldnn_OhIw16o32i,
755 OhIw8o4i_s8s8 = mkldnn_OhIw8o4i_s8s8,
757 gOwi4o = mkldnn_gOwi4o,
758 gOIw4i4o = mkldnn_gOIw4i4o,
759 gOwi8o = mkldnn_gOwi8o,
760 gOIw8o8i = mkldnn_gOIw8o8i,
761 gOIw8i8o = mkldnn_gOIw8i8o,
762 gOIw16i16o = mkldnn_gOIw16i16o,
763 gOIw16o16i = mkldnn_gOIw16o16i,
764 gOiw4o = mkldnn_gOiw4o,
765 gOiw16o = mkldnn_gOiw16o,
766 gOwi16o = mkldnn_gOwi16o,
767 gIOw16o16i = mkldnn_gIOw16o16i,
768 gOIw8i16o2i = mkldnn_gOIw8i16o2i,
769 gOIw8o16i2o = mkldnn_gOIw8o16i2o,
770 gIOw8o16i2o = mkldnn_gIOw8o16i2o,
771 gOIw4i16o4i = mkldnn_gOIw4i16o4i,
772 gOIw4i16o4i_s8s8 = mkldnn_gOIw4i16o4i_s8s8,
773 goihw = mkldnn_goihw,
774 hwigo = mkldnn_hwigo,
775 giohw = mkldnn_giohw,
776 hwigo_s8s8 = mkldnn_hwigo_s8s8,
777 dhwigo_s8s8 = mkldnn_dhwigo_s8s8,
778 dhwigo = mkldnn_dhwigo,
779 dhwio_s8s8 = mkldnn_dhwio_s8s8,
780 gOIdhw4i4o = mkldnn_gOIdhw4i4o,
781 gOdhwi4o = mkldnn_gOdhwi4o,
782 gOIdhw8i8o = mkldnn_gOIdhw8i8o,
783 gOIdhw8o8i = mkldnn_gOIdhw8o8i,
784 gOdhwi8o = mkldnn_gOdhwi8o,
785 gOIhw4i4o = mkldnn_gOIhw4i4o,
786 gOIhw8i8o = mkldnn_gOIhw8i8o,
787 gOIhw16i16o = mkldnn_gOIhw16i16o,
788 gOIhw8i16o2i = mkldnn_gOIhw8i16o2i,
789 gIOhw8i16o2i = mkldnn_gIOhw8i16o2i,
790 gOIhw8o16i2o = mkldnn_gOIhw8o16i2o,
791 gIOhw8o16i2o = mkldnn_gIOhw8o16i2o,
792 gOIdhw8i16o2i = mkldnn_gOIdhw8i16o2i,
793 gOIdhw8o16i2o = mkldnn_gOIdhw8o16i2o,
794 gIOdhw8o16i2o = mkldnn_gIOdhw8o16i2o,
795 gOIhw4i16o4i = mkldnn_gOIhw4i16o4i,
796 gOIhw4i16o4i_s8s8 = mkldnn_gOIhw4i16o4i_s8s8,
797 gOIhw2i8o4i = mkldnn_gOIhw2i8o4i,
798 gOIhw2i8o4i_s8s8 = mkldnn_gOIhw2i8o4i_s8s8,
799 gOihw8o = mkldnn_gOihw8o,
800 gOihw4o = mkldnn_gOihw4o,
801 gOihw16o = mkldnn_gOihw16o,
802 gOhwi4o = mkldnn_gOhwi4o,
803 gOhwi8o = mkldnn_gOhwi8o,
804 gOhwi16o = mkldnn_gOhwi16o,
805 Goihw8g = mkldnn_Goihw8g,
806 Goihw8g_s8s8 = mkldnn_Goihw8g_s8s8,
807 Goiw16g = mkldnn_Goiw16g,
808 Goiw16g_s8s8 = mkldnn_Goiw16g_s8s8,
809 Goihw16g = mkldnn_Goihw16g,
810 Goihw16g_s8s8 = mkldnn_Goihw16g_s8s8,
811 gOIhw4o4i = mkldnn_gOIhw4o4i,
812 gOIhw4o4i_s8s8 = mkldnn_gOIhw4o4i_s8s8,
813 gOIhw8o8i = mkldnn_gOIhw8o8i,
814 gOIhw16o16i = mkldnn_gOIhw16o16i,
815 gIOhw16o16i = mkldnn_gIOhw16o16i,
816 gOhIw16o4i = mkldnn_gOhIw16o4i,
817 gOhIw8o4i = mkldnn_gOhIw8o4i,
818 gOhIw8o4i_s8s8 = mkldnn_gOhIw8o4i_s8s8,
819 goidhw = mkldnn_goidhw,
820 gOIdhw16i16o = mkldnn_gOIdhw16i16o,
821 gOIdhw16o16i = mkldnn_gOIdhw16o16i,
822 gOidhw4o = mkldnn_gOidhw4o,
823 gOidhw16o = mkldnn_gOidhw16o,
824 gOdhwi16o = mkldnn_gOdhwi16o,
825 OdhIw8o4i = mkldnn_OdhIw8o4i,
826 OdhIw8o4i_s8s8 = mkldnn_OdhIw8o4i_s8s8,
827 gOdhIw8o4i = mkldnn_gOdhIw8o4i,
828 gOdhIw8o4i_s8s8 = mkldnn_gOdhIw8o4i_s8s8,
829 Goidhw8g = mkldnn_Goidhw8g,
830 Goidhw8g_s8s8 = mkldnn_Goidhw8g_s8s8,
831 Goidhw16g = mkldnn_Goidhw16g,
832 Goidhw16g_s8s8 = mkldnn_Goidhw16g_s8s8,
835 ldsnc = mkldnn_ldsnc,
836 ldigo = mkldnn_ldigo,
837 ldgoi = mkldnn_ldgoi,
839 rnn_packed = mkldnn_rnn_packed,
840 wino_fmt = mkldnn_wino_fmt,
841 format_last = mkldnn_format_last,
844 /// A memory descriptor.
846 friend struct memory;
847 /// The underlying C API data structure.
848 mkldnn_memory_desc_t data;
850 /// Constructs a memory descriptor.
852 /// @param adims Data dimensions
853 /// @param adata_type Data precision/type.
854 /// @param aformat Data layout format.
855 desc(dims adims, data_type adata_type,
857 validate_dims(adims);
859 mkldnn_memory_desc_init(&data, (int)adims.size(),
860 adims.size() == 0 ? nullptr : &adims[0],
861 convert_to_c(adata_type), convert_to_c(aformat)),
862 "could not initialize a memory descriptor");
865 /// Constructs a memory descriptor from a C API data structure.
867 /// @param adata A C API #mkldnn_memory_desc_t structure.
868 desc(const mkldnn_memory_desc_t &adata): data(adata) {}
871 /// A memory primitive descriptor.
872 struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
873 friend struct memory;
875 // TODO: make private
878 /// Constructs a memory primitive descriptor.
879 primitive_desc(const desc &adesc, const engine &aengine) {
880 mkldnn_primitive_desc_t result;
882 mkldnn_memory_primitive_desc_create(&result,
883 &adesc.data, aengine.get()),
884 "could not initialize a memory primitive descriptor");
888 /// Returns the memory primitive descriptor.
889 memory::desc desc() {
890 auto memory_d = mkldnn_primitive_desc_query_memory_d(get());
891 return memory::desc(*memory_d); }
893 /// Returns the number of bytes required to allocate the memory described
894 /// including the padding area.
895 size_t get_size() const {
896 return mkldnn_memory_primitive_desc_get_size(get());
899 bool operator==(const primitive_desc &other) const {
900 return (0 == mkldnn_memory_primitive_desc_equal(get(),
901 other.get())) ? false : true;
904 bool operator!=(const primitive_desc &other) const {
905 return !operator==(other);
908 engine get_engine() { return engine::query(*this); }
911 /// Constructs a memory primitive from a generic primitive.
913 /// @param aprimitive The primitive to treat as memory.
914 memory(const primitive &aprimitive): primitive(aprimitive) {}
915 /// Constructs a memory primitive.
917 /// @param adesc Memory primitive descriptor.
918 memory(const primitive_desc &adesc) {
919 mkldnn_primitive_t result;
921 mkldnn_primitive_create(&result, adesc.get(), nullptr, nullptr),
922 "could not create a memory primitive");
924 auto _malloc = [](size_t size, int alignment) {
927 ptr = _aligned_malloc(size, alignment);
928 int rc = ((ptr)? 0 : errno);
930 int rc = ::posix_memalign(&ptr, alignment, size);
932 return (rc == 0) ? (char*)ptr : nullptr;
934 auto _free = [](char* p) {
936 _aligned_free((void*)p);
941 _handle.reset(_malloc(adesc.get_size(), 4096), _free);
942 set_data_handle(_handle.get());
945 memory(const primitive_desc &adesc, void *ahandle) {
946 mkldnn_primitive_t result;
948 mkldnn_primitive_create(&result, adesc.get(), nullptr, nullptr),
949 "could not create a memory primitive");
951 set_data_handle(ahandle);
954 /// Returns the descriptor of the memory primitive.
955 primitive_desc get_primitive_desc() const {
956 primitive_desc adesc;
957 const_mkldnn_primitive_desc_t cdesc;
958 error::wrap_c_api(mkldnn_primitive_get_primitive_desc(get(),
960 "could not get primitive descriptor from a memory primitive");
961 /* FIXME: no const_cast should be here */
962 adesc.reset(const_cast<mkldnn_primitive_desc_t>(cdesc), true);
966 /// Returns a handle of the data contained in the memory primitive. On
967 /// the CPU engine, this is a pointer to the allocated memory.
968 inline void *get_data_handle() const {
970 error::wrap_c_api(mkldnn_memory_get_data_handle(get(), &handle),
971 "could not get native handle");
975 inline void set_data_handle(void *handle) const {
976 error::wrap_c_api(mkldnn_memory_set_data_handle(get(), handle),
977 "could not set native handle");
980 inline void set_data_handle_no_pads_proc(void *handle) const {
981 error::wrap_c_api(mkldnn_memory_set_data_handle_no_pads_proc(get(), handle),
982 "could not set native handle");
985 // Must go away or be private:
986 static mkldnn_data_type_t convert_to_c(data_type adata_type) {
987 return static_cast<mkldnn_data_type_t>(adata_type);
989 static mkldnn_memory_format_t convert_to_c(format aformat) {
990 return static_cast<mkldnn_memory_format_t>(aformat);
994 inline memory::desc zero_md() {
995 auto zero = mkldnn_memory_desc_t();
996 zero.primitive_kind = mkldnn_memory;
997 return memory::desc(zero);
1000 inline memory null_memory(engine eng) {
1001 mkldnn::memory::desc zero = zero_md();
1002 return memory({zero, eng}, nullptr);
1005 inline void check_num_parameters(const const_mkldnn_primitive_desc_t
1006 &aprimitive_desc, int n_inputs, int n_outputs,
1007 const std::string &prim_name) {
1008 const int n_inputs_expected = mkldnn_primitive_desc_query_s32(
1009 aprimitive_desc, mkldnn_query_num_of_inputs_s32, 0);
1010 const int n_outputs_expected = mkldnn_primitive_desc_query_s32(
1011 aprimitive_desc, mkldnn_query_num_of_outputs_s32, 0);
1012 if (n_outputs_expected > n_outputs ) {
1013 std::string message = "could not create " + prim_name +
1014 " primitive, not enought output parameters";
1015 throw error(mkldnn_invalid_arguments, message, nullptr);
1017 if (n_inputs_expected > n_inputs ) {
1018 std::string message = "could not create " + prim_name +
1019 " primitive, not enought input parameters";
1020 throw error(mkldnn_invalid_arguments, message, nullptr);
1025 inline bool is_null_memory(const const_mkldnn_primitive_t &aprimitive) {
1026 const_mkldnn_primitive_desc_t aprimitive_pd;
1027 mkldnn_primitive_get_primitive_desc(aprimitive, &aprimitive_pd);
1028 const mkldnn_memory_desc_t *aprimitive_md = mkldnn_primitive_desc_query_memory_d(
1031 return ((aprimitive_md != nullptr) && (aprimitive_md->ndims == 0));
1034 inline bool operator==(mkldnn_data_type_t a, memory::data_type b) {
1035 return a == memory::convert_to_c(b);
1037 inline bool operator!=(mkldnn_data_type_t a, memory::data_type b) {
1040 inline bool operator==(memory::data_type a, mkldnn_data_type_t b) {
1043 inline bool operator!=(memory::data_type a, mkldnn_data_type_t b) {
1047 inline bool operator==(mkldnn_memory_format_t a, memory::format b) {
1048 return a == memory::convert_to_c(b);
1050 inline bool operator!=(mkldnn_memory_format_t a, memory::format b) {
1053 inline bool operator==(memory::format a, mkldnn_memory_format_t b) {
1056 inline bool operator!=(memory::format a, mkldnn_memory_format_t b) {
1062 /// @addtogroup cpp_api_reorder Reorder
1063 /// A primitive to copy data between memory formats.
1065 /// @sa @ref c_api_reorder in @ref c_api
1068 struct reorder : public primitive {
1069 struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1070 primitive_desc(const memory::primitive_desc &input,
1071 const memory::primitive_desc &output) {
1072 mkldnn_primitive_desc_t result;
1073 error::wrap_c_api(mkldnn_reorder_primitive_desc_create(
1074 &result, input.get(), output.get()),
1075 "could not create a reorder primitive descriptor");
1079 primitive_desc(const memory::primitive_desc &input,
1080 const memory::primitive_desc &output,
1081 const primitive_attr &aattr) {
1082 mkldnn_primitive_desc_t result;
1083 error::wrap_c_api(mkldnn_reorder_primitive_desc_create_v2(
1084 &result, input.get(), output.get(), aattr.get()),
1085 "could not create a reorder primitive descriptor");
1089 engine get_engine() { return engine::query(*this); }
1092 reorder(const primitive_desc &aprimitive_desc,
1093 const primitive::at &input, const memory &output) {
1094 mkldnn_primitive_t result;
1095 mkldnn_primitive_at_t inputs[] = { input.data };
1096 const_mkldnn_primitive_t outputs[] = { output.get() };
1097 error::wrap_c_api(mkldnn_primitive_create(&result,
1098 aprimitive_desc.get(), inputs, outputs),
1099 "could not create a reorder primitive");
1103 reorder(const primitive::at &input, const memory &output) {
1104 auto input_mpd = memory(input).get_primitive_desc();
1105 auto output_mpd = output.get_primitive_desc();
1107 auto reorder_d = primitive_desc(input_mpd, output_mpd);
1109 mkldnn_primitive_t result;
1110 mkldnn_primitive_at_t inputs[] = { input.data };
1111 const_mkldnn_primitive_t outputs[] = { output.get() };
1112 error::wrap_c_api(mkldnn_primitive_create(&result,
1113 reorder_d.get(), inputs, outputs),
1114 "could not create a reorder primitive");
1121 /// @addtogroup cpp_api_view View
1122 /// A primitive to view on a memory.
1124 /// @sa mkldnn_view_primitive_desc_create in @ref c_api
1127 struct view : public primitive {
1128 struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1129 primitive_desc(const memory::primitive_desc &input, memory::dims dims,
1130 memory::dims offsets) {
1131 mkldnn_primitive_desc_t result;
1133 error::wrap_c_api(mkldnn_view_primitive_desc_create(
1134 &result, input.get(), &dims[0], &offsets[0]),
1135 "could not create a view primitive descriptor");
1139 memory::primitive_desc dst_primitive_desc() const {
1140 memory::primitive_desc adesc;
1141 mkldnn_primitive_desc_t cdesc;
1142 const_mkldnn_primitive_desc_t const_cdesc =
1143 mkldnn_primitive_desc_query_pd(get(),
1144 mkldnn::convert_to_c(dst_pd), 0);
1145 error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc,
1147 "could not clone a dst primitive descriptor");
1152 engine get_engine() { return engine::query(*this); }
1155 view(const primitive_desc &view_pd, primitive::at input) {
1156 mkldnn_primitive_t result;
1157 mkldnn_primitive_at_t inputs[] = { input.data };
1158 error::wrap_c_api(mkldnn_primitive_create(&result,
1159 view_pd.get(), inputs, nullptr),
1160 "could not create a view primitive");
1164 view(memory input, memory::dims dims, memory::dims offsets) {
1165 mkldnn_primitive_t result;
1166 primitive_desc view_pd(input.get_primitive_desc(), dims,
1168 mkldnn_primitive_at_t inputs[] = { primitive::at(input).data };
1169 error::wrap_c_api(mkldnn_primitive_create(&result,
1170 view_pd.get(), inputs, nullptr),
1171 "could not create a view primitive");
1178 /// @addtogroup cpp_api_concat Concat
1179 /// A primitive to concatenate data by arbitrary dimension.
1181 /// @sa @ref c_api_concat in @ref c_api
1184 struct concat : public primitive {
1185 struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1186 std::vector<const_mkldnn_primitive_desc_t> cpp_to_c(
1187 std::vector<memory::primitive_desc> inputs) {
1188 std::vector<const_mkldnn_primitive_desc_t> c_api_inputs;
1189 c_api_inputs.reserve(inputs.size());
1190 auto convert_to_c = [](memory::primitive_desc d) { return d.get(); };
1191 std::transform(inputs.begin(), inputs.end(),
1192 std::back_inserter(c_api_inputs), convert_to_c);
1193 return c_api_inputs;
1196 primitive_desc(const memory::desc &output, int concat_dimension,
1197 std::vector<memory::primitive_desc> inputs) {
1198 mkldnn_primitive_desc_t result;
1200 auto c_api_inputs = cpp_to_c(inputs);
1202 error::wrap_c_api(mkldnn_concat_primitive_desc_create(
1203 &result, &output.data, (int)c_api_inputs.size(),
1204 concat_dimension, &c_api_inputs[0]),
1205 "could not create a concat primitive descriptor");
1209 primitive_desc(int concat_dimension,
1210 std::vector<memory::primitive_desc> inputs) {
1211 mkldnn_primitive_desc_t result;
1213 auto c_api_inputs = cpp_to_c(inputs);
1215 error::wrap_c_api(mkldnn_concat_primitive_desc_create(
1216 &result, nullptr, (int)c_api_inputs.size(),
1217 concat_dimension, &c_api_inputs[0]),
1218 "could not create a concat primitive descriptor");
1222 memory::primitive_desc dst_primitive_desc() const {
1223 memory::primitive_desc adesc;
1224 mkldnn_primitive_desc_t cdesc;
1225 const_mkldnn_primitive_desc_t const_cdesc =
1226 mkldnn_primitive_desc_query_pd(get(),
1227 mkldnn::convert_to_c(dst_pd), 0);
1228 error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1229 "could not clone a dst primitive descriptor");
1234 engine get_engine() { return engine::query(*this); }
1237 concat(const primitive_desc &concat_pd,
1238 std::vector<primitive::at> &inputs, const memory &output) {
1239 mkldnn_primitive_t result;
1241 std::vector<mkldnn_primitive_at_t> p_inputs;
1242 for (size_t i = 0; i < inputs.size(); i++)
1243 p_inputs.push_back(inputs[i].data);
1244 const_mkldnn_primitive_t outputs[] = { output.get() };
1246 error::wrap_c_api(mkldnn_primitive_create(&result,
1247 concat_pd.get(), &p_inputs[0], outputs),
1248 "could not create a concat primitive");
1255 /// @addtogroup cpp_api_sum Sum
1256 /// A primitive to sum data.
1258 /// @sa @ref c_api_sum in @ref c_api
1261 struct sum : public primitive {
1262 struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1263 std::vector<const_mkldnn_primitive_desc_t> cpp_to_c(
1264 std::vector<memory::primitive_desc> inputs) {
1265 std::vector<const_mkldnn_primitive_desc_t> c_api_inputs;
1266 c_api_inputs.reserve(inputs.size());
1267 auto convert_to_c = [](memory::primitive_desc d) { return d.get();};
1268 std::transform(inputs.begin(), inputs.end(),
1269 std::back_inserter(c_api_inputs), convert_to_c);
1270 return c_api_inputs;
1273 primitive_desc(const memory::desc &output,
1274 const std::vector<float> &scales,
1275 std::vector<memory::primitive_desc> inputs) {
1276 mkldnn_primitive_desc_t result;
1278 auto c_api_inputs = cpp_to_c(inputs);
1281 scales.size() == inputs.size() ? mkldnn_success
1282 : mkldnn_invalid_arguments,
1283 "number of scales not equal to number of inputs");
1285 error::wrap_c_api(mkldnn_sum_primitive_desc_create(
1286 &result, &output.data, (int)c_api_inputs.size(),
1287 &scales[0], &c_api_inputs[0]),
1288 "could not create a sum primitive descriptor");
1292 primitive_desc(const std::vector<float> &scales,
1293 std::vector<memory::primitive_desc> inputs) {
1294 mkldnn_primitive_desc_t result;
1296 auto c_api_inputs = cpp_to_c(inputs);
1299 scales.size() == inputs.size() ? mkldnn_success
1300 : mkldnn_invalid_arguments,
1301 "number of scales not equal to number of inputs");
1303 error::wrap_c_api(mkldnn_sum_primitive_desc_create(
1304 &result, nullptr, (int)c_api_inputs.size(), &scales[0],
1306 "could not create a sum primitive descriptor");
1310 memory::primitive_desc dst_primitive_desc() const {
1311 memory::primitive_desc adesc;
1312 mkldnn_primitive_desc_t cdesc;
1313 const_mkldnn_primitive_desc_t const_cdesc =
1314 mkldnn_primitive_desc_query_pd(get(),
1315 mkldnn::convert_to_c(dst_pd), 0);
1316 error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc,
1318 "could not clone a dst primitive descriptor");
1323 engine get_engine() { return engine::query(*this); }
1326 sum(const primitive_desc &sum_pd,
1327 std::vector<primitive::at> &inputs, const memory &output) {
1328 mkldnn_primitive_t result;
1330 std::vector<mkldnn_primitive_at_t> p_inputs;
1331 for (size_t i = 0; i < inputs.size(); i++)
1332 p_inputs.push_back(inputs[i].data);
1333 const_mkldnn_primitive_t outputs[] = { output.get() };
1335 error::wrap_c_api(mkldnn_primitive_create(&result,
1336 sum_pd.get(), &p_inputs[0], outputs),
1337 "could not create a sum primitive");
1346 /// @addtogroup cpp_api_primitives Primitives
1349 /// @addtogroup cpp_api_primitive_descriptors Primitive descriptors
1352 /// A base class for all primitive descriptors.
1353 struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1354 primitive_desc(const_mkldnn_op_desc_t desc, const primitive_attr *attr,
1355 const engine &e, const_mkldnn_primitive_desc_t hint_fwd_pd) {
1356 mkldnn_primitive_desc_iterator_t iterator = nullptr;
1357 mkldnn_status_t status = mkldnn_primitive_desc_iterator_create_v2(
1358 &iterator, desc, attr ? attr->get() : nullptr, e.get(),
1360 error::wrap_c_api(status,
1361 "could not create a primitive descriptor iterator");
1362 pd_iterator.reset(iterator);
1366 engine get_engine() { return engine::query(*this); }
1368 primitive_attr get_primitive_attr() const {
1369 const_mkldnn_primitive_attr_t const_cattr;
1370 error::wrap_c_api(mkldnn_primitive_desc_get_attr(get(), &const_cattr),
1371 "could not get attributes");
1372 mkldnn_primitive_attr_t cattr;
1373 error::wrap_c_api(mkldnn_primitive_attr_clone(&cattr, const_cattr),
1374 "could not clone attributes");
1376 primitive_attr attr;
1381 /// Returns implementation name
1382 const char *impl_info_str() const {
1384 error::wrap_c_api(mkldnn_primitive_desc_query(get(),
1385 mkldnn_query_impl_info_str, 0, &res),
1386 "could not query implementation info string");
1390 /// Advances the next implementation for the given op descriptor.
1393 /// - @c true on success
1394 /// - @c false if the last implementation reached, and
1395 /// the primitive descriptor itself is kept unchanged
1397 mkldnn_status_t status = mkldnn_primitive_desc_iterator_next(
1399 if (status == mkldnn_iterator_ends) return false;
1400 error::wrap_c_api(status, "primitive descriptor iterator next failed");
1406 /// Queries and returns requested memory primitive descriptor.
1407 memory::primitive_desc query_mpd(query what, int idx = 0) const {
1408 std::vector<query> valid_w{input_pd, output_pd, src_pd, diff_src_pd,
1409 weights_pd, diff_weights_pd, dst_pd, diff_dst_pd, workspace_pd};
1410 if (!std::any_of(valid_w.cbegin(), valid_w.cend(),
1411 [=](query q) { return what == q; }))
1412 throw error(mkldnn_invalid_arguments, "invalid memory query");
1414 const_mkldnn_primitive_desc_t const_cdesc
1415 = mkldnn_primitive_desc_query_pd(get(),
1416 mkldnn::convert_to_c(what), idx);
1418 // TODO: is there a better way to inform about this?
1419 if (const_cdesc == nullptr)
1420 throw error(mkldnn_not_required, "queried memory is not required");
1422 mkldnn_primitive_desc_t cdesc;
1423 error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1424 "could not clone a memory primitive descriptor");
1426 memory::primitive_desc ret;
1431 // register specialized queries, e.g. src_primitive_desc()
1432 # define REG_QUERY_MPD(name, what, idx) \
1433 memory::primitive_desc name ## _primitive_desc() const \
1434 { return query_mpd(what ## _pd, idx); }
1437 handle<mkldnn_primitive_desc_iterator_t> pd_iterator;
1439 mkldnn_primitive_desc_t pd = mkldnn_primitive_desc_iterator_fetch(
1441 error::wrap_c_api(pd != nullptr ? mkldnn_success : mkldnn_runtime_error,
1442 "could not fetch a primitive descriptor from the iterator");
1449 /// @addtogroup cpp_api_convolution Convolution
1450 /// A primitive to compute convolution using different algorithms.
1452 /// @sa @ref c_api_convolution in @ref c_api
1455 struct convolution_forward: public primitive {
1457 mkldnn_convolution_desc_t data;
1458 desc(prop_kind aprop_kind, algorithm aalgorithm,
1459 const memory::desc &src_desc,
1460 const memory::desc &weights_desc,
1461 const memory::desc &bias_desc,
1462 const memory::desc &dst_desc,
1463 const memory::dims strides,
1464 const memory::dims padding_l,
1465 const memory::dims padding_r,
1466 const padding_kind apadding_kind) {
1467 memory::validate_dims(strides);
1468 memory::validate_dims(padding_l);
1469 memory::validate_dims(padding_r);
1470 error::wrap_c_api(mkldnn_convolution_forward_desc_init(&data,
1471 mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1472 &src_desc.data, &weights_desc.data, &bias_desc.data,
1473 &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
1474 mkldnn::convert_to_c(apadding_kind)),
1475 "could not create a convolution forward descriptor");
1477 desc(prop_kind aprop_kind, algorithm aalgorithm,
1478 const memory::desc &src_desc,
1479 const memory::desc &weights_desc,
1480 const memory::desc &dst_desc,
1481 const memory::dims strides,
1482 const memory::dims padding_l,
1483 const memory::dims padding_r,
1484 const padding_kind apadding_kind) {
1485 memory::validate_dims(strides);
1486 memory::validate_dims(padding_l);
1487 memory::validate_dims(padding_r);
1488 error::wrap_c_api(mkldnn_convolution_forward_desc_init(&data,
1489 mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1490 &src_desc.data, &weights_desc.data, nullptr,
1491 &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
1492 mkldnn::convert_to_c(apadding_kind)),
1493 "could not create a convolution forward descriptor");
1495 desc(prop_kind aprop_kind, algorithm aalgorithm,
1496 const memory::desc &src_desc,
1497 const memory::desc &weights_desc,
1498 const memory::desc &bias_desc,
1499 const memory::desc &dst_desc,
1500 const memory::dims strides,
1501 const memory::dims dilates,
1502 const memory::dims padding_l,
1503 const memory::dims padding_r,
1504 const padding_kind apadding_kind) {
1505 memory::validate_dims(strides);
1506 memory::validate_dims(dilates);
1507 memory::validate_dims(padding_l);
1508 memory::validate_dims(padding_r);
1510 mkldnn_dilated_convolution_forward_desc_init(&data,
1511 mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1512 &src_desc.data, &weights_desc.data, &bias_desc.data,
1513 &dst_desc.data, &strides[0], &dilates[0],
1514 &padding_l[0], &padding_r[0],
1515 mkldnn::convert_to_c(apadding_kind)),
1516 "could not create a dilated convolution forward descriptor");
1518 desc(prop_kind aprop_kind, algorithm aalgorithm,
1519 const memory::desc &src_desc,
1520 const memory::desc &weights_desc,
1521 const memory::desc &dst_desc,
1522 const memory::dims strides,
1523 const memory::dims dilates,
1524 const memory::dims padding_l,
1525 const memory::dims padding_r,
1526 const padding_kind apadding_kind) {
1527 memory::validate_dims(strides);
1528 memory::validate_dims(dilates);
1529 memory::validate_dims(padding_l);
1530 memory::validate_dims(padding_r);
1532 mkldnn_dilated_convolution_forward_desc_init(&data,
1533 mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1534 &src_desc.data, &weights_desc.data, nullptr,
1535 &dst_desc.data, &strides[0], &dilates[0],
1536 &padding_l[0], &padding_r[0],
1537 mkldnn::convert_to_c(apadding_kind)),
1538 "could not create a dilated convolution forward descriptor");
1542 struct primitive_desc : public mkldnn::primitive_desc {
1543 primitive_desc(const desc &desc, const engine &e)
1544 : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
1546 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
1547 : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
1549 REG_QUERY_MPD(src, src, 0);
1550 REG_QUERY_MPD(weights, weights, 0);
1551 REG_QUERY_MPD(bias, weights, 1);
1552 REG_QUERY_MPD(dst, dst, 0);
1555 convolution_forward(const primitive_desc &aprimitive_desc,
1556 const primitive::at &src, const primitive::at &weights,
1557 const primitive::at &bias, const memory &dst) {
1558 mkldnn_primitive_t result;
1559 mkldnn_primitive_at_t inputs[] = { src.data, weights.data,
1561 const_mkldnn_primitive_t outputs[] = { dst.get() };
1562 error::wrap_c_api(mkldnn_primitive_create(&result,
1563 aprimitive_desc.get(), inputs, outputs),
1564 "could not create a convolution forward bias primitive");
1568 convolution_forward(const primitive_desc &aprimitive_desc,
1569 const primitive::at &src, const primitive::at &weights,
1570 const memory &dst) {
1571 mkldnn_primitive_t result;
1572 mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
1573 const_mkldnn_primitive_t outputs[] = { dst.get() };
1574 check_num_parameters(aprimitive_desc.get(), 2, 1,
1575 "convolution forward");
1576 error::wrap_c_api(mkldnn_primitive_create(&result,
1577 aprimitive_desc.get(), inputs, outputs),
1578 "could not create a convolution forward primitive");
1583 struct convolution_backward_data : public primitive {
1585 mkldnn_convolution_desc_t data;
1586 desc(algorithm aalgorithm,
1587 const memory::desc &diff_src_desc,
1588 const memory::desc &weights_desc,
1589 const memory::desc &diff_dst_desc,
1590 const memory::dims strides,
1591 const memory::dims padding_l,
1592 const memory::dims padding_r,
1593 const padding_kind apadding_kind) {
1594 memory::validate_dims(strides);
1595 memory::validate_dims(padding_l);
1596 memory::validate_dims(padding_r);
1597 error::wrap_c_api(mkldnn_convolution_backward_data_desc_init(
1598 &data, convert_to_c(aalgorithm), &diff_src_desc.data,
1599 &weights_desc.data, &diff_dst_desc.data,
1600 &strides[0], &padding_l[0], &padding_r[0],
1601 mkldnn::convert_to_c(apadding_kind)),
1602 "could not create a convolution backward data descriptor");
1604 desc(algorithm aalgorithm,
1605 const memory::desc &diff_src_desc,
1606 const memory::desc &weights_desc,
1607 const memory::desc &diff_dst_desc,
1608 const memory::dims strides,
1609 const memory::dims dilates,
1610 const memory::dims padding_l,
1611 const memory::dims padding_r,
1612 const padding_kind apadding_kind) {
1613 memory::validate_dims(strides);
1614 memory::validate_dims(dilates);
1615 memory::validate_dims(padding_l);
1616 memory::validate_dims(padding_r);
1618 mkldnn_dilated_convolution_backward_data_desc_init(
1619 &data, convert_to_c(aalgorithm), &diff_src_desc.data,
1620 &weights_desc.data, &diff_dst_desc.data,
1621 &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1622 mkldnn::convert_to_c(apadding_kind)),
1623 "could not create a convolution backward data descriptor");
1627 struct primitive_desc : public mkldnn::primitive_desc {
1628 primitive_desc(const desc &desc, const engine &e,
1629 const convolution_forward::primitive_desc &hint_fwd_pd)
1630 : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
1632 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
1633 const convolution_forward::primitive_desc &hint_fwd_pd)
1634 : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
1636 REG_QUERY_MPD(diff_src, diff_src, 0);
1637 REG_QUERY_MPD(weights, weights, 0);
1638 REG_QUERY_MPD(diff_dst, diff_dst, 0);
1641 convolution_backward_data(const primitive_desc &aprimitive_desc,
1642 const primitive::at &diff_dst, const primitive::at &weights,
1643 const memory &diff_src) {
1644 mkldnn_primitive_t result;
1645 mkldnn_primitive_at_t inputs[] = { diff_dst.data, weights.data };
1646 const_mkldnn_primitive_t outputs[] = { diff_src.get() };
1647 check_num_parameters(aprimitive_desc.get(), 2, 1,
1648 "convolution backward data");
1649 error::wrap_c_api(mkldnn_primitive_create(&result,
1650 aprimitive_desc.get(), inputs, outputs),
1651 "could not create a convolution backward data primitive");
1656 struct convolution_backward_weights : public primitive {
1658 mkldnn_convolution_desc_t data;
1659 desc(algorithm aalgorithm,
1660 const memory::desc &src_desc,
1661 const memory::desc &diff_weights_desc,
1662 const memory::desc &diff_bias_desc,
1663 const memory::desc &diff_dst_desc,
1664 const memory::dims strides,
1665 const memory::dims padding_l,
1666 const memory::dims padding_r,
1667 const padding_kind apadding_kind) {
1668 memory::validate_dims(strides);
1669 memory::validate_dims(padding_l);
1670 memory::validate_dims(padding_r);
1671 error::wrap_c_api(mkldnn_convolution_backward_weights_desc_init(
1672 &data, convert_to_c(aalgorithm), &src_desc.data,
1673 &diff_weights_desc.data, &diff_bias_desc.data,
1674 &diff_dst_desc.data,
1675 &strides[0], &padding_l[0], &padding_r[0],
1676 mkldnn::convert_to_c(apadding_kind)),
1677 "could not create a convolution backward weights descriptor");
1679 desc(algorithm aalgorithm,
1680 const memory::desc &src_desc,
1681 const memory::desc &diff_weights_desc,
1682 const memory::desc &diff_dst_desc,
1683 const memory::dims strides,
1684 const memory::dims padding_l,
1685 const memory::dims padding_r,
1686 const padding_kind apadding_kind) {
1687 memory::validate_dims(strides);
1688 memory::validate_dims(padding_l);
1689 memory::validate_dims(padding_r);
1690 error::wrap_c_api(mkldnn_convolution_backward_weights_desc_init(
1691 &data, convert_to_c(aalgorithm), &src_desc.data,
1692 &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
1693 &strides[0], &padding_l[0], &padding_r[0],
1694 mkldnn::convert_to_c(apadding_kind)),
1695 "could not create a convolution backward weights descriptor");
1697 desc(algorithm aalgorithm,
1698 const memory::desc &src_desc,
1699 const memory::desc &diff_weights_desc,
1700 const memory::desc &diff_bias_desc,
1701 const memory::desc &diff_dst_desc,
1702 const memory::dims strides,
1703 const memory::dims dilates,
1704 const memory::dims padding_l,
1705 const memory::dims padding_r,
1706 const padding_kind apadding_kind) {
1707 memory::validate_dims(strides);
1708 memory::validate_dims(dilates);
1709 memory::validate_dims(padding_l);
1710 memory::validate_dims(padding_r);
1711 error::wrap_c_api(mkldnn_dilated_convolution_backward_weights_desc_init(
1712 &data, convert_to_c(aalgorithm), &src_desc.data,
1713 &diff_weights_desc.data, &diff_bias_desc.data,
1714 &diff_dst_desc.data,
1715 &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1716 mkldnn::convert_to_c(apadding_kind)),
1717 "could not create a convolution backward weights descriptor");
1719 desc(algorithm aalgorithm,
1720 const memory::desc &src_desc,
1721 const memory::desc &diff_weights_desc,
1722 const memory::desc &diff_dst_desc,
1723 const memory::dims strides,
1724 const memory::dims dilates,
1725 const memory::dims padding_l,
1726 const memory::dims padding_r,
1727 const padding_kind apadding_kind) {
1728 memory::validate_dims(strides);
1729 memory::validate_dims(dilates);
1730 memory::validate_dims(padding_l);
1731 memory::validate_dims(padding_r);
1732 error::wrap_c_api(mkldnn_dilated_convolution_backward_weights_desc_init(
1733 &data, convert_to_c(aalgorithm), &src_desc.data,
1734 &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
1735 &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1736 mkldnn::convert_to_c(apadding_kind)),
1737 "could not create a convolution backward weights descriptor");
1742 struct primitive_desc : public mkldnn::primitive_desc {
1743 primitive_desc(const desc &desc, const engine &e,
1744 const convolution_forward::primitive_desc &hint_fwd_pd)
1745 : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
1747 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
1748 const convolution_forward::primitive_desc &hint_fwd_pd)
1749 : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
1751 REG_QUERY_MPD(src, src, 0);
1752 REG_QUERY_MPD(diff_weights, diff_weights, 0);
1753 REG_QUERY_MPD(diff_bias, diff_weights, 1);
1754 REG_QUERY_MPD(diff_dst, diff_dst, 0);
1757 convolution_backward_weights(const primitive_desc &aprimitive_desc,
1758 const primitive::at &src, const primitive::at &diff_dst,
1759 const memory &diff_weights, const memory &diff_bias) {
1760 mkldnn_primitive_t result;
1761 mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
1762 const_mkldnn_primitive_t outputs[] = { diff_weights.get(),
1764 check_num_parameters(aprimitive_desc.get(), 2, 2,
1765 "convolution backward weights");
1766 error::wrap_c_api(mkldnn_primitive_create(&result,
1767 aprimitive_desc.get(), inputs, outputs),
1768 "could not create a convolution backward weights primitive");
1771 convolution_backward_weights(const primitive_desc &aprimitive_desc,
1772 const primitive::at &src, const primitive::at &diff_dst,
1773 const memory &diff_weights) {
1774 mkldnn_primitive_t result;
1775 mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
1776 const_mkldnn_primitive_t outputs[] = { diff_weights.get() };
1777 check_num_parameters(aprimitive_desc.get(), 2, 1,
1778 "convolution backward weights");
1779 error::wrap_c_api(mkldnn_primitive_create(&result,
1780 aprimitive_desc.get(), inputs, outputs),
1781 "could not create a convolution backward weights primitive");
1788 /// @addtogroup cpp_api_deconvolution Deconvolution
1789 /// A primitive to compute deconvolution using different algorithms.
1791 /// @sa @ref c_api_deconvolution in @ref c_api
1794 struct deconvolution_forward: public primitive {
1796 mkldnn_deconvolution_desc_t data;
1797 desc(prop_kind aprop_kind, algorithm aalgorithm,
1798 const memory::desc &src_desc,
1799 const memory::desc &weights_desc,
1800 const memory::desc &bias_desc,
1801 const memory::desc &dst_desc,
1802 const memory::dims strides,
1803 const memory::dims padding_l,
1804 const memory::dims padding_r,
1805 const padding_kind apadding_kind) {
1806 memory::validate_dims(strides);
1807 memory::validate_dims(padding_l);
1808 memory::validate_dims(padding_r);
1809 error::wrap_c_api(mkldnn_deconvolution_forward_desc_init(&data,
1810 mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1811 &src_desc.data, &weights_desc.data, &bias_desc.data,
1812 &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
1813 mkldnn::convert_to_c(apadding_kind)),
1814 "could not create a deconvolution forward descriptor");
1816 desc(prop_kind aprop_kind, algorithm aalgorithm,
1817 const memory::desc &src_desc,
1818 const memory::desc &weights_desc,
1819 const memory::desc &dst_desc,
1820 const memory::dims strides,
1821 const memory::dims padding_l,
1822 const memory::dims padding_r,
1823 const padding_kind apadding_kind) {
1824 memory::validate_dims(strides);
1825 memory::validate_dims(padding_l);
1826 memory::validate_dims(padding_r);
1827 error::wrap_c_api(mkldnn_deconvolution_forward_desc_init(&data,
1828 mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1829 &src_desc.data, &weights_desc.data, nullptr,
1830 &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
1831 mkldnn::convert_to_c(apadding_kind)),
1832 "could not create a deconvolution forward descriptor");
1834 desc(prop_kind aprop_kind, algorithm aalgorithm,
1835 const memory::desc &src_desc,
1836 const memory::desc &weights_desc,
1837 const memory::desc &bias_desc,
1838 const memory::desc &dst_desc,
1839 const memory::dims strides,
1840 const memory::dims dilates,
1841 const memory::dims padding_l,
1842 const memory::dims padding_r,
1843 const padding_kind apadding_kind) {
1844 memory::validate_dims(strides);
1845 memory::validate_dims(dilates);
1846 memory::validate_dims(padding_l);
1847 memory::validate_dims(padding_r);
1848 error::wrap_c_api(mkldnn_dilated_deconvolution_forward_desc_init(&data,
1849 mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1850 &src_desc.data, &weights_desc.data, &bias_desc.data,
1851 &dst_desc.data, &strides[0], &dilates[0], &padding_l[0],
1852 &padding_r[0], mkldnn::convert_to_c(apadding_kind)),
1853 "could not create a dilated deconvolution forward descriptor");
1855 desc(prop_kind aprop_kind, algorithm aalgorithm,
1856 const memory::desc &src_desc,
1857 const memory::desc &weights_desc,
1858 const memory::desc &dst_desc,
1859 const memory::dims strides,
1860 const memory::dims dilates,
1861 const memory::dims padding_l,
1862 const memory::dims padding_r,
1863 const padding_kind apadding_kind) {
1864 memory::validate_dims(strides);
1865 memory::validate_dims(dilates);
1866 memory::validate_dims(padding_l);
1867 memory::validate_dims(padding_r);
1868 error::wrap_c_api(mkldnn_dilated_deconvolution_forward_desc_init(&data,
1869 mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1870 &src_desc.data, &weights_desc.data, nullptr,
1871 &dst_desc.data, &strides[0], &dilates[0], &padding_l[0],
1872 &padding_r[0], mkldnn::convert_to_c(apadding_kind)),
1873 "could not create a dilated deconvolution forward descriptor");
1877 struct primitive_desc : public mkldnn::primitive_desc {
1878 primitive_desc(const desc &desc, const engine &e)
1879 : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
1881 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
1882 : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
1884 REG_QUERY_MPD(src, src, 0);
1885 REG_QUERY_MPD(weights, weights, 0);
1886 REG_QUERY_MPD(bias, weights, 1);
1887 REG_QUERY_MPD(dst, dst, 0);
1890 deconvolution_forward(const primitive_desc &aprimitive_desc,
1891 const primitive::at &src, const primitive::at &weights,
1892 const primitive::at &bias, const memory &dst) {
1893 mkldnn_primitive_t result;
1894 mkldnn_primitive_at_t inputs[] = { src.data, weights.data,
1896 const_mkldnn_primitive_t outputs[] = { dst.get() };
1897 check_num_parameters(aprimitive_desc.get(), 3, 1,
1898 "deconvolution forward");
1899 error::wrap_c_api(mkldnn_primitive_create(&result,
1900 aprimitive_desc.get(), inputs, outputs),
1901 "could not create a deconvolution forward bias primitive");
1905 deconvolution_forward(const primitive_desc &aprimitive_desc,
1906 const primitive::at &src, const primitive::at &weights,
1907 const memory &dst) {
1908 mkldnn_primitive_t result;
1909 mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
1910 const_mkldnn_primitive_t outputs[] = { dst.get() };
1911 check_num_parameters(aprimitive_desc.get(), 2, 1,
1912 "deconvolution forward");
1913 error::wrap_c_api(mkldnn_primitive_create(&result,
1914 aprimitive_desc.get(), inputs, outputs),
1915 "could not create a deconvolution forward primitive");
1920 struct deconvolution_backward_data : public primitive {
1922 mkldnn_deconvolution_desc_t data;
1923 desc(algorithm aalgorithm,
1924 const memory::desc &diff_src_desc,
1925 const memory::desc &weights_desc,
1926 const memory::desc &diff_dst_desc,
1927 const memory::dims strides,
1928 const memory::dims padding_l,
1929 const memory::dims padding_r,
1930 const padding_kind apadding_kind) {
1931 memory::validate_dims(strides);
1932 memory::validate_dims(padding_l);
1933 memory::validate_dims(padding_r);
1934 error::wrap_c_api(mkldnn_deconvolution_backward_data_desc_init(
1935 &data, convert_to_c(aalgorithm), &diff_src_desc.data,
1936 &weights_desc.data, &diff_dst_desc.data,
1937 &strides[0], &padding_l[0], &padding_r[0],
1938 mkldnn::convert_to_c(apadding_kind)),
1939 "could not create a deconvolution backward data descriptor");
1941 desc(algorithm aalgorithm,
1942 const memory::desc &diff_src_desc,
1943 const memory::desc &weights_desc,
1944 const memory::desc &diff_dst_desc,
1945 const memory::dims strides,
1946 const memory::dims dilates,
1947 const memory::dims padding_l,
1948 const memory::dims padding_r,
1949 const padding_kind apadding_kind) {
1950 memory::validate_dims(strides);
1951 memory::validate_dims(dilates);
1952 memory::validate_dims(padding_l);
1953 memory::validate_dims(padding_r);
1954 error::wrap_c_api(mkldnn_dilated_deconvolution_backward_data_desc_init(
1955 &data, convert_to_c(aalgorithm), &diff_src_desc.data,
1956 &weights_desc.data, &diff_dst_desc.data,
1957 &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1958 mkldnn::convert_to_c(apadding_kind)),
1959 "could not create a dilated deconvolution backward data descriptor");
1963 struct primitive_desc : public mkldnn::primitive_desc {
1964 primitive_desc(const desc &desc, const engine &e,
1965 const deconvolution_forward::primitive_desc &hint_fwd_pd)
1966 : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
1968 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
1969 const deconvolution_forward::primitive_desc &hint_fwd_pd)
1970 : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
1972 REG_QUERY_MPD(diff_src, diff_src, 0);
1973 REG_QUERY_MPD(weights, weights, 0);
1974 REG_QUERY_MPD(diff_dst, diff_dst, 0);
1977 deconvolution_backward_data(const primitive_desc &aprimitive_desc,
1978 const primitive::at &diff_dst, const primitive::at &weights,
1979 const memory &diff_src) {
1980 mkldnn_primitive_t result;
1981 mkldnn_primitive_at_t inputs[] = { diff_dst.data, weights.data };
1982 const_mkldnn_primitive_t outputs[] = { diff_src.get() };
1983 check_num_parameters(aprimitive_desc.get(), 2, 1,
1984 "deconvolution backward data");
1985 error::wrap_c_api(mkldnn_primitive_create(&result,
1986 aprimitive_desc.get(), inputs, outputs),
1987 "could not create a deconvolution backward data primitive");
1992 struct deconvolution_backward_weights : public primitive {
1994 mkldnn_deconvolution_desc_t data;
1995 desc(algorithm aalgorithm,
1996 const memory::desc &src_desc,
1997 const memory::desc &diff_weights_desc,
1998 const memory::desc &diff_bias_desc,
1999 const memory::desc &diff_dst_desc,
2000 const memory::dims strides,
2001 const memory::dims padding_l,
2002 const memory::dims padding_r,
2003 const padding_kind apadding_kind) {
2004 memory::validate_dims(strides);
2005 memory::validate_dims(padding_l);
2006 memory::validate_dims(padding_r);
2007 error::wrap_c_api(mkldnn_deconvolution_backward_weights_desc_init(
2008 &data, convert_to_c(aalgorithm), &src_desc.data,
2009 &diff_weights_desc.data, &diff_bias_desc.data,
2010 &diff_dst_desc.data,
2011 &strides[0], &padding_l[0], &padding_r[0],
2012 mkldnn::convert_to_c(apadding_kind)),
2013 "could not create a deconvolution backward weights descriptor");
2015 desc(algorithm aalgorithm,
2016 const memory::desc &src_desc,
2017 const memory::desc &diff_weights_desc,
2018 const memory::desc &diff_dst_desc,
2019 const memory::dims strides,
2020 const memory::dims padding_l,
2021 const memory::dims padding_r,
2022 const padding_kind apadding_kind) {
2023 memory::validate_dims(strides);
2024 memory::validate_dims(padding_l);
2025 memory::validate_dims(padding_r);
2026 error::wrap_c_api(mkldnn_deconvolution_backward_weights_desc_init(
2027 &data, convert_to_c(aalgorithm), &src_desc.data,
2028 &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
2029 &strides[0], &padding_l[0], &padding_r[0],
2030 mkldnn::convert_to_c(apadding_kind)),
2031 "could not create a deconvolution backward weights descriptor");
2033 desc(algorithm aalgorithm,
2034 const memory::desc &src_desc,
2035 const memory::desc &diff_weights_desc,
2036 const memory::desc &diff_bias_desc,
2037 const memory::desc &diff_dst_desc,
2038 const memory::dims strides,
2039 const memory::dims dilates,
2040 const memory::dims padding_l,
2041 const memory::dims padding_r,
2042 const padding_kind apadding_kind) {
2043 memory::validate_dims(strides);
2044 memory::validate_dims(dilates);
2045 memory::validate_dims(padding_l);
2046 memory::validate_dims(padding_r);
2047 error::wrap_c_api(mkldnn_dilated_deconvolution_backward_weights_desc_init(
2048 &data, convert_to_c(aalgorithm), &src_desc.data,
2049 &diff_weights_desc.data, &diff_bias_desc.data,
2050 &diff_dst_desc.data,
2051 &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
2052 mkldnn::convert_to_c(apadding_kind)),
2053 "could not create a dilated deconvolution backward weights descriptor");
2055 desc(algorithm aalgorithm,
2056 const memory::desc &src_desc,
2057 const memory::desc &diff_weights_desc,
2058 const memory::desc &diff_dst_desc,
2059 const memory::dims strides,
2060 const memory::dims dilates,
2061 const memory::dims padding_l,
2062 const memory::dims padding_r,
2063 const padding_kind apadding_kind) {
2064 memory::validate_dims(strides);
2065 memory::validate_dims(dilates);
2066 memory::validate_dims(padding_l);
2067 memory::validate_dims(padding_r);
2068 error::wrap_c_api(mkldnn_dilated_deconvolution_backward_weights_desc_init(
2069 &data, convert_to_c(aalgorithm), &src_desc.data,
2070 &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
2071 &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
2072 mkldnn::convert_to_c(apadding_kind)),
2073 "could not create a dilated deconvolution backward weights descriptor");
2077 struct primitive_desc : public mkldnn::primitive_desc {
2078 primitive_desc(const desc &desc, const engine &e,
2079 const deconvolution_forward::primitive_desc &hint_fwd_pd)
2080 : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2082 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2083 const deconvolution_forward::primitive_desc &hint_fwd_pd)
2084 : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2086 REG_QUERY_MPD(src, src, 0);
2087 REG_QUERY_MPD(diff_weights, diff_weights, 0);
2088 REG_QUERY_MPD(diff_bias, diff_weights, 1);
2089 REG_QUERY_MPD(diff_dst, diff_dst, 0);
2092 deconvolution_backward_weights(const primitive_desc &aprimitive_desc,
2093 const primitive::at &src, const primitive::at &diff_dst,
2094 const memory &diff_weights, const memory &diff_bias) {
2095 mkldnn_primitive_t result;
2096 mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
2097 const_mkldnn_primitive_t outputs[] = { diff_weights.get(),
2099 check_num_parameters(aprimitive_desc.get(), 2, 2,
2100 "deconvolution backward weights");
2101 error::wrap_c_api(mkldnn_primitive_create(&result,
2102 aprimitive_desc.get(), inputs, outputs),
2103 "could not create a deconvolution backward weights primitive");
2106 deconvolution_backward_weights(const primitive_desc &aprimitive_desc,
2107 const primitive::at &src, const primitive::at &diff_dst,
2108 const memory &diff_weights) {
2109 mkldnn_primitive_t result;
2110 mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
2111 const_mkldnn_primitive_t outputs[] = { diff_weights.get() };
2112 check_num_parameters(aprimitive_desc.get(), 2, 1,
2113 "deconvolution backward weights");
2114 error::wrap_c_api(mkldnn_primitive_create(&result,
2115 aprimitive_desc.get(), inputs, outputs),
2116 "could not create a deconvolution backward weights primitive");
2123 /// @addtogroup cpp_api_roi_pooling ROIPooling
2126 struct roi_pooling_forward : public primitive {
2128 mkldnn_roi_pooling_desc_t data;
2129 std::vector<mkldnn_memory_desc_t> c_api_inputs;
2131 desc(prop_kind aprop_kind, algorithm aalgorithm, std::vector<memory::desc> inputs,
2132 const memory::desc &dst_desc, int pooled_h, int pooled_w, double spatial_scale) {
2134 for(size_t i = 0; i < inputs.size(); i++) {
2135 c_api_inputs.push_back(inputs[i].data);
2138 error::wrap_c_api(mkldnn_roi_pooling_forward_desc_init(&data,
2139 mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), &c_api_inputs[0],
2140 c_api_inputs.size(),
2141 &dst_desc.data, pooled_h, pooled_w, spatial_scale),
2142 "could not create a roi pooling forward descriptor");
2146 struct primitive_desc : public handle<mkldnn_primitive_desc_t>{
2147 primitive_desc(const desc &adesc, const engine &aengine) {
2148 mkldnn_primitive_desc_t result;
2149 error::wrap_c_api(mkldnn_primitive_desc_create(
2150 &result, &adesc.data, aengine.get(), nullptr),
2151 "could not create a roi pooling forward primitive descriptor");
2156 roi_pooling_forward(const primitive_desc &aprimitive_desc,
2157 std::vector<primitive::at> &inputs, const memory &dst) {
2158 mkldnn_primitive_t result;
2160 std::vector<mkldnn_primitive_at_t> p_inputs;
2161 for (size_t i = 0; i < inputs.size(); i++) {
2162 p_inputs.push_back(inputs[i].data);
2165 const_mkldnn_primitive_t outputs[] = { dst.get() };
2166 error::wrap_c_api(mkldnn_primitive_create(&result,
2167 aprimitive_desc.get(), &p_inputs[0], outputs),
2168 "could not create a roi pooling forward primitive");
2175 /// @addtogroup cpp_api_lrn LRN
2176 /// A primitive to perform local response normalization (LRN) across or within
2179 /// @sa @ref c_api_lrn in @ref c_api
2182 struct lrn_forward : public primitive {
2184 mkldnn_lrn_desc_t data;
2185 desc(prop_kind aprop_kind, algorithm aalgorithm,
2186 const memory::desc &src_desc,
2187 int local_size, float alpha, float beta, float k)
2189 error::wrap_c_api(mkldnn_lrn_forward_desc_init(&data,
2190 mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
2191 &src_desc.data, local_size, alpha, beta, k),
2192 "could not create a lrn forward descriptor");
2194 desc(prop_kind aprop_kind, algorithm aalgorithm,
2195 const memory::desc &src_desc,
2196 int local_size, float alpha, float beta)
2198 error::wrap_c_api(mkldnn_lrn_forward_desc_init(&data,
2199 mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
2200 &src_desc.data, local_size, alpha, beta, float(1.0)),
2201 "could not create a lrn forward descriptor");
2205 struct primitive_desc : public mkldnn::primitive_desc {
2206 primitive_desc(const desc &desc, const engine &e)
2207 : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2209 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2210 : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2212 REG_QUERY_MPD(src, src, 0);
2213 REG_QUERY_MPD(dst, dst, 0);
2214 REG_QUERY_MPD(workspace, workspace, 0);
2217 lrn_forward(const primitive_desc &aprimitive_desc,
2218 const primitive::at &src, const memory &workspace,
2219 const memory &dst) {
2220 mkldnn_primitive_t result;
2221 mkldnn_primitive_at_t inputs[] = { src.data };
2222 const_mkldnn_primitive_t outputs[] = { dst.get(),
2224 check_num_parameters(aprimitive_desc.get(), 1, 2, "lrn forward");
2225 error::wrap_c_api(mkldnn_primitive_create(&result,
2226 aprimitive_desc.get(), inputs, outputs),
2227 "could not create a lrn forward primitive");
2231 lrn_forward(const primitive_desc &aprimitive_desc,
2232 const primitive::at &src, const memory &dst) {
2233 mkldnn_primitive_t result;
2234 mkldnn_primitive_at_t inputs[] = { src.data };
2235 const_mkldnn_primitive_t outputs[] = { dst.get() };
2236 check_num_parameters(aprimitive_desc.get(), 1, 1, "lrn forward");
2237 error::wrap_c_api(mkldnn_primitive_create(&result,
2238 aprimitive_desc.get(), inputs, outputs),
2239 "could not create a lrn forward primitive");
2244 struct lrn_backward : public primitive {
2246 mkldnn_lrn_desc_t data;
2247 desc(algorithm aalgorithm,
2248 const memory::desc &data_desc,
2249 const memory::desc &diff_data_desc,
2250 int local_size, float alpha, float beta, float k)
2252 error::wrap_c_api(mkldnn_lrn_backward_desc_init(&data,
2253 convert_to_c(aalgorithm), &diff_data_desc.data,
2254 &data_desc.data, local_size, alpha, beta, k),
2255 "could not create a lrn backward descriptor");
2257 desc(algorithm aalgorithm,
2258 const memory::desc &data_desc,
2259 const memory::desc &diff_data_desc,
2260 int local_size, float alpha, float beta)
2262 error::wrap_c_api(mkldnn_lrn_backward_desc_init(&data,
2263 convert_to_c(aalgorithm), &diff_data_desc.data,
2264 &data_desc.data, local_size, alpha, beta, float(1.0)),
2265 "could not create a lrn backward descriptor");
2269 struct primitive_desc : public mkldnn::primitive_desc {
2270 primitive_desc(const desc &desc, const engine &e,
2271 const lrn_forward::primitive_desc &hint_fwd_pd)
2272 : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2274 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2275 const lrn_forward::primitive_desc &hint_fwd_pd)
2276 : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2278 REG_QUERY_MPD(diff_src, diff_src, 0);
2279 REG_QUERY_MPD(diff_dst, diff_dst, 0);
2280 REG_QUERY_MPD(workspace, workspace, 0);
2283 lrn_backward(const primitive_desc &aprimitive_desc,
2284 const primitive::at &src, const primitive::at &diff_dst,
2285 const primitive::at &workspace, const memory &diff_src) {
2286 mkldnn_primitive_t result;
2287 mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data,
2289 const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2290 check_num_parameters(aprimitive_desc.get(), 3, 1, "lrn backward");
2291 error::wrap_c_api(mkldnn_primitive_create(&result,
2292 aprimitive_desc.get(), inputs, outputs),
2293 "could not create a lrn backward primitive");
2297 lrn_backward(const primitive_desc &aprimitive_desc,
2298 const primitive::at &src, const primitive::at &diff_dst,
2299 const memory &diff_src) {
2300 mkldnn_primitive_t result;
2301 mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
2302 const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2303 check_num_parameters(aprimitive_desc.get(), 2, 1, "lrn backward");
2304 error::wrap_c_api(mkldnn_primitive_create(&result,
2305 aprimitive_desc.get(), inputs, outputs),
2306 "could not create a lrn backward primitive");
2313 /// @addtogroup cpp_api_pooling Pooling
2314 /// A primitive to perform max or average pooling.
2316 /// @sa @ref c_api_pooling in @ref c_api
2319 struct pooling_forward : public primitive {
2321 mkldnn_pooling_desc_t data;
2322 desc(prop_kind aprop_kind, algorithm aalgorithm,
2323 const memory::desc &src_desc,
2324 const memory::desc &dst_desc,
2325 const memory::dims strides,
2326 const memory::dims kernel,
2327 const memory::dims padding_l,
2328 const memory::dims padding_r,
2329 const padding_kind apadding_kind) {
2330 memory::validate_dims(strides);
2331 memory::validate_dims(kernel);
2332 memory::validate_dims(padding_l);
2333 memory::validate_dims(padding_r);
2334 error::wrap_c_api(mkldnn_pooling_forward_desc_init(&data,
2335 mkldnn::convert_to_c(aprop_kind),
2336 convert_to_c(aalgorithm),
2337 &src_desc.data, &dst_desc.data,
2338 &strides[0], &kernel[0],
2339 &padding_l[0], &padding_r[0],
2340 mkldnn::convert_to_c(apadding_kind)),
2341 "could not init a forward pooling descriptor");
2345 struct primitive_desc : public mkldnn::primitive_desc {
2346 primitive_desc(const desc &desc, const engine &e)
2347 : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2349 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2350 : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2352 REG_QUERY_MPD(src, src, 0);
2353 REG_QUERY_MPD(dst, dst, 0);
2354 REG_QUERY_MPD(workspace, workspace, 0);
2357 pooling_forward(const primitive_desc &aprimitive_desc, const primitive::at &src,
2358 const memory &dst) {
2359 mkldnn_primitive_t result;
2360 mkldnn_primitive_at_t inputs[] = { src.data };
2361 const_mkldnn_primitive_t outputs[] = { dst.get(), nullptr };
2362 check_num_parameters(aprimitive_desc.get(), 1, 1, "pooling forward");
2363 error::wrap_c_api(mkldnn_primitive_create(&result,
2364 aprimitive_desc.get(), inputs, outputs),
2365 "could not create a pooling forward primitive");
2369 pooling_forward(const primitive_desc &aprimitive_desc, const primitive::at &src,
2370 const memory &dst, const memory &workspace) {
2371 mkldnn_primitive_t result;
2372 mkldnn_primitive_at_t inputs[] = { src.data };
2373 const_mkldnn_primitive_t outputs[] = { dst.get(), workspace.get() };
2374 check_num_parameters(aprimitive_desc.get(), 1, 2, "pooling forward");
2375 error::wrap_c_api(mkldnn_primitive_create(&result,
2376 aprimitive_desc.get(), inputs, outputs),
2377 "could not create a pooling forward primitive");
2382 struct pooling_backward : public primitive {
2384 mkldnn_pooling_desc_t data;
2385 desc(algorithm aalgorithm,
2386 const memory::desc &diff_src_desc,
2387 const memory::desc &diff_dst_desc,
2388 const memory::dims &strides,
2389 const memory::dims &kernel,
2390 const memory::dims &padding_l,
2391 const memory::dims &padding_r,
2392 const padding_kind apadding_kind) {
2393 memory::validate_dims(strides);
2394 memory::validate_dims(kernel);
2395 memory::validate_dims(padding_l);
2396 memory::validate_dims(padding_r);
2397 error::wrap_c_api(mkldnn_pooling_backward_desc_init(&data,
2398 convert_to_c(aalgorithm),
2399 &diff_src_desc.data, &diff_dst_desc.data,
2400 &strides[0], &kernel[0],
2401 &padding_l[0], &padding_r[0],
2402 mkldnn::convert_to_c(apadding_kind)),
2403 "could not init a backward pooling descriptor");
2407 struct primitive_desc : public mkldnn::primitive_desc {
2408 primitive_desc(const desc &desc, const engine &e,
2409 const pooling_forward::primitive_desc &hint_fwd_pd)
2410 : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2412 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2413 const pooling_forward::primitive_desc &hint_fwd_pd)
2414 : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2416 REG_QUERY_MPD(diff_src, diff_src, 0);
2417 REG_QUERY_MPD(diff_dst, diff_dst, 0);
2418 REG_QUERY_MPD(workspace, workspace, 0);
2421 pooling_backward(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst,
2422 const memory &diff_src) {
2423 mkldnn_primitive_t result;
2424 mkldnn_primitive_at_t inputs[] = { diff_dst.data };
2425 const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2426 check_num_parameters(aprimitive_desc.get(), 1, 1, "pooling backward");
2427 error::wrap_c_api(mkldnn_primitive_create(&result,
2428 aprimitive_desc.get(), inputs, outputs),
2429 "could not create a pooling backward primitive");
2433 pooling_backward(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst,
2434 const primitive::at &workspace, const memory &diff_src) {
2435 mkldnn_primitive_t result;
2436 mkldnn_primitive_at_t inputs[] = { diff_dst.data, workspace.data };
2437 const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2438 check_num_parameters(aprimitive_desc.get(), 2, 1, "pooling backward");
2439 error::wrap_c_api(mkldnn_primitive_create(&result,
2440 aprimitive_desc.get(), inputs, outputs),
2441 "could not create a pooling backward primitive");
2448 /// @addtogroup cpp_api_eltwise Eltwise
2449 /// A primitive to compute element-wise operations like parametric rectifier
2450 /// linear unit (ReLU).
2452 /// @sa @ref c_api_eltwise in @ref c_api
2455 struct eltwise_forward : public primitive {
2457 mkldnn_eltwise_desc_t data;
2458 template <typename T>
2459 desc(prop_kind aprop_kind, algorithm alg_kind,
2460 const memory::desc &src_desc, T alpha = 0, T beta = 0) {
2461 error::wrap_c_api(mkldnn_eltwise_forward_desc_init(&data,
2462 mkldnn::convert_to_c(aprop_kind),
2463 mkldnn::convert_to_c(alg_kind), &src_desc.data,
2464 static_cast<float>(alpha), static_cast<float>(beta)),
2465 "could not create a eltwise forward descriptor");
2469 struct primitive_desc : public mkldnn::primitive_desc {
2470 primitive_desc(const desc &desc, const engine &e)
2471 : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2473 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2474 : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2476 REG_QUERY_MPD(src, src, 0);
2477 REG_QUERY_MPD(dst, dst, 0);
2480 eltwise_forward(const primitive_desc &aprimitive_desc,
2481 const primitive::at &src, const memory &dst) {
2482 mkldnn_primitive_t result;
2483 mkldnn_primitive_at_t inputs[] = { src.data };
2484 const_mkldnn_primitive_t outputs[] = { dst.get() };
2485 check_num_parameters(aprimitive_desc.get(), 1, 1, "eltwise forward");
2486 error::wrap_c_api(mkldnn_primitive_create(&result,
2487 aprimitive_desc.get(), inputs, outputs),
2488 "could not create a eltwise forward primitive");
2493 struct eltwise_backward : public primitive {
2495 mkldnn_eltwise_desc_t data;
2497 template <typename T>
2498 desc(algorithm alg_kind, const memory::desc &diff_data_desc,
2499 const memory::desc &data_desc, T alpha = 0, T beta = 0) {
2500 error::wrap_c_api(mkldnn_eltwise_backward_desc_init(&data,
2501 mkldnn::convert_to_c(alg_kind), &diff_data_desc.data,
2502 &data_desc.data, static_cast<float>(alpha),
2503 static_cast<float>(beta)),
2504 "could not create a eltwise backward descriptor");
2508 struct primitive_desc : public mkldnn::primitive_desc {
2509 primitive_desc(const desc &desc, const engine &e,
2510 const eltwise_forward::primitive_desc &hint_fwd_pd)
2511 : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2513 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2514 const eltwise_forward::primitive_desc &hint_fwd_pd)
2515 : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2517 REG_QUERY_MPD(src, src, 0);
2518 REG_QUERY_MPD(diff_src, diff_src, 0);
2519 REG_QUERY_MPD(diff_dst, diff_dst, 0);
2522 eltwise_backward(const primitive_desc &aprimitive_desc,
2523 const primitive::at &src, const primitive::at &diff_dst,
2524 const memory &diff_src) {
2525 mkldnn_primitive_t result;
2526 mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
2527 const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2528 check_num_parameters(aprimitive_desc.get(), 2, 1, "eltwise backward");
2529 error::wrap_c_api(mkldnn_primitive_create(&result,
2530 aprimitive_desc.get(), inputs, outputs),
2531 "could not create a eltwise backward primitive");
2538 /// @addtogroup cpp_api_depthwise Depthwise
2541 struct depthwise_forward : public primitive {
2543 mkldnn_depthwise_desc_t data;
2545 desc(prop_kind aprop_kind, algorithm alg_kind,
2546 const memory::desc &src_desc, const memory::desc &dst_desc, const memory::desc &weights_desc,
2547 const memory::desc &bias_desc) {
2548 error::wrap_c_api(mkldnn_depthwise_forward_desc_init(&data,
2549 mkldnn::convert_to_c(aprop_kind),
2550 mkldnn::convert_to_c(alg_kind),
2551 &src_desc.data, &dst_desc.data,
2552 &weights_desc.data, &bias_desc.data),
2553 "could not create a depthwise forward descriptor");
2556 desc(prop_kind aprop_kind, algorithm alg_kind,
2557 const memory::desc &src_desc, const memory::desc &dst_desc, const memory::desc &weights_desc) {
2558 error::wrap_c_api(mkldnn_depthwise_forward_desc_init(&data,
2559 mkldnn::convert_to_c(aprop_kind),
2560 mkldnn::convert_to_c(alg_kind),
2561 &src_desc.data, &dst_desc.data,
2562 &weights_desc.data, nullptr),
2563 "could not create a depthwise forward descriptor");
2567 struct primitive_desc : public mkldnn::primitive_desc {
2568 primitive_desc(const desc &desc, const engine &e)
2569 : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2571 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2572 : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2574 REG_QUERY_MPD(src, src, 0);
2575 REG_QUERY_MPD(dst, dst, 0);
2578 depthwise_forward(const primitive_desc &aprimitive_desc,
2579 const primitive::at &src, const primitive::at &weights,
2580 const primitive::at &bias, const memory &dst) {
2581 mkldnn_primitive_t result;
2582 mkldnn_primitive_at_t inputs[] = { src.data, weights.data, bias.data };
2583 const_mkldnn_primitive_t outputs[] = { dst.get() };
2584 error::wrap_c_api(mkldnn_primitive_create(&result,
2585 aprimitive_desc.get(), inputs, outputs),
2586 "could not create a depthwise forward primitive");
2590 depthwise_forward(const primitive_desc &aprimitive_desc,
2591 const primitive::at &src, const primitive::at &weights,
2592 const memory &dst) {
2593 mkldnn_primitive_t result;
2594 mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
2595 const_mkldnn_primitive_t outputs[] = { dst.get() };
2596 error::wrap_c_api(mkldnn_primitive_create(&result,
2597 aprimitive_desc.get(), inputs, outputs),
2598 "could not create a depthwise forward primitive");
2605 /// @addtogroup cpp_api_softmax Softmax
2606 /// A primitive to perform softmax.
2608 /// @sa @ref c_api_softmax in @ref c_api
2611 struct softmax_forward : public primitive {
2613 mkldnn_softmax_desc_t data;
2614 desc(prop_kind aprop_kind, const memory::desc &data_desc,
2616 error::wrap_c_api(mkldnn_softmax_forward_desc_init(&data,
2617 mkldnn::convert_to_c(aprop_kind), &data_desc.data,
2619 "could not create a softmax forward descriptor");
2623 struct primitive_desc : public mkldnn::primitive_desc {
2624 primitive_desc(const desc &desc, const engine &e)
2625 : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2627 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2628 : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2630 REG_QUERY_MPD(src, src, 0);
2631 REG_QUERY_MPD(dst, dst, 0);
2634 softmax_forward(const primitive_desc &aprimitive_desc,
2635 const primitive::at &src, const memory &dst) {
2636 mkldnn_primitive_t result;
2637 mkldnn_primitive_at_t inputs[] = { src.data };
2638 const_mkldnn_primitive_t outputs[] = { dst.get() };
2639 check_num_parameters(aprimitive_desc.get(), 1, 1, "softmax forward");
2640 error::wrap_c_api(mkldnn_primitive_create(&result,
2641 aprimitive_desc.get(), inputs, outputs),
2642 "could not create a softmax forward primitive");
2647 struct softmax_backward : public primitive {
2649 mkldnn_softmax_desc_t data;
2650 desc(const memory::desc &diff_desc, const memory::desc &data_desc,
2652 error::wrap_c_api(mkldnn_softmax_backward_desc_init(&data,
2653 &diff_desc.data, &data_desc.data, softmax_axis),
2654 "could not init a backward softmax descriptor");
2658 struct primitive_desc : public mkldnn::primitive_desc {
2659 primitive_desc(const desc &desc, const engine &e,
2660 const softmax_forward::primitive_desc &hint_fwd_pd)
2661 : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2663 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2664 const softmax_forward::primitive_desc &hint_fwd_pd)
2665 : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2667 REG_QUERY_MPD(dst, dst, 0);
2668 REG_QUERY_MPD(diff_src, diff_src, 0);
2669 REG_QUERY_MPD(diff_dst, diff_dst, 0);
2670 REG_QUERY_MPD(workspace, workspace, 0);
2673 softmax_backward(const primitive_desc &aprimitive_desc,
2674 const primitive::at &dst, const primitive::at &diff_dst,
2675 const memory &diff_src) {
2676 mkldnn_primitive_t result;
2677 mkldnn_primitive_at_t inputs[] = { dst.data, diff_dst.data };
2678 const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2679 error::wrap_c_api(mkldnn_primitive_create(&result,
2680 aprimitive_desc.get(), inputs, outputs),
2681 "could not create a softmax backward primitive");
2688 /// @addtogroup cpp_api_batch_norm Batch normalization
2689 /// A primitive to perform batch normalization.
2691 /// @sa @ref c_api_batch_normalization in @ref c_api
2694 struct batch_normalization_forward : public primitive {
2696 mkldnn_batch_normalization_desc_t data;
2697 template <typename T>
2698 desc(prop_kind aprop_kind, const memory::desc &src_desc, T epsilon,
2701 mkldnn_batch_normalization_forward_desc_init(&data,
2702 mkldnn::convert_to_c(aprop_kind), &src_desc.data,
2703 static_cast<float>(epsilon), flags),
2704 "could not create a batch normalization forward descriptor");
2708 struct primitive_desc : public mkldnn::primitive_desc {
2709 primitive_desc(const desc &desc, const engine &e)
2710 : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2712 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2713 : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2715 REG_QUERY_MPD(src, src, 0);
2716 REG_QUERY_MPD(weights, weights, 0);
2717 REG_QUERY_MPD(dst, dst, 0);
2718 REG_QUERY_MPD(workspace, workspace, 0);
2720 memory::primitive_desc mean_primitive_desc() const
2721 { return stat_primitive_desc(mean); }
2722 memory::primitive_desc variance_primitive_desc() const
2723 { return stat_primitive_desc(var); }
2726 enum { mean = 1, var = 2, };
2727 memory::primitive_desc stat_primitive_desc(int kind) const {
2728 mkldnn_batch_normalization_desc_t *p;
2729 error::wrap_c_api(mkldnn_primitive_desc_query(
2730 get(), mkldnn::convert_to_c(batch_normalization_d), 0, &p),
2731 "could not get a batch-normalization descriptor");
2732 return query_mpd(p->flags & use_global_stats ? src_pd : dst_pd, kind);
2736 batch_normalization_forward(const primitive_desc &aprimitive_desc,
2737 const primitive::at &src, const primitive::at &mean,
2738 const primitive::at &variance, const primitive::at &weights,
2739 const memory &dst) {
2740 mkldnn_primitive_t result;
2741 mkldnn_primitive_at_t inputs[] = { src.data,
2742 mean.data, variance.data, weights.data };
2743 const_mkldnn_primitive_t outputs[] = { dst.get() };
2744 check_num_parameters(aprimitive_desc.get(), 4, 1,
2745 "batch normalization forward");
2746 error::wrap_c_api(mkldnn_primitive_create(&result,
2747 aprimitive_desc.get(), inputs, outputs),
2748 "could not create a batch normalization forward primitive");
2752 batch_normalization_forward(const primitive_desc &aprimitive_desc,
2753 const primitive::at &src, const primitive::at &mean,
2754 const primitive::at &variance, const memory &dst) {
2755 mkldnn_primitive_t result;
2756 mkldnn_primitive_at_t inputs[] = { src.data,
2757 mean.data, variance.data };
2758 const_mkldnn_primitive_t outputs[] = { dst.get() };
2759 check_num_parameters(aprimitive_desc.get(), 3, 1,
2760 "batch normalization forward");
2761 error::wrap_c_api(mkldnn_primitive_create(&result,
2762 aprimitive_desc.get(), inputs, outputs),
2763 "could not create a batch normalization forward primitive");
2767 /// @warning batch_normalization_forward has two constructors with very
2768 /// similar signatures:
2769 /// - (pd, src, weights, dst, mean, variance) // 2 in, 3 out
2770 /// - (pd, src, dst, mean, variance, workspace) // 1 in, 4 out
2771 /// The only way to distinguish between them is to explicitly
2772 /// cast all input parameters to their type; that is, to
2773 /// const primitive:at &.
2774 batch_normalization_forward(const primitive_desc &aprimitive_desc,
2775 const primitive::at &src, const primitive::at &weights,
2776 const memory &dst, const memory &mean, const memory &variance) {
2777 mkldnn_primitive_t result;
2778 mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
2779 const_mkldnn_primitive_t outputs[] = { dst.get(),
2780 mean.get(), variance.get() };
2781 check_num_parameters(aprimitive_desc.get(), 2, 3,
2782 "batch normalization forward");
2783 error::wrap_c_api(mkldnn_primitive_create(&result,
2784 aprimitive_desc.get(), inputs, outputs),
2785 "could not create a batch normalization forward primitive");
2789 batch_normalization_forward(const primitive_desc &aprimitive_desc,
2790 const primitive::at &src, const primitive::at &weights,
2791 const memory &dst, const memory &mean, const memory &variance,
2792 const memory &workspace) {
2793 mkldnn_primitive_t result;
2794 mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
2795 const_mkldnn_primitive_t outputs[] = { dst.get(),
2796 mean.get(), variance.get(), workspace.get() };
2797 check_num_parameters(aprimitive_desc.get(), 2, 4,
2798 "batch normalization forward");
2799 error::wrap_c_api(mkldnn_primitive_create(&result,
2800 aprimitive_desc.get(), inputs, outputs),
2801 "could not create a batch normalization forward primitive");
2805 batch_normalization_forward(const primitive_desc &aprimitive_desc,
2806 const primitive::at &src, const memory &dst, const memory &mean,
2807 const memory &variance) {
2808 mkldnn_primitive_t result;
2809 mkldnn_primitive_at_t inputs[] = { src.data };
2810 const_mkldnn_primitive_t outputs[] = { dst.get(),
2811 mean.get(), variance.get() };
2812 check_num_parameters(aprimitive_desc.get(), 1, 3,
2813 "batch normalization forward");
2814 error::wrap_c_api(mkldnn_primitive_create(&result,
2815 aprimitive_desc.get(), inputs, outputs),
2816 "could not create a batch normalization forward primitive");
2820 /// @warning batch_normalization_forward has two constructors with very
2821 /// similar signatures:
2822 /// - (pd, src, weights, dst, mean, variance) // 2 in, 3 out
2823 /// - (pd, src, dst, mean, variance, workspace) // 1 in, 4 out
2824 /// The only way to distinguish between them is to explicitly
2825 /// cast all input parameters to their type; that is, to
2826 /// const primitive:at &.
2827 /// @note To make users' experience a little better, this constructor
2828 /// checks whether parameters match the corresponding primitive
2829 /// descriptor, and if not, calls the other (proper) constructor.
2830 batch_normalization_forward(const primitive_desc &aprimitive_desc,
2831 const primitive::at &src, const memory &dst, const memory &mean,
2832 const memory &variance, const memory &workspace) {
2833 mkldnn_primitive_t result;
2834 mkldnn_primitive_at_t inputs[2] = { src.data };
2835 const_mkldnn_primitive_t outputs[4] = { dst.get(),
2836 mean.get(), variance.get(), workspace.get() };
2838 if (1) { // check whether this is the `wrong` constructor
2839 const int n_inputs_expected = mkldnn_primitive_desc_query_s32(
2840 aprimitive_desc.get(), mkldnn_query_num_of_inputs_s32, 0);
2841 const int n_outputs_expected = mkldnn_primitive_desc_query_s32(
2842 aprimitive_desc.get(), mkldnn_query_num_of_outputs_s32, 0);
2843 if (n_inputs_expected == 2 && n_outputs_expected == 3) {
2844 // shift parameters, get rid of workspace, and add weights...
2845 auto _weights = dst;
2846 inputs[1] = {_weights.get(), 0};
2848 auto _dst = mean, _mean = variance, _variance = workspace;
2849 outputs[0] = _dst.get();
2850 outputs[1] = _mean.get();
2851 outputs[2] = _variance.get();
2852 outputs[3] = nullptr;
2855 error::wrap_c_api(mkldnn_primitive_create(&result,
2856 aprimitive_desc.get(), inputs, outputs),
2857 "could not create a batch normalization forward primitive");
2861 batch_normalization_forward(const primitive_desc &aprimitive_desc,
2862 const primitive::at &src, const primitive::at &weights,
2863 const memory &dst) {
2864 mkldnn_primitive_t result;
2865 mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
2866 const_mkldnn_primitive_t outputs[] = { dst.get() };
2867 check_num_parameters(aprimitive_desc.get(), 2, 1,
2868 "batch normalization forward");
2869 error::wrap_c_api(mkldnn_primitive_create(&result,
2870 aprimitive_desc.get(), inputs, outputs),
2871 "could not create a batch normalization forward primitive");
2875 batch_normalization_forward(const primitive_desc &aprimitive_desc,
2876 const primitive::at &src, const memory &dst) {
2877 mkldnn_primitive_t result;
2878 mkldnn_primitive_at_t inputs[] = { src.data };
2879 const_mkldnn_primitive_t outputs[] = { dst.get() };
2880 check_num_parameters(aprimitive_desc.get(), 1, 1,
2881 "batch normalization forward");
2882 error::wrap_c_api(mkldnn_primitive_create(&result,
2883 aprimitive_desc.get(), inputs, outputs),
2884 "could not create a batch normalization forward primitive");
2889 struct batch_normalization_backward : public primitive {
2891 mkldnn_batch_normalization_desc_t data;
2892 template <typename T>
2893 desc(prop_kind aprop_kind, const memory::desc &diff_data_desc,
2894 const memory::desc &data_desc, T epsilon, unsigned flags) {
2896 mkldnn_batch_normalization_backward_desc_init(&data,
2897 mkldnn::convert_to_c(aprop_kind),
2898 &diff_data_desc.data, &data_desc.data,
2899 static_cast<float>(epsilon), flags),
2900 "could not create a batch normalization backward descriptor");
2904 struct primitive_desc : public mkldnn::primitive_desc {
2905 primitive_desc(const desc &desc, const engine &e,
2906 const batch_normalization_forward::primitive_desc &hint_fwd_pd)
2907 : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2909 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2910 const batch_normalization_forward::primitive_desc &hint_fwd_pd)
2911 : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2913 REG_QUERY_MPD(src, src, 0);
2914 REG_QUERY_MPD(mean, src, 1);
2915 REG_QUERY_MPD(variance, src, 2);
2916 REG_QUERY_MPD(weights, weights, 0);
2917 REG_QUERY_MPD(dst, dst, 0);
2918 REG_QUERY_MPD(diff_dst, diff_dst, 0);
2919 REG_QUERY_MPD(workspace, workspace, 0);
2921 REG_QUERY_MPD(diff_src, diff_src, 0);
2922 REG_QUERY_MPD(diff_weights, diff_weights, 0);
2925 // Prop_kind == backward
2926 batch_normalization_backward(const primitive_desc &aprimitive_desc,
2927 const primitive::at &src, const primitive::at &mean,
2928 const primitive::at &variance, const primitive::at &diff_dst,
2929 const primitive::at &weights, const memory &diff_src,
2930 const memory &diff_weights) {
2931 mkldnn_primitive_t result;
2932 mkldnn_primitive_at_t inputs[] = { src.data,
2933 mean.data, variance.data, diff_dst.data, weights.data };
2934 const_mkldnn_primitive_t outputs[] = { diff_src.get(),
2935 diff_weights.get() };
2936 check_num_parameters(aprimitive_desc.get(), 5, 2,
2937 "batch normalization backward");
2938 error::wrap_c_api(mkldnn_primitive_create(&result,
2939 aprimitive_desc.get(), inputs, outputs),
2940 "could not create a batch normalization backward primitive");
2944 // Prop_kind == backward (+ws)
2945 batch_normalization_backward(const primitive_desc &aprimitive_desc,
2946 const primitive::at &src, const primitive::at &mean,
2947 const primitive::at &variance, const primitive::at &diff_dst,
2948 const primitive::at &weights, const primitive::at &workspace,
2949 const memory &diff_src, const memory &diff_weights) {
2950 mkldnn_primitive_t result;
2951 mkldnn_primitive_at_t inputs[] = { src.data, mean.data, variance.data,
2952 diff_dst.data, weights.data, workspace.data };
2953 const_mkldnn_primitive_t outputs[] = { diff_src.get(),
2954 diff_weights.get() };
2955 check_num_parameters(aprimitive_desc.get(), 6, 2,
2956 "batch normalization backward");
2957 error::wrap_c_api(mkldnn_primitive_create(&result,
2958 aprimitive_desc.get(), inputs, outputs),
2959 "could not create a batch normalization backward primitive");
2963 // Prop_kind == backward_data (+ws or +weights)
2964 /// @warning This constructor works for backward_data propagation
2965 /// - w/ weights but w/o workspace, or
2966 /// - w/ workspace but w/o weights
2967 batch_normalization_backward(const primitive_desc &aprimitive_desc,
2968 const primitive::at &src, const primitive::at &mean,
2969 const primitive::at &variance,const primitive::at &diff_dst,
2970 const primitive::at &weights_or_workspace, const memory &diff_src) {
2971 mkldnn_primitive_t result;
2972 mkldnn_primitive_at_t inputs[] = { src.data, mean.data, variance.data,
2973 diff_dst.data, weights_or_workspace.data };
2974 const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2975 check_num_parameters(aprimitive_desc.get(), 5, 1,
2976 "batch normalization backward");
2977 error::wrap_c_api(mkldnn_primitive_create(&result,
2978 aprimitive_desc.get(), inputs, outputs),
2979 "could not create a batch normalization backward primitive");
2983 // Prop_kind == backward_data
2984 batch_normalization_backward(const primitive_desc &aprimitive_desc,
2985 const primitive::at &src, const primitive::at &mean,
2986 const primitive::at &variance, const primitive::at &diff_dst,
2987 const memory &diff_src) {
2988 mkldnn_primitive_t result;
2989 mkldnn_primitive_at_t inputs[] = { src.data,
2990 mean.data, variance.data, diff_dst.data };
2991 const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2992 check_num_parameters(aprimitive_desc.get(), 4, 1,
2993 "batch normalization backward");
2994 error::wrap_c_api(mkldnn_primitive_create(&result,
2995 aprimitive_desc.get(), inputs, outputs),
2996 "could not create a batch normalization backward primitive");
3003 /// @addtogroup cpp_api_inner_product Inner Product
3004 /// A primitive to compute an inner product.
3006 /// @sa @ref c_api_inner_product in @ref c_api
3009 struct inner_product_forward: public primitive {
3011 mkldnn_inner_product_desc_t data;
3012 desc(prop_kind aprop_kind, const memory::desc &src_desc,
3013 const memory::desc &weights_desc,
3014 const memory::desc &bias_desc,
3015 const memory::desc &dst_desc) {
3017 mkldnn_inner_product_forward_desc_init(&data,
3018 mkldnn::convert_to_c(aprop_kind), &src_desc.data,
3019 &weights_desc.data, &bias_desc.data, &dst_desc.data),
3020 "could not create a inner product forward descriptor");
3023 desc(prop_kind aprop_kind, const memory::desc &src_desc,
3024 const memory::desc &weights_desc,
3025 const memory::desc &dst_desc) {
3027 mkldnn_inner_product_forward_desc_init(&data,
3028 mkldnn::convert_to_c(aprop_kind), &src_desc.data,
3029 &weights_desc.data, nullptr, &dst_desc.data),
3030 "could not create a inner product forward descriptor");
3034 struct primitive_desc : public mkldnn::primitive_desc {
3035 primitive_desc(const desc &desc, const engine &e)
3036 : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
3038 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
3039 : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
3041 REG_QUERY_MPD(src, src, 0);
3042 REG_QUERY_MPD(weights, weights, 0);
3043 REG_QUERY_MPD(bias, weights, 1);
3044 REG_QUERY_MPD(dst, dst, 0);
3047 inner_product_forward(const primitive_desc &aprimitive_desc,
3048 const primitive::at &src, const primitive::at weights,
3049 const primitive::at &bias, const memory &dst) {
3050 mkldnn_primitive_t result;
3051 mkldnn_primitive_at_t inputs[] = { src.data, weights.data,
3053 const_mkldnn_primitive_t outputs[] = { dst.get() };
3054 check_num_parameters(aprimitive_desc.get(), 3, 1,
3055 "inner product forward");
3056 error::wrap_c_api(mkldnn_primitive_create(&result,
3057 aprimitive_desc.get(), inputs, outputs),
3058 "could not create a inner product forward primitive");
3062 inner_product_forward(const primitive_desc &aprimitive_desc,
3063 const primitive::at &src, const primitive::at weights,
3064 const memory &dst) {
3065 mkldnn_primitive_t result;
3066 mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
3067 const_mkldnn_primitive_t outputs[] = { dst.get() };
3068 check_num_parameters(aprimitive_desc.get(), 2, 1,
3069 "inner product forward");
3070 error::wrap_c_api(mkldnn_primitive_create(&result,
3071 aprimitive_desc.get(), inputs, outputs),
3072 "could not create a inner product forward primitive");
3077 struct inner_product_backward_data: public primitive {
3079 mkldnn_inner_product_desc_t data;
3080 desc(const memory::desc &diff_src_desc,
3081 const memory::desc &weights_desc,
3082 const memory::desc &diff_dst_desc) {
3084 mkldnn_inner_product_backward_data_desc_init(&data,
3085 &diff_src_desc.data, &weights_desc.data,
3086 &diff_dst_desc.data),
3087 "could not create a inner product backward data descriptor");
3091 struct primitive_desc : public mkldnn::primitive_desc {
3092 primitive_desc(const desc &desc, const engine &e,
3093 const inner_product_forward::primitive_desc &hint_fwd_pd)
3094 : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
3096 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
3097 const inner_product_forward::primitive_desc &hint_fwd_pd)
3098 : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
3100 REG_QUERY_MPD(diff_src, diff_src, 0);
3101 REG_QUERY_MPD(weights, weights, 0);
3102 REG_QUERY_MPD(diff_dst, diff_dst, 0);
3105 inner_product_backward_data(const primitive_desc &aprimitive_desc,
3106 const primitive::at &diff_dst, const primitive::at weights,
3107 const memory &diff_src) {
3108 mkldnn_primitive_t result;
3109 mkldnn_primitive_at_t inputs[] = { diff_dst.data, weights.data };
3110 const_mkldnn_primitive_t outputs[] = { diff_src.get() };
3111 check_num_parameters(aprimitive_desc.get(), 2, 1,
3112 "inner product backward data");
3113 error::wrap_c_api(mkldnn_primitive_create(&result,
3114 aprimitive_desc.get(), inputs, outputs),
3115 "could not create a inner product backward data primitive");
3120 struct inner_product_backward_weights: public primitive {
3122 mkldnn_inner_product_desc_t data;
3123 desc(const memory::desc &src_desc,
3124 const memory::desc &diff_weights_desc,
3125 const memory::desc &diff_bias_desc,
3126 const memory::desc &diff_dst_desc) {
3128 mkldnn_inner_product_backward_weights_desc_init(
3129 &data, &src_desc.data, &diff_weights_desc.data,
3130 &diff_bias_desc.data, &diff_dst_desc.data),
3131 "could not create a inner product backward weights descriptor");
3133 desc(const memory::desc &src_desc,
3134 const memory::desc &diff_weights_desc,
3135 const memory::desc &diff_dst_desc) {
3137 mkldnn_inner_product_backward_weights_desc_init(
3138 &data, &src_desc.data, &diff_weights_desc.data,
3139 nullptr, &diff_dst_desc.data),
3140 "could not create a inner product backward weights descriptor");
3144 struct primitive_desc : public mkldnn::primitive_desc {
3145 primitive_desc(const desc &desc, const engine &e,
3146 const inner_product_forward::primitive_desc &hint_fwd_pd)
3147 : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
3149 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
3150 const inner_product_forward::primitive_desc &hint_fwd_pd)
3151 : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
3153 REG_QUERY_MPD(src, src, 0);
3154 REG_QUERY_MPD(diff_weights, diff_weights, 0);
3155 REG_QUERY_MPD(diff_bias, diff_weights, 1);
3156 REG_QUERY_MPD(diff_dst, diff_dst, 0);
3159 inner_product_backward_weights(const primitive_desc &aprimitive_desc,
3160 const primitive::at &src, const primitive::at diff_dst,
3161 const memory &diff_weights) {
3162 mkldnn_primitive_t result;
3163 mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
3164 const_mkldnn_primitive_t outputs[] = { diff_weights.get() };
3165 check_num_parameters(aprimitive_desc.get(), 2, 1,
3166 "inner product backward weights");
3167 error::wrap_c_api(mkldnn_primitive_create(&result,
3168 aprimitive_desc.get(), inputs, outputs),
3169 "could not create a inner product backward weights primitive");
3173 inner_product_backward_weights(const primitive_desc &aprimitive_desc,
3174 const primitive::at &src, const primitive::at diff_dst,
3175 const memory &diff_weights, const memory &diff_bias) {
3176 mkldnn_primitive_t result;
3177 mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
3178 const_mkldnn_primitive_t outputs[] =
3179 { diff_weights.get(), diff_bias.get()};
3180 check_num_parameters(aprimitive_desc.get(), 2, 2,
3181 "inner product backward weights");
3182 error::wrap_c_api(mkldnn_primitive_create(&result,
3183 aprimitive_desc.get(), inputs, outputs),
3184 "could not create a inner product backward weights primitive");
3191 /// @addtogroup cpp_api_rnn RNN
3192 /// A primitive to compute common recurrent layer.
3194 /// @sa @ref c_api_rnn in @ref c_api
3199 mkldnn_rnn_cell_desc_t c_rnn_cell_;
3201 desc(algorithm kind, algorithm activation_f) {
3202 error::wrap_c_api(mkldnn_rnn_cell_desc_init(&c_rnn_cell_,
3203 mkldnn::convert_to_c(kind),
3204 mkldnn::convert_to_c(activation_f), 0U, 0, 0),
3205 "could not init an rnn cell descriptor");
3207 desc(algorithm kind): desc(kind, algorithm::algorithm_undef) {}
3209 operator const mkldnn_rnn_cell_desc_t*() const { return &c_rnn_cell_; }
3211 algorithm get_cell_kind() const
3212 { return algorithm(c_rnn_cell_.cell_kind); }
3213 algorithm get_activation() const
3214 { return algorithm(c_rnn_cell_.activation_kind); }
3216 float get_alpha() const { return c_rnn_cell_.alpha; }
3217 void set_alpha(float alpha) {
3218 c_rnn_cell_.flags |= mkldnn_rnn_cell_with_relu;
3219 c_rnn_cell_.alpha = alpha;
3222 float get_clipping() const { return c_rnn_cell_.clipping; }
3223 void set_clipping(float clipping) {
3224 c_rnn_cell_.flags |= mkldnn_rnn_cell_with_clipping;
3225 c_rnn_cell_.clipping = clipping;
3228 int get_gates_count() const {
3229 return mkldnn_rnn_cell_get_gates_count(&c_rnn_cell_);
3231 int get_state_count() const {
3232 return mkldnn_rnn_cell_get_states_count(&c_rnn_cell_);
3237 struct rnn_forward : public primitive {
3239 mkldnn_rnn_desc_t data;
3240 desc(prop_kind aprop_kind, rnn_cell::desc cell,
3241 const rnn_direction direction,
3242 const memory::desc &src_layer_desc,
3243 const memory::desc &src_iter_desc,
3244 const memory::desc &weights_layer_desc,
3245 const memory::desc &weights_iter_desc,
3246 const memory::desc &bias_desc,
3247 const memory::desc &dst_layer_desc,
3248 const memory::desc &dst_iter_desc
3250 error::wrap_c_api(mkldnn_rnn_forward_desc_init(&data,
3251 mkldnn::convert_to_c(aprop_kind), cell,
3252 mkldnn::convert_to_c(direction),
3253 &src_layer_desc.data, &src_iter_desc.data,
3254 &weights_layer_desc.data, &weights_iter_desc.data,
3256 &dst_layer_desc.data, &dst_iter_desc.data),
3257 "could not create an RNN forward descriptor");
3262 struct primitive_desc : public mkldnn::primitive_desc {
3263 primitive_desc(const desc &desc, const engine &e)
3264 : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
3266 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
3267 : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
3269 REG_QUERY_MPD(src_layer, src, 0);
3270 REG_QUERY_MPD(src_iter, src, 1);
3271 REG_QUERY_MPD(weights_layer, weights, 0);
3272 REG_QUERY_MPD(weights_iter, weights, 1);
3273 REG_QUERY_MPD(bias, weights, 2);
3274 REG_QUERY_MPD(dst_layer, dst, 0);
3275 REG_QUERY_MPD(dst_iter, dst, 1);
3276 REG_QUERY_MPD(workspace, workspace, 0);
3279 rnn_forward(const primitive_desc &aprimitive_desc,
3280 const primitive::at &src_layer, const primitive::at &src_iter,
3281 const primitive::at &weights_layer,
3282 const primitive::at &weights_iter, const primitive::at &bias,
3283 const memory &dst_layer, const memory &dst_iter,
3284 const memory &workspace) {
3285 mkldnn_primitive_t result;
3286 mkldnn_primitive_at_t inputs[5];
3287 const_mkldnn_primitive_t outputs[3];
3289 inputs[idx++] = src_layer.data;
3290 if (!is_null_memory(src_iter.data.primitive))
3291 inputs[idx++] = src_iter.data;
3292 inputs[idx++] = weights_layer.data;
3293 inputs[idx++] = weights_iter.data;
3294 if (!is_null_memory(bias.data.primitive)) inputs[idx++] = bias.data;
3297 outputs[idx++] = dst_layer.get();
3298 if (!is_null_memory(dst_iter.get())) outputs[idx++] = dst_iter.get();
3299 if (!is_null_memory(workspace.get())) outputs[idx++] = workspace.get();
3301 error::wrap_c_api(mkldnn_primitive_create(&result,
3302 aprimitive_desc.get(), inputs, outputs),
3303 "could not create an RNN forward primitive");
3308 struct rnn_backward : public primitive {
3310 mkldnn_rnn_desc_t data;
3311 desc(prop_kind aprop_kind, rnn_cell::desc cell,
3312 const rnn_direction direction,
3313 const memory::desc &src_layer_desc,
3314 const memory::desc &src_iter_desc,
3315 const memory::desc &weights_layer_desc,
3316 const memory::desc &weights_iter_desc,
3317 const memory::desc &bias_desc,
3318 const memory::desc &dst_layer_desc,
3319 const memory::desc &dst_iter_desc,
3320 const memory::desc &diff_src_layer_desc,
3321 const memory::desc &diff_src_iter_desc,
3322 const memory::desc &diff_weights_layer_desc,
3323 const memory::desc &diff_weights_iter_desc,
3324 const memory::desc &diff_bias_desc,
3325 const memory::desc &diff_dst_layer_desc,
3326 const memory::desc &diff_dst_iter_desc) {
3327 error::wrap_c_api(mkldnn_rnn_backward_desc_init(&data,
3328 mkldnn::convert_to_c(aprop_kind), cell,
3329 mkldnn::convert_to_c(direction),
3330 &src_layer_desc.data, &src_iter_desc.data,
3331 &weights_layer_desc.data, &weights_iter_desc.data,
3333 &dst_layer_desc.data, &dst_iter_desc.data,
3334 &diff_src_layer_desc.data, &diff_src_iter_desc.data,
3335 &diff_weights_layer_desc.data,
3336 &diff_weights_iter_desc.data, &diff_bias_desc.data,
3337 &diff_dst_layer_desc.data, &diff_dst_iter_desc.data),
3338 "could not create an RNN backward descriptor");
3343 struct primitive_desc : public mkldnn::primitive_desc {
3344 primitive_desc(const desc &desc, const engine &e,
3345 const rnn_forward::primitive_desc &hint_fwd_pd)
3346 : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
3348 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
3349 const rnn_forward::primitive_desc &hint_fwd_pd)
3350 : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
3352 REG_QUERY_MPD(src_layer, src, 0);
3353 REG_QUERY_MPD(src_iter, src, 1);
3354 REG_QUERY_MPD(weights_layer, weights, 0);
3355 REG_QUERY_MPD(weights_iter, weights, 1);
3356 REG_QUERY_MPD(bias, weights, 2);
3357 REG_QUERY_MPD(dst_layer, dst, 0);
3358 REG_QUERY_MPD(dst_iter, dst, 1);
3359 REG_QUERY_MPD(workspace, workspace, 0);
3361 REG_QUERY_MPD(diff_src_layer, diff_src, 0);
3362 REG_QUERY_MPD(diff_src_iter, diff_src, 1);
3363 REG_QUERY_MPD(diff_weights_layer, diff_weights, 0);
3364 REG_QUERY_MPD(diff_weights_iter, diff_weights, 1);
3365 REG_QUERY_MPD(diff_bias, diff_weights, 2);
3366 REG_QUERY_MPD(diff_dst_layer, diff_dst, 0);
3367 REG_QUERY_MPD(diff_dst_iter, diff_dst, 1);
3370 // With last iteration (with and without input src_iter)
3371 rnn_backward(const primitive_desc &aprimitive_desc,
3372 const primitive::at &src_layer,
3373 const primitive::at &src_iter,
3374 const primitive::at &weights_layer,
3375 const primitive::at &weights_iter,
3376 const primitive::at &bias,
3377 const primitive::at &dst_layer,
3378 const primitive::at &dst_iter,
3379 const memory &diff_src_layer,
3380 const memory &diff_src_iter,
3381 const memory &diff_weights_layer,
3382 const memory &diff_weights_iter,
3383 const memory &diff_bias,
3384 const primitive::at &diff_dst_layer,
3385 const primitive::at &diff_dst_iter,
3386 const primitive::at &workspace) {
3387 mkldnn_primitive_t result;
3388 mkldnn_primitive_at_t inputs[10];
3389 const_mkldnn_primitive_t outputs[5];
3391 inputs[idx++] = src_layer.data;
3392 if (!is_null_memory(src_iter.data.primitive))
3393 inputs[idx++] = src_iter.data;
3394 inputs[idx++] = weights_layer.data;
3395 inputs[idx++] = weights_iter.data;
3396 if (!is_null_memory(bias.data.primitive))
3397 inputs[idx++] = bias.data;
3398 inputs[idx++] = dst_layer.data;
3399 if (!is_null_memory(dst_iter.data.primitive))
3400 inputs[idx++] = dst_iter.data;
3401 inputs[idx++] = diff_dst_layer.data;
3402 if (!is_null_memory(diff_dst_iter.data.primitive))
3403 inputs[idx++] = diff_dst_iter.data;
3404 inputs[idx++] = workspace.data;
3407 outputs[idx++] = diff_src_layer.get();
3408 if (!is_null_memory(diff_src_iter.get()))
3409 outputs[idx++] = diff_src_iter.get();
3410 outputs[idx++] = diff_weights_layer.get();
3411 outputs[idx++] = diff_weights_iter.get();
3412 if (!is_null_memory(diff_bias.get())) outputs[idx++] = diff_bias.get();
3413 error::wrap_c_api(mkldnn_primitive_create(&result,
3414 aprimitive_desc.get(), inputs, outputs),
3415 "could not create an RNN backward primitive");
3422 /// @addtogroup cpp_api_shuffle Shuffle
3423 /// A primitive to shuffle data along the axis.
3425 /// @sa @ref c_api_shuffle in @ref c_api
3428 struct shuffle_forward : public primitive {
3430 mkldnn_shuffle_desc_t data;
3431 desc(prop_kind aprop_kind, const memory::desc &data_desc,
3432 int axis, int group_size) {
3433 error::wrap_c_api(mkldnn_shuffle_forward_desc_init(&data,
3434 mkldnn::convert_to_c(aprop_kind), &data_desc.data,
3436 "could not create a shuffle forward descriptor");
3440 struct primitive_desc : public mkldnn::primitive_desc {
3441 primitive_desc(const desc &desc, const engine &e)
3442 : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
3444 REG_QUERY_MPD(src, src, 0);
3445 REG_QUERY_MPD(dst, dst, 0);
3448 shuffle_forward(const primitive_desc &aprimitive_desc,
3449 const primitive::at &src, const memory &dst) {
3450 mkldnn_primitive_t result;
3451 mkldnn_primitive_at_t inputs[] = { src.data };
3452 const_mkldnn_primitive_t outputs[] = { dst.get() };
3453 check_num_parameters(aprimitive_desc.get(), 1, 1, "shuffle forward");
3454 error::wrap_c_api(mkldnn_primitive_create(&result,
3455 aprimitive_desc.get(), inputs, outputs),
3456 "could not create a shuffle forward primitive");
3461 struct shuffle_backward : public primitive {
3463 mkldnn_shuffle_desc_t data;
3464 desc(const memory::desc &diff_data_desc, int axis, int group_size) {
3465 error::wrap_c_api(mkldnn_shuffle_backward_desc_init(&data,
3466 &diff_data_desc.data, axis, group_size),
3467 "could not create a shuffle backward descriptor");
3471 struct primitive_desc : public mkldnn::primitive_desc {
3472 primitive_desc(const desc &desc, const engine &e,
3473 const shuffle_forward::primitive_desc &hint_fwd_pd)
3474 : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
3476 REG_QUERY_MPD(diff_src, diff_src, 0);
3477 REG_QUERY_MPD(diff_dst, diff_dst, 0);
3480 shuffle_backward(const primitive_desc &aprimitive_desc,
3481 const primitive::at &diff_dst, const memory &diff_src) {
3482 mkldnn_primitive_t result;
3483 mkldnn_primitive_at_t inputs[] = { diff_dst.data};
3484 const_mkldnn_primitive_t outputs[] = { diff_src.get() };
3485 check_num_parameters(aprimitive_desc.get(), 1, 1, "shuffle backward");
3486 error::wrap_c_api(mkldnn_primitive_create(&result,
3487 aprimitive_desc.get(), inputs, outputs),
3488 "could not create a shuffle backward primitive");
3495 /// @addtogroup cpp_api_binary_convolution Binary convolution
3496 /// A primitive to compute binary convolution using different algorithms.
3498 /// @sa @ref c_api_binary_convolution in @ref c_api
3501 struct binary_convolution_forward: public primitive {
3503 mkldnn_binary_convolution_desc_t data;
3504 desc(prop_kind aprop_kind, algorithm aalgorithm,
3505 const memory::desc &src_desc,
3506 const memory::desc &weights_desc,
3507 const memory::desc &dst_desc,
3508 const memory::dims strides,
3509 const memory::dims dilates,
3510 const memory::dims padding_l,
3511 const memory::dims padding_r,
3512 const float pad_value) {
3513 memory::validate_dims(strides);
3514 memory::validate_dims(dilates);
3515 memory::validate_dims(padding_l);
3516 memory::validate_dims(padding_r);
3518 mkldnn_dilated_binary_convolution_forward_desc_init(&data,
3519 mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
3520 &src_desc.data, &weights_desc.data, &dst_desc.data,
3521 &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
3523 "could not create a dilated binary convolution forward descriptor");
3527 struct primitive_desc : public mkldnn::primitive_desc {
3528 primitive_desc(const desc &desc, const engine &e)
3529 : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
3531 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
3532 : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
3534 REG_QUERY_MPD(src, src, 0);
3535 REG_QUERY_MPD(weights, weights, 0);
3536 REG_QUERY_MPD(dst, dst, 0);
3539 binary_convolution_forward(const primitive_desc &aprimitive_desc,
3540 const primitive::at &src, const primitive::at &weights, const memory &dst) {
3541 mkldnn_primitive_t result;
3542 mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
3543 const_mkldnn_primitive_t outputs[] = { dst.get() };
3544 check_num_parameters(aprimitive_desc.get(), 2, 1,
3545 "binary convolution forward");
3546 error::wrap_c_api(mkldnn_primitive_create(&result,
3547 aprimitive_desc.get(), inputs, outputs),
3548 "could not create a binary convolution forward primitive");
3555 /// @addtogroup cpp_api_binarization Binarization
3558 struct binarization_forward : public primitive {
3560 mkldnn_binarization_desc_t data;
3562 desc(prop_kind aprop_kind, algorithm alg_kind,
3563 const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &output_mask_desc,
3564 const memory::desc &dst_desc) {
3565 error::wrap_c_api(mkldnn_binarization_forward_desc_init(&data,
3566 mkldnn::convert_to_c(aprop_kind),
3567 mkldnn::convert_to_c(alg_kind),
3568 &src_desc.data, &dst_desc.data,
3569 &weights_desc.data, &output_mask_desc.data),
3570 "could not create a binarization forward descriptor");
3574 struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
3575 primitive_desc(const desc &adesc, const engine &aengine) {
3576 mkldnn_primitive_desc_t result;
3577 error::wrap_c_api(mkldnn_primitive_desc_create(
3578 &result, &adesc.data, aengine.get(), nullptr),
3579 "could not create a binarization forward primitive descriptor");
3583 engine get_engine() { return engine::query(*this); }
3586 binarization_forward(const primitive_desc &aprimitive_desc,
3587 const primitive::at &src, const primitive::at &weights, const primitive::at &output_mask,
3588 const memory &dst) {
3589 mkldnn_primitive_t result;
3590 mkldnn_primitive_at_t inputs[] = { src.data, weights.data, output_mask.data};
3591 const_mkldnn_primitive_t outputs[] = { dst.get() };
3592 error::wrap_c_api(mkldnn_primitive_create(&result, aprimitive_desc.get(), inputs, outputs),
3593 "could not create a binarization forward primitive");
3600 /// @addtogroup cpp_api_deformable_convolution Deformable convolution
3601 /// A primitive to compute deformable convolution.
3603 /// @sa @ref c_api_deformable_convolution in @ref c_api
3606 struct deformable_convolution_forward: public primitive {
3608 mkldnn_deformable_convolution_desc_t data;
3609 std::vector<mkldnn_memory_desc_t> c_api_inputs;
3611 desc(prop_kind aprop_kind, algorithm aalgorithm, std::vector<memory::desc> inputs,
3612 const memory::desc &weights_desc,
3613 const memory::desc &bias_desc,
3614 const memory::desc &dst_desc,
3615 const memory::dims strides,
3616 const memory::dims dilates,
3617 const memory::dims padding_l,
3618 const memory::dims padding_r,
3619 const padding_kind apadding_kind,
3620 const int deformable_group) {
3622 for (size_t i = 0; i < inputs.size(); i++) {
3623 c_api_inputs.push_back(inputs[i].data);
3626 memory::validate_dims(strides);
3627 memory::validate_dims(dilates);
3628 memory::validate_dims(padding_l);
3629 memory::validate_dims(padding_r);
3630 error::wrap_c_api(mkldnn_deformable_convolution_forward_desc_init(&data,
3631 mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
3632 &c_api_inputs[0], c_api_inputs.size(), &weights_desc.data, &bias_desc.data,
3633 &dst_desc.data, &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
3634 mkldnn::convert_to_c(apadding_kind), deformable_group),
3635 "could not create a deformable convolution forward descriptor");
3637 desc(prop_kind aprop_kind, algorithm aalgorithm, std::vector<memory::desc> inputs,
3638 const memory::desc &weights_desc,
3639 const memory::desc &dst_desc,
3640 const memory::dims strides,
3641 const memory::dims dilates,
3642 const memory::dims padding_l,
3643 const memory::dims padding_r,
3644 const padding_kind apadding_kind,
3645 const int deformable_group) {
3647 for (size_t i = 0; i < inputs.size(); i++) {
3648 c_api_inputs.push_back(inputs[i].data);
3651 memory::validate_dims(strides);
3652 memory::validate_dims(dilates);
3653 memory::validate_dims(padding_l);
3654 memory::validate_dims(padding_r);
3655 error::wrap_c_api(mkldnn_deformable_convolution_forward_desc_init(&data,
3656 mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
3657 &c_api_inputs[0], c_api_inputs.size(), &weights_desc.data, nullptr,
3658 &dst_desc.data, &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
3659 mkldnn::convert_to_c(apadding_kind), deformable_group),
3660 "could not create a deformable convolution forward descriptor");
3664 struct primitive_desc : public mkldnn::primitive_desc {
3665 primitive_desc(const desc &desc, const engine &e)
3666 : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
3668 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
3669 : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
3671 REG_QUERY_MPD(src, src, 0);
3672 REG_QUERY_MPD(weights, weights, 0);
3673 REG_QUERY_MPD(bias, weights, 1);
3674 REG_QUERY_MPD(dst, dst, 0);
3677 deformable_convolution_forward(const primitive_desc &aprimitive_desc,
3678 const std::vector<primitive::at> &inputs, const primitive::at &weights,
3679 const primitive::at &bias, const memory &dst) {
3680 mkldnn_primitive_t result;
3682 mkldnn_primitive_at_t p_inputs[] = { inputs[0].data, inputs[1].data, weights.data,
3684 const_mkldnn_primitive_t outputs[] = { dst.get() };
3685 error::wrap_c_api(mkldnn_primitive_create(&result,
3686 aprimitive_desc.get(), &p_inputs[0], outputs),
3687 "could not create a deformable convolution forward bias primitive");
3691 deformable_convolution_forward(const primitive_desc &aprimitive_desc,
3692 const std::vector<primitive::at> &inputs, const primitive::at &weights, const memory &dst) {
3693 mkldnn_primitive_t result;
3694 mkldnn_primitive_at_t p_inputs[] = { inputs[0].data, inputs[1].data, weights.data, };
3695 const_mkldnn_primitive_t outputs[] = { dst.get() };
3696 check_num_parameters(aprimitive_desc.get(), 3, 1,
3697 "deformable convolution forward");
3698 error::wrap_c_api(mkldnn_primitive_create(&result,
3699 aprimitive_desc.get(), &p_inputs[0], outputs),
3700 "could not create a deformable convolution forward primitive");
3709 /// @addtogroup cpp_api_stream Stream
3710 /// Execution stream operations.
3712 /// @sa @ref c_api_stream in @ref c_api
3715 #ifndef DOXYGEN_SHOULD_SKIP_THIS
3716 template <> struct handle_traits<mkldnn_stream_t> {
3717 static constexpr auto destructor = &mkldnn_stream_destroy;
3721 struct stream: public handle<mkldnn_stream_t> {
3722 using handle::handle;
3724 enum kind { any = mkldnn_stream_kind_t::mkldnn_any_stream,
3725 eager = mkldnn_stream_kind_t::mkldnn_eager,
3726 lazy = mkldnn_stream_kind_t::mkldnn_lazy };
3728 static mkldnn_stream_kind_t convert_to_c(kind akind) {
3729 return static_cast<mkldnn_stream_kind_t>(akind);
3731 /// Constructs a stream.
3732 stream(kind akind) {
3733 mkldnn_stream_t astream;
3734 error::wrap_c_api(mkldnn_stream_create(&astream,
3735 convert_to_c(akind)),
3736 "could not create a stream");
3740 /// Submits a vector of primitives to a stream for computations.
3742 /// @param primitives The vector of primitives to submit.
3743 /// @returns The stream.
3744 stream &submit(std::vector<primitive> primitives) {
3745 // TODO: find a proper way to convert vector<primitive> to
3746 // vector<mkldnn_primitive_t>
3747 if (primitives.size() == 0) return *this;
3748 std::vector<mkldnn_primitive_t> c_api_primitives;
3749 c_api_primitives.reserve(primitives.size());
3750 auto convert_to_c = [](primitive p) { return p.get(); };
3751 std::transform(primitives.begin(), primitives.end(),
3752 std::back_inserter(c_api_primitives), convert_to_c);
3754 mkldnn_primitive_t c_api_error_primitive;
3756 mkldnn_stream_submit(get(),
3757 c_api_primitives.size(), &c_api_primitives[0],
3758 &c_api_error_primitive),
3759 "could not submit primitives to a stream",
3760 &c_api_error_primitive);
3765 /// Waits for all computations submitted to the stream to complete.
3767 /// @param block Specifies whether the operation should wait indefinitely or
3768 /// return immediately.
3769 /// @returns @c true if all computations completed.
3770 /// @returns @c false if not all computations completed.
3771 bool wait(bool block = true) {
3772 mkldnn_primitive_t c_api_error_primitive;
3773 mkldnn_status_t status = mkldnn_stream_wait(get(),
3774 block, &c_api_error_primitive);
3775 if (status != mkldnn_success
3776 && status != mkldnn_try_again)
3777 error::wrap_c_api(status, "could not wait on a stream",
3778 &c_api_error_primitive);
3779 return (status == mkldnn_success);
3783 mkldnn_primitive_t c_api_error_primitive;
3785 mkldnn_stream_rerun(get(), &c_api_error_primitive),
3786 "could not rerun a stream", &c_api_error_primitive);
3791 #undef REG_QUERY_MPD
3797 } // namespace mkldnn