updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / include / mkldnn.hpp
1 /*******************************************************************************
2 * Copyright 2016-2019 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef MKLDNN_HPP
18 #define MKLDNN_HPP
19
20 #ifndef DOXYGEN_SHOULD_SKIP_THIS
21 #include <stdlib.h>
22 #include <memory>
23 #include <vector>
24 #include <algorithm>
25 #include <iterator>
26 #include <string>
27
28 #include "mkldnn.h"
29 #endif
30
31 namespace mkldnn {
32
33 /// @addtogroup cpp_api C++ API
34 /// @{
35
36 /// @addtogroup cpp_api_utils Utils
37 /// @{
38
39 /// A class that provides the destructor for an Intel(R) MKL-DNN C handle
40 template <typename T> class handle_traits {};
41
42 /// A class for wrapping an Intel(R) MKL-DNN handle. It is used as the base
43 /// class for primitive (#mkldnn_primitive_t), engine (#mkldnn_engine_t), and
44 /// stream (#mkldnn_stream_t) handles. An object of the #mkldnn::handle class
45 /// can be passed by value. This class enables wrapping:
46 ///  - Newly constructed handles.
47 ///    @n In this case, the constructed handle uses reference counting provided
48 ///    by @p std::shared_ptr with a proper deleter function specified through
49 ///    the @p handle_traits class.
50 ///  - Pre-existing handles returned by the Intel(R) MKL-DNN C API (for
51 ///    example, through #mkldnn_primitive_get_output()).
52 ///    @n In this case, an Intel(R) MKL-DNN C API handle is wrapped without a
53 ///    deleter because it is assumed that the handle wrapper for the original
54 ///    object deletes the handle (this model is similar to @p std::weak_ptr).
55 template <typename T, typename traits=handle_traits<T>> class handle {
56 private:
57     std::shared_ptr<typename std::remove_pointer<T>::type> _data;
58     handle(const handle &&) = delete;
59     handle &operator=(const handle &&other) = delete;
60 protected:
61     bool operator==(const T other) const { return other == _data.get(); }
62     bool operator!=(const T other) const { return !(*this == other); }
63 public:
64     /// Constructs a C handle wrapper.
65     /// @param t The C handle to wrap.
66     /// @param weak A flag to specify whether to construct a weak wrapper.
67     handle(T t = 0, bool weak = false): _data(0) {
68         reset(t, weak);
69     }
70
71     handle(const handle &other): _data(other._data) {}
72     handle &operator=(const handle &other) {
73         _data = other._data;
74         return *this;
75     }
76     /// Resets the value of a C handle.
77     /// @param t The new value of the C handle.
78     /// @param weak A flag to specify whether the wrapper should be weak.
79     void reset(T t, bool weak = false) {
80         if (weak) _data.reset(t, [](T) { return decltype(traits::destructor(0))(0); });
81         else _data.reset(t, traits::destructor);
82     }
83
84     /// Returns the value of the underlying C handle.
85     T get() const { return _data.get(); }
86
87     bool operator==(const handle &other) const { return other._data.get() == _data.get(); }
88     bool operator!=(const handle &other) const { return !(*this == other); }
89 };
90
91 #ifndef DOXYGEN_SHOULD_SKIP_THIS
92 template <> struct handle_traits<mkldnn_primitive_desc_t> {
93     static constexpr auto destructor = &mkldnn_primitive_desc_destroy;
94 };
95
96 template <> struct handle_traits<mkldnn_primitive_t> {
97     static constexpr auto destructor = &mkldnn_primitive_destroy;
98 };
99
100 template <> struct handle_traits<mkldnn_primitive_desc_iterator_t> {
101     static constexpr auto destructor = &mkldnn_primitive_desc_iterator_destroy;
102 };
103 #endif
104
105 /// Base class for all computational primitives.
106 class primitive: public handle<mkldnn_primitive_t> {
107     friend struct error;
108     friend struct stream;
109     friend class primitive_at;
110     using handle::handle;
111 public:
112     /// A proxy to C primitive kind enum
113     enum class kind {
114         undefined_primitive = mkldnn_undefined_primitive,
115         memory = mkldnn_memory,
116         view = mkldnn_view,
117         reorder = mkldnn_reorder,
118         concat = mkldnn_concat,
119         concat_inplace = mkldnn_concat_inplace,
120         sum = mkldnn_sum,
121         convolution = mkldnn_convolution,
122         deconvolution = mkldnn_deconvolution,
123         shuffle = mkldnn_shuffle,
124         eltwise = mkldnn_eltwise,
125         depthwise = mkldnn_depthwise,
126         softmax = mkldnn_softmax,
127         pooling = mkldnn_pooling,
128         lrn = mkldnn_lrn,
129         batch_normalization = mkldnn_batch_normalization,
130         inner_product = mkldnn_inner_product,
131         rnn = mkldnn_rnn,
132         binary_convolution = mkldnn_binary_convolution,
133         binarization = mkldnn_binarization,
134         deformable_convolution = mkldnn_deformable_convolution,
135     };
136
137     /// A wrapper structure to specify a particular output of a primitive.
138     struct at {
139         /// The underlying C API structure.
140         mkldnn_primitive_at_t data;
141         /// Constructs a wrapper specifying @p aprimitive output with index @p
142         /// at.
143         ///
144         /// @param aprimitive The target primitive.
145         /// @param at The output index.
146
147         at(const primitive &aprimitive, size_t at = 0)
148             : data(mkldnn_primitive_at(aprimitive.get(), at)) {}
149         /// Returns the specified output.
150         inline operator primitive() const;
151     };
152
153     /// Returns the descriptor of the underlying C API primitive.
154     inline const_mkldnn_primitive_desc_t get_primitive_desc() const;
155     // TODO: use the C++ API wrapper structure.
156 };
157
158 inline mkldnn_primitive_kind_t convert_to_c(primitive::kind akind) {
159     return static_cast<mkldnn_primitive_kind_t>(akind);
160 }
161 /// Intel(R) MKL-DNN exception class.
162 ///
163 /// This class captures the status returned by the failed C API function, error
164 /// message, and, optionally, handle of the primitive that caused the error.
165 struct error: public std::exception {
166     mkldnn_status_t status;
167     std::string message;
168     primitive error_primitive;
169
170     /// Constructs an error instance.
171     ///
172     /// @param astatus The error status returned by the C API.
173     /// @param amessage The error message.
174     /// @param aerror_primitive (optional) A C handle of the primitive that
175     ///                         caused the error.
176
177     error(mkldnn_status_t astatus, std::string amessage,
178             mkldnn_primitive_t aerror_primitive = 0)
179         : status(astatus)
180         , message(amessage)
181         , error_primitive(aerror_primitive, true)
182     {}
183
184     /// A convenience function for wrapping calls to the C API. Checks the
185     /// return status and throws an #error in case of failure.
186     ///
187     /// @param status The error status returned by the C API.
188     /// @param message The error message.
189     /// @param error_primitive (optional) A C handle of the primitive that
190     ///                        caused the error.
191
192     static void wrap_c_api(mkldnn_status_t status,
193             const std::string &message,
194             mkldnn_primitive_t *error_primitive = 0)
195     {
196         if (status != mkldnn_success) {
197             if (nullptr != error_primitive)
198                 throw error(status, message, *error_primitive);
199             else
200                 throw error(status, message, nullptr);
201         }
202     }
203 };
204
205 inline primitive::at::operator primitive() const {
206     const_mkldnn_primitive_t output;
207     error::wrap_c_api(
208             mkldnn_primitive_get_output(data.primitive,
209                 data.output_index, &output),
210             "could not get an output primitive");
211     return primitive(const_cast<mkldnn_primitive_t>(output), true);
212 }
213
214 const_mkldnn_primitive_desc_t primitive::get_primitive_desc() const {
215     const_mkldnn_primitive_desc_t pd;
216     error::wrap_c_api(mkldnn_primitive_get_primitive_desc(get(), &pd),
217             "could not get primitive descriptor by primitive");
218     return pd;
219 }
220 /// @}
221
222 /// @addtogroup cpp_api_enums Common data types and enumerations
223 /// A proxy to @ref c_api_types in @ref c_api.
224 ///
225 /// @{
226
227 enum round_mode {
228     round_nearest = mkldnn_round_nearest,
229     round_down = mkldnn_round_down,
230 };
231
232 inline mkldnn_round_mode_t convert_to_c(round_mode mode) {
233     return static_cast<mkldnn_round_mode_t>(mode);
234 }
235
236 enum padding_kind {
237     zero = mkldnn_padding_zero
238 };
239
240 inline mkldnn_padding_kind_t convert_to_c(padding_kind kind) {
241     return static_cast<mkldnn_padding_kind_t>(kind);
242 }
243
244 enum prop_kind {
245     forward_training = mkldnn_forward_training,
246     forward_scoring = mkldnn_forward_scoring,
247     forward_inference = mkldnn_forward_inference,
248     forward = mkldnn_forward,
249     backward = mkldnn_backward,
250     backward_data = mkldnn_backward_data,
251     backward_weights = mkldnn_backward_weights,
252     backward_bias = mkldnn_backward_bias
253 };
254
255 inline mkldnn_prop_kind_t convert_to_c(prop_kind kind) {
256     return static_cast<mkldnn_prop_kind_t>(kind);
257 }
258
259 enum algorithm {
260     algorithm_undef = mkldnn_alg_kind_undef,
261     convolution_auto = mkldnn_convolution_auto,
262     convolution_direct = mkldnn_convolution_direct,
263     convolution_winograd = mkldnn_convolution_winograd,
264     deconvolution_direct = mkldnn_deconvolution_direct,
265     deconvolution_winograd = mkldnn_deconvolution_winograd,
266     eltwise_relu = mkldnn_eltwise_relu,
267     eltwise_tanh = mkldnn_eltwise_tanh,
268     eltwise_elu = mkldnn_eltwise_elu,
269     eltwise_square = mkldnn_eltwise_square,
270     eltwise_abs = mkldnn_eltwise_abs,
271     eltwise_sqrt = mkldnn_eltwise_sqrt,
272     eltwise_linear = mkldnn_eltwise_linear,
273     eltwise_bounded_relu = mkldnn_eltwise_bounded_relu,
274     eltwise_soft_relu = mkldnn_eltwise_soft_relu,
275     eltwise_logistic = mkldnn_eltwise_logistic,
276     eltwise_clamp = mkldnn_eltwise_clamp,
277     eltwise_exp = mkldnn_eltwise_exp,
278     eltwise_not = mkldnn_eltwise_not,
279     depthwise_scale_shift = mkldnn_depthwise_scale_shift,
280     depthwise_prelu = mkldnn_depthwise_prelu,
281     lrn_across_channels = mkldnn_lrn_across_channels,
282     lrn_within_channel  = mkldnn_lrn_within_channel,
283     pooling_max = mkldnn_pooling_max,
284     pooling_avg = mkldnn_pooling_avg,
285     pooling_avg_include_padding = mkldnn_pooling_avg_include_padding,
286     pooling_avg_exclude_padding = mkldnn_pooling_avg_exclude_padding,
287     vanilla_rnn = mkldnn_vanilla_rnn,
288     vanilla_lstm = mkldnn_vanilla_lstm,
289     vanilla_gru = mkldnn_vanilla_gru,
290     gru_linear_before_reset = mkldnn_gru_linear_before_reset,
291     roi_pooling_max = mkldnn_roi_pooling_max,
292     roi_pooling_bilinear = mkldnn_roi_pooling_bilinear,
293     binary_convolution_direct = mkldnn_binary_convolution_direct,
294     binarization_depthwise = mkldnn_binarization_depthwise,
295     deformable_convolution_direct = mkldnn_deformable_convolution_direct,
296 };
297
298 inline mkldnn_alg_kind_t convert_to_c(algorithm aalgorithm) {
299     return static_cast<mkldnn_alg_kind_t>(aalgorithm);
300 }
301
302 enum batch_normalization_flag {
303     use_global_stats = mkldnn_use_global_stats,
304     use_scale_shift = mkldnn_use_scaleshift,
305     fuse_bn_relu = mkldnn_fuse_bn_relu
306 };
307
308 inline mkldnn_batch_normalization_flag_t convert_to_c(
309         batch_normalization_flag aflag) {
310     return static_cast<mkldnn_batch_normalization_flag_t>(aflag);
311 }
312
313 enum rnn_direction {
314     unidirectional_left2right = mkldnn_unidirectional_left2right,
315     unidirectional_right2left = mkldnn_unidirectional_right2left,
316     unidirectional = mkldnn_unidirectional,
317     bidirectional_concat = mkldnn_bidirectional_concat,
318     bidirectional_sum = mkldnn_bidirectional_sum,
319 };
320
321 inline mkldnn_rnn_direction_t convert_to_c(rnn_direction adir) {
322     return static_cast<mkldnn_rnn_direction_t>(adir);
323 }
324
325 enum query {
326     undef = mkldnn_query_undef,
327
328     eengine = mkldnn_query_engine,
329     primitive_kind = mkldnn_query_primitive_kind,
330
331     num_of_inputs_s32 = mkldnn_query_num_of_inputs_s32,
332     num_of_outputs_s32 = mkldnn_query_num_of_outputs_s32,
333
334     time_estimate_f64 = mkldnn_query_time_estimate_f64,
335     memory_consumption_s64 = mkldnn_query_memory_consumption_s64,
336
337     impl_info_str = mkldnn_query_impl_info_str,
338
339     op_d = mkldnn_query_op_d,
340     memory_d = mkldnn_query_memory_d,
341     convolution_d = mkldnn_query_convolution_d,
342     deconvolution_d = mkldnn_query_deconvolution_d,
343     shuffle_d = mkldnn_query_shuffle_d,
344     eltwise_d = mkldnn_query_eltwise_d,
345     depthwise_d = mkldnn_query_depthwise_d,
346     softmax_d = mkldnn_query_softmax_d,
347     pooling_d = mkldnn_query_pooling_d,
348     lrn_d = mkldnn_query_lrn_d,
349     batch_normalization_d = mkldnn_query_batch_normalization_d,
350     inner_product_d = mkldnn_query_inner_product_d,
351     rnn_d = mkldnn_query_rnn_d,
352     binary_convolution_d = mkldnn_query_binary_convolution_d,
353     binarization_d = mkldnn_query_binarization_d,
354     deformable_convolution_d = mkldnn_query_deformable_convolution_d,
355
356     input_pd = mkldnn_query_input_pd,
357     output_pd = mkldnn_query_output_pd,
358     src_pd = mkldnn_query_src_pd,
359     diff_src_pd = mkldnn_query_diff_src_pd,
360     weights_pd = mkldnn_query_weights_pd,
361     diff_weights_pd = mkldnn_query_diff_weights_pd,
362     dst_pd = mkldnn_query_dst_pd,
363     diff_dst_pd = mkldnn_query_diff_dst_pd,
364     workspace_pd = mkldnn_query_workspace_pd,
365 };
366
367 inline mkldnn_query_t convert_to_c(query aquery) {
368     return static_cast<mkldnn_query_t>(aquery);
369 }
370
371 /// @}
372
373 /// @addtogroup cpp_api_attr Attributes
374 /// An extension for controlling primitive behavior.
375 ///
376 /// @sa @ref c_api_attributes in @ref c_api
377 /// @{
378
379 #ifndef DOXYGEN_SHOULD_SKIP_THIS
380 template <> struct handle_traits<mkldnn_post_ops_t> {
381     static constexpr auto destructor = &mkldnn_post_ops_destroy;
382 };
383 #endif
384
385 struct post_ops: public handle<mkldnn_post_ops_t> {
386     post_ops() {
387         mkldnn_post_ops_t result;
388         error::wrap_c_api(mkldnn_post_ops_create(&result),
389                 "could not create post operation sequence");
390         reset(result);
391     }
392
393     int len() const { return mkldnn_post_ops_len(get()); }
394
395     primitive::kind kind(int index) const {
396         error::wrap_c_api(
397                 index < len() ? mkldnn_success : mkldnn_invalid_arguments,
398                 "post_ops index is out of range");
399         return static_cast<primitive::kind>(mkldnn_post_ops_get_kind(get(),
400                     index));
401     }
402
403     void append_sum(float scale = 1.) {
404         error::wrap_c_api(mkldnn_post_ops_append_sum(get(), scale),
405                 "could not append sum");
406     }
407
408     void get_params_sum(int index, float &scale) const {
409         error::wrap_c_api(mkldnn_post_ops_get_params_sum(get(), index, &scale),
410                 "could not get sum params");
411     }
412
413     void append_eltwise(float scale, algorithm alg, float alpha,
414             float beta) {
415         error::wrap_c_api(mkldnn_post_ops_append_eltwise(get(), scale,
416                     convert_to_c(alg), alpha, beta),
417                 "could not append eltwise");
418     }
419
420     void get_params_eltwise(int index, float &scale, algorithm &alg,
421             float &alpha, float &beta) const {
422         mkldnn_alg_kind_t c_alg;
423         error::wrap_c_api(mkldnn_post_ops_get_params_eltwise(get(), index,
424                     &scale, &c_alg, &alpha, &beta),
425                 "could not get eltwise params");
426         alg = static_cast<algorithm>(c_alg);
427     }
428
429     void append_depthwise(algorithm alg, const float* weights_data,
430             const float* biases_data) {
431         error::wrap_c_api(mkldnn_post_ops_append_depthwise(get(),
432                     convert_to_c(alg), weights_data, biases_data),
433                 "could not append depthwise");
434     }
435
436     void get_params_depthwise(int index, algorithm &alg,
437             const float** weights_data, const float** biases_data) const {
438         mkldnn_alg_kind_t c_alg;
439         error::wrap_c_api(mkldnn_post_ops_get_params_depthwise(get(), index,
440                     &c_alg, weights_data, biases_data),
441                 "could not get depthwise params");
442         alg = static_cast<algorithm>(c_alg);
443     }
444
445     void append_dw_conv(int in_h, int in_w, int ker_h, int ker_w, int str_h, int str_w, mkldnn_data_type_t in_dt,
446             const float* weights_data, const float* biases_data) {
447         error::wrap_c_api(mkldnn_post_ops_append_dw_conv(get(),
448                 in_h, in_w, ker_h, ker_w, str_h, str_w, in_dt, weights_data, biases_data),
449                           "could not append dw conv");
450     }
451
452     void get_params_dw_conv(int index, int &in_h, int &in_w, int &ker_h, int &ker_w, int &str_h, int &str_w, mkldnn_data_type_t& in_dt,
453             const float** weights_data, const float** biases_data) const {
454         error::wrap_c_api(mkldnn_post_ops_get_params_dw_conv(get(), index,
455                 &in_h, &in_w, &ker_h, &ker_w, &str_h, &str_w, &in_dt, weights_data, biases_data),
456                           "could not get dw conv params");
457     }
458
459     void append_binarization(algorithm alg, const float* weights_data, const float* output_mask) {
460         error::wrap_c_api(mkldnn_post_ops_append_binarization(get(), convert_to_c(alg), weights_data, output_mask),
461                 "could not append binarization");
462     }
463
464     void get_params_binarization(int index, algorithm &alg, const float** weights_data, const float** output_mask) const {
465         mkldnn_alg_kind_t c_alg;
466         error::wrap_c_api(mkldnn_post_ops_get_params_binarization(get(), index, &c_alg, weights_data, output_mask),
467                 "could not get binarization params");
468         alg = static_cast<algorithm>(c_alg);
469     }
470 };
471
472 #ifndef DOXYGEN_SHOULD_SKIP_THIS
473 template <> struct handle_traits<mkldnn_primitive_attr_t> {
474     static constexpr auto destructor = &mkldnn_primitive_attr_destroy;
475 };
476 #endif
477
478 struct primitive_attr: public handle<mkldnn_primitive_attr_t> {
479     primitive_attr() {
480         mkldnn_primitive_attr_t result;
481         error::wrap_c_api(mkldnn_primitive_attr_create(&result),
482                 "could not create a primitive attr");
483         reset(result);
484     }
485
486     round_mode get_int_output_round_mode() const {
487         mkldnn_round_mode_t result;
488         error::wrap_c_api(mkldnn_primitive_attr_get_int_output_round_mode(
489                     get(), &result), "could not get int output round mode");
490         return round_mode(result);
491     }
492
493     void set_int_output_round_mode(round_mode mode) {
494         error::wrap_c_api(mkldnn_primitive_attr_set_int_output_round_mode(
495                     get(), mkldnn::convert_to_c(mode)),
496                 "could not set int output round mode");
497     }
498
499     void get_output_scales(int &mask, std::vector<float> &scales) const
500     {
501         int count, c_mask;
502         const float *c_scales;
503         error::wrap_c_api(mkldnn_primitive_attr_get_output_scales(get(),
504                     &count, &c_mask, &c_scales),
505                 "could not get int output scales");
506         scales.resize(count);
507
508         mask = c_mask;
509         for (int c = 0; c < count; ++c)
510             scales[c] = c_scales[c];
511     }
512
513     void set_output_scales(int mask, const std::vector<float> &scales)
514     {
515         error::wrap_c_api(mkldnn_primitive_attr_set_output_scales(get(),
516                     (int)scales.size(), mask, &scales[0]),
517                 "could not set int output scales");
518     }
519
520     const post_ops get_post_ops() const {
521         post_ops result;
522         const_mkldnn_post_ops_t c_result;
523         error::wrap_c_api(mkldnn_primitive_attr_get_post_ops(get(), &c_result),
524                 "could not get post operation sequence");
525         result.reset(const_cast<mkldnn_post_ops_t>(c_result), true);
526         return result;
527     }
528
529     void set_post_ops(post_ops ops) {
530         error::wrap_c_api(mkldnn_primitive_attr_set_post_ops(get(), ops.get()),
531                 "could not set post operation sequence");
532     }
533
534     void set_rnn_data_qparams(const float scale, const float shift)
535     {
536         error::wrap_c_api(mkldnn_primitive_attr_set_rnn_data_qparams(get(),
537                     scale, shift), "could not set rnn data int scale/shift");
538     }
539
540     void set_rnn_weights_qparams(int mask, const std::vector<float> &scales)
541     {
542         error::wrap_c_api(mkldnn_primitive_attr_set_rnn_weights_qparams(get(),
543                     (int)scales.size(), mask, &scales[0]),
544                 "could not set rnn weights int scales");
545     }
546 };
547
548 /// @}
549
550 /// @addtogroup cpp_api_engine Engine
551 /// Engine operations.
552 ///
553 /// @sa @ref c_api_engine in @ref c_api
554 /// @{
555
556 #ifndef DOXYGEN_SHOULD_SKIP_THIS
557 template <> struct handle_traits<mkldnn_engine_t> {
558     static constexpr auto destructor = &mkldnn_engine_destroy;
559 };
560 #endif
561
562 /// An execution engine.
563 struct engine: public handle<mkldnn_engine_t> {
564     friend class primitive;
565     // gcc bug??? using handle::handle;
566
567     /// Kinds of engines.
568     enum kind {
569         /// An unspecified engine
570         any = mkldnn_any_engine,
571         /// CPU engine
572         cpu = mkldnn_cpu,
573     };
574
575     /// Returns the number of engines of a certain kind.
576     ///
577     /// @param akind The kind of engines to count.
578
579     static size_t get_count(kind akind) {
580         return mkldnn_engine_get_count(convert_to_c(akind));
581     }
582
583     /// Constructs an engine.
584     ///
585     /// @param akind The kind of engine to construct.
586     /// @param index The index of the engine. Must be less than the value
587     ///              returned by #get_count() for this particular kind of engine.
588
589     engine(kind akind, size_t index) {
590         mkldnn_engine_t aengine;
591         error::wrap_c_api(
592                 mkldnn_engine_create(&aengine,
593                     convert_to_c(akind), index),
594                 "could not create an engine");
595         reset(aengine);
596     }
597
598     explicit engine(const mkldnn_engine_t& aengine)
599         : handle(aengine, true) {}
600
601     engine(const handle<mkldnn_primitive_desc_t> &pd) {
602         mkldnn_engine_t engine_q;
603         error::wrap_c_api(
604                 mkldnn_primitive_desc_query(pd.get(),
605                     mkldnn::convert_to_c(eengine), 0, &engine_q),
606                 "could not get engine from primitive_desc");
607         reset(engine_q, true);
608     }
609
610     template <class primitive_desc>
611     static engine query(const primitive_desc &pd) {
612         mkldnn_engine_t engine_q;
613         error::wrap_c_api(
614                 mkldnn_primitive_desc_query(pd.get(),
615                     mkldnn::convert_to_c(eengine), 0, &engine_q),
616                 "could not get engine from primitive_desc");
617
618         return engine(engine_q);
619     }
620
621 private:
622     static mkldnn_engine_kind_t convert_to_c(kind akind) {
623         return static_cast<mkldnn_engine_kind_t>(akind);
624     }
625 };
626
627 /// @}
628
629 /// @addtogroup cpp_api_memory_related Memory and memory related operations
630 /// @{
631
632 /// @addtogroup cpp_api_memory Memory
633 /// A primitive to describe and store data.
634 ///
635 /// For more information, refer to @ref c_api_memory in @ref c_api.
636 /// @{
637
638 /// Memory primitive that describes the data.
639 struct memory: public primitive  {
640     private:
641     std::shared_ptr<char> _handle;
642
643     public:
644     typedef std::vector<std::remove_extent<mkldnn_dims_t>::type> dims;
645
646     template <typename T> static void validate_dims(std::vector<T> v) {
647         if (v.size() > TENSOR_MAX_DIMS)
648             throw error(mkldnn_invalid_arguments,
649                     "invalid dimensions");
650     }
651
652     /// Data type specification. See #mkldnn_data_type_t for a detailed
653     /// description.
654     enum data_type {
655         data_undef = mkldnn_data_type_undef,
656         f32 = mkldnn_f32,
657         s32 = mkldnn_s32,
658         bf16 = mkldnn_bf16,
659         s16 = mkldnn_s16,
660         s8 = mkldnn_s8,
661         u8 = mkldnn_u8,
662         bin = mkldnn_bin,
663     };
664
665     /// Memory format specification. See #mkldnn_memory_format_t
666     /// for a detailed description.
667     enum format {
668         format_undef = mkldnn_format_undef,
669         any = mkldnn_any,
670         blocked = mkldnn_blocked,
671         x = mkldnn_x,
672         nc = mkldnn_nc,
673         ncw = mkldnn_ncw,
674         nwc = mkldnn_nwc,
675         nCw16c = mkldnn_nCw16c,
676         nchw = mkldnn_nchw,
677         nhwc = mkldnn_nhwc,
678         chwn = mkldnn_chwn,
679         nCw4c = mkldnn_nCw4c,
680         nCw8c = mkldnn_nCw8c,
681         nChw4c = mkldnn_nChw4c,
682         nChw8c = mkldnn_nChw8c,
683         nChw16c = mkldnn_nChw16c,
684         ncdhw = mkldnn_ncdhw,
685         ndhwc = mkldnn_ndhwc,
686         nCdhw4c = mkldnn_nCdhw4c,
687         nCdhw8c = mkldnn_nCdhw8c,
688         nCdhw16c = mkldnn_nCdhw16c,
689         oi = mkldnn_oi,
690         io = mkldnn_io,
691         oiw = mkldnn_oiw,
692         wio = mkldnn_wio,
693         Owi4o = mkldnn_Owi4o,
694         OIw4i4o = mkldnn_OIw4i4o,
695         Owi8o = mkldnn_Owi8o,
696         OIw8o8i = mkldnn_OIw8o8i,
697         OIw8i8o = mkldnn_OIw8i8o,
698         OIw16i16o = mkldnn_OIw16i16o,
699         OIw16o16i = mkldnn_OIw16o16i,
700         Oiw4o = mkldnn_Oiw4o,
701         Oiw16o = mkldnn_Oiw16o,
702         Owi16o = mkldnn_Owi16o,
703         OIw8i16o2i = mkldnn_OIw8i16o2i,
704         OIw8o16i2o = mkldnn_OIw8o16i2o,
705         IOw8o16i2o = mkldnn_IOw8o16i2o,
706         IOw16o16i = mkldnn_IOw16o16i,
707         OIw4i16o4i = mkldnn_OIw4i16o4i,
708         OIw4i16o4i_s8s8 = mkldnn_OIw4i16o4i_s8s8,
709         oihw = mkldnn_oihw,
710         ihwo = mkldnn_ihwo,
711         hwio = mkldnn_hwio,
712         iohw = mkldnn_iohw,
713         hwio_s8s8 = mkldnn_hwio_s8s8,
714         dhwio = mkldnn_dhwio,
715         oidhw = mkldnn_oidhw,
716         OIdhw4i4o = mkldnn_OIdhw4i4o,
717         Odhwi4o = mkldnn_Odhwi4o,
718         OIdhw8i8o = mkldnn_OIdhw8i8o,
719         OIdhw8o8i = mkldnn_OIdhw8o8i,
720         Odhwi8o = mkldnn_Odhwi8o,
721         OIdhw16i16o = mkldnn_OIdhw16i16o,
722         OIdhw16o16i = mkldnn_OIdhw16o16i,
723         Oidhw4o = mkldnn_Oidhw4o,
724         Oidhw16o = mkldnn_Oidhw16o,
725         Odhwi16o = mkldnn_Odhwi16o,
726         oIhw8i = mkldnn_oIhw8i,
727         oIhw16i = mkldnn_oIhw16i,
728         oIdhw8i = mkldnn_oIdhw8i,
729         oIdhw16i = mkldnn_oIdhw16i,
730         OIhw4i4o = mkldnn_OIhw4i4o,
731         OIhw8i8o = mkldnn_OIhw8i8o,
732         OIhw16i16o = mkldnn_OIhw16i16o,
733         OIhw8o8i = mkldnn_OIhw8o8i,
734         OIhw16o16i = mkldnn_OIhw16o16i,
735         IOhw16o16i = mkldnn_IOhw16o16i,
736         OIhw8i16o2i = mkldnn_OIhw8i16o2i,
737         IOhw8i16o2i = mkldnn_IOhw8i16o2i,
738         OIhw8o16i2o = mkldnn_OIhw8o16i2o,
739         IOhw8o16i2o = mkldnn_IOhw8o16i2o,
740         OIdhw8i16o2i = mkldnn_OIdhw8i16o2i,
741         OIdhw8o16i2o = mkldnn_OIdhw8o16i2o,
742         IOdhw8o16i2o = mkldnn_IOdhw8o16i2o,
743         OIhw4i16o4i = mkldnn_OIhw4i16o4i,
744         OIhw4i16o4i_s8s8 = mkldnn_OIhw4i16o4i_s8s8,
745         Oihw8o = mkldnn_Oihw8o,
746         Oihw4o = mkldnn_Oihw4o,
747         Oihw16o = mkldnn_Oihw16o,
748         Ohwi8o = mkldnn_Ohwi8o,
749         Ohwi4o = mkldnn_Ohwi4o,
750         Ohwi16o = mkldnn_Ohwi16o,
751         OhIw16o4i = mkldnn_OhIw16o4i,
752         OhIw8o4i = mkldnn_OhIw8o4i,
753         OhIw8o32i = mkldnn_OhIw8o32i,
754         OhIw16o32i = mkldnn_OhIw16o32i,
755         OhIw8o4i_s8s8 = mkldnn_OhIw8o4i_s8s8,
756         goiw = mkldnn_goiw,
757         gOwi4o = mkldnn_gOwi4o,
758         gOIw4i4o = mkldnn_gOIw4i4o,
759         gOwi8o = mkldnn_gOwi8o,
760         gOIw8o8i = mkldnn_gOIw8o8i,
761         gOIw8i8o = mkldnn_gOIw8i8o,
762         gOIw16i16o = mkldnn_gOIw16i16o,
763         gOIw16o16i = mkldnn_gOIw16o16i,
764         gOiw4o = mkldnn_gOiw4o,
765         gOiw16o = mkldnn_gOiw16o,
766         gOwi16o = mkldnn_gOwi16o,
767         gIOw16o16i = mkldnn_gIOw16o16i,
768         gOIw8i16o2i = mkldnn_gOIw8i16o2i,
769         gOIw8o16i2o = mkldnn_gOIw8o16i2o,
770         gIOw8o16i2o = mkldnn_gIOw8o16i2o,
771         gOIw4i16o4i = mkldnn_gOIw4i16o4i,
772         gOIw4i16o4i_s8s8 = mkldnn_gOIw4i16o4i_s8s8,
773         goihw = mkldnn_goihw,
774         hwigo = mkldnn_hwigo,
775         giohw = mkldnn_giohw,
776         hwigo_s8s8 = mkldnn_hwigo_s8s8,
777         dhwigo_s8s8 = mkldnn_dhwigo_s8s8,
778         dhwigo = mkldnn_dhwigo,
779         dhwio_s8s8 = mkldnn_dhwio_s8s8,
780         gOIdhw4i4o = mkldnn_gOIdhw4i4o,
781         gOdhwi4o = mkldnn_gOdhwi4o,
782         gOIdhw8i8o = mkldnn_gOIdhw8i8o,
783         gOIdhw8o8i = mkldnn_gOIdhw8o8i,
784         gOdhwi8o = mkldnn_gOdhwi8o,
785         gOIhw4i4o = mkldnn_gOIhw4i4o,
786         gOIhw8i8o = mkldnn_gOIhw8i8o,
787         gOIhw16i16o = mkldnn_gOIhw16i16o,
788         gOIhw8i16o2i = mkldnn_gOIhw8i16o2i,
789         gIOhw8i16o2i = mkldnn_gIOhw8i16o2i,
790         gOIhw8o16i2o = mkldnn_gOIhw8o16i2o,
791         gIOhw8o16i2o = mkldnn_gIOhw8o16i2o,
792         gOIdhw8i16o2i = mkldnn_gOIdhw8i16o2i,
793         gOIdhw8o16i2o = mkldnn_gOIdhw8o16i2o,
794         gIOdhw8o16i2o = mkldnn_gIOdhw8o16i2o,
795         gOIhw4i16o4i = mkldnn_gOIhw4i16o4i,
796         gOIhw4i16o4i_s8s8 = mkldnn_gOIhw4i16o4i_s8s8,
797         gOIhw2i8o4i = mkldnn_gOIhw2i8o4i,
798         gOIhw2i8o4i_s8s8 = mkldnn_gOIhw2i8o4i_s8s8,
799         gOihw8o = mkldnn_gOihw8o,
800         gOihw4o = mkldnn_gOihw4o,
801         gOihw16o = mkldnn_gOihw16o,
802         gOhwi4o = mkldnn_gOhwi4o,
803         gOhwi8o = mkldnn_gOhwi8o,
804         gOhwi16o = mkldnn_gOhwi16o,
805         Goihw8g = mkldnn_Goihw8g,
806         Goihw8g_s8s8 = mkldnn_Goihw8g_s8s8,
807         Goiw16g = mkldnn_Goiw16g,
808         Goiw16g_s8s8 = mkldnn_Goiw16g_s8s8,
809         Goihw16g = mkldnn_Goihw16g,
810         Goihw16g_s8s8 = mkldnn_Goihw16g_s8s8,
811         gOIhw4o4i = mkldnn_gOIhw4o4i,
812         gOIhw4o4i_s8s8 = mkldnn_gOIhw4o4i_s8s8,
813         gOIhw8o8i = mkldnn_gOIhw8o8i,
814         gOIhw16o16i = mkldnn_gOIhw16o16i,
815         gIOhw16o16i = mkldnn_gIOhw16o16i,
816         gOhIw16o4i = mkldnn_gOhIw16o4i,
817         gOhIw8o4i = mkldnn_gOhIw8o4i,
818         gOhIw8o4i_s8s8 = mkldnn_gOhIw8o4i_s8s8,
819         goidhw = mkldnn_goidhw,
820         gOIdhw16i16o = mkldnn_gOIdhw16i16o,
821         gOIdhw16o16i = mkldnn_gOIdhw16o16i,
822         gOidhw4o = mkldnn_gOidhw4o,
823         gOidhw16o = mkldnn_gOidhw16o,
824         gOdhwi16o = mkldnn_gOdhwi16o,
825         OdhIw8o4i = mkldnn_OdhIw8o4i,
826         OdhIw8o4i_s8s8 = mkldnn_OdhIw8o4i_s8s8,
827         gOdhIw8o4i = mkldnn_gOdhIw8o4i,
828         gOdhIw8o4i_s8s8 = mkldnn_gOdhIw8o4i_s8s8,
829         Goidhw8g = mkldnn_Goidhw8g,
830         Goidhw8g_s8s8 = mkldnn_Goidhw8g_s8s8,
831         Goidhw16g = mkldnn_Goidhw16g,
832         Goidhw16g_s8s8 = mkldnn_Goidhw16g_s8s8,
833         ntc = mkldnn_ntc,
834         tnc = mkldnn_tnc,
835         ldsnc = mkldnn_ldsnc,
836         ldigo = mkldnn_ldigo,
837         ldgoi = mkldnn_ldgoi,
838         ldgo = mkldnn_ldgo,
839         rnn_packed = mkldnn_rnn_packed,
840         wino_fmt = mkldnn_wino_fmt,
841         format_last = mkldnn_format_last,
842     };
843
844     /// A memory descriptor.
845     struct desc {
846         friend struct memory;
847         /// The underlying C API data structure.
848         mkldnn_memory_desc_t data;
849
850         /// Constructs a memory descriptor.
851         ///
852         /// @param adims Data dimensions
853         /// @param adata_type Data precision/type.
854         /// @param aformat Data layout format.
855         desc(dims adims, data_type adata_type,
856                 format aformat) {
857             validate_dims(adims);
858             error::wrap_c_api(
859                     mkldnn_memory_desc_init(&data, (int)adims.size(),
860                         adims.size() == 0 ? nullptr : &adims[0],
861                         convert_to_c(adata_type), convert_to_c(aformat)),
862                     "could not initialize a memory descriptor");
863         }
864
865         /// Constructs a memory descriptor from a C API data structure.
866         ///
867         /// @param adata A C API #mkldnn_memory_desc_t structure.
868         desc(const mkldnn_memory_desc_t &adata): data(adata) {}
869     };
870
871     /// A memory primitive descriptor.
872     struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
873         friend struct memory;
874
875         // TODO: make private
876         primitive_desc() {}
877
878         /// Constructs a memory primitive descriptor.
879         primitive_desc(const desc &adesc, const engine &aengine) {
880             mkldnn_primitive_desc_t result;
881             error::wrap_c_api(
882                     mkldnn_memory_primitive_desc_create(&result,
883                         &adesc.data, aengine.get()),
884                     "could not initialize a memory primitive descriptor");
885             reset(result);
886         }
887
888         /// Returns the memory primitive descriptor.
889         memory::desc desc() {
890             auto memory_d = mkldnn_primitive_desc_query_memory_d(get());
891             return memory::desc(*memory_d); }
892
893         /// Returns the number of bytes required to allocate the memory described
894         /// including the padding area.
895         size_t get_size() const {
896              return mkldnn_memory_primitive_desc_get_size(get());
897         }
898
899         bool operator==(const primitive_desc &other) const {
900             return (0 == mkldnn_memory_primitive_desc_equal(get(),
901                         other.get())) ? false : true;
902         }
903
904         bool operator!=(const primitive_desc &other) const {
905             return !operator==(other);
906         }
907
908         engine get_engine() { return engine::query(*this); }
909     };
910
911     /// Constructs a memory primitive from a generic primitive.
912     ///
913     /// @param aprimitive The primitive to treat as memory.
914     memory(const primitive &aprimitive): primitive(aprimitive) {}
915     /// Constructs a memory primitive.
916     ///
917     /// @param adesc Memory primitive descriptor.
918     memory(const primitive_desc &adesc) {
919         mkldnn_primitive_t result;
920         error::wrap_c_api(
921                 mkldnn_primitive_create(&result, adesc.get(), nullptr, nullptr),
922                 "could not create a memory primitive");
923         reset(result);
924         auto _malloc = [](size_t size, int alignment) {
925             void *ptr;
926 #ifdef _WIN32
927             ptr = _aligned_malloc(size, alignment);
928             int rc = ((ptr)? 0 : errno);
929 #else
930             int rc = ::posix_memalign(&ptr, alignment, size);
931 #endif /* _WIN32 */
932             return (rc == 0) ? (char*)ptr : nullptr;
933         };
934         auto _free = [](char* p) {
935 #ifdef _WIN32
936             _aligned_free((void*)p);
937 #else
938             ::free((void*)p);
939 #endif /* _WIN32 */
940         };
941         _handle.reset(_malloc(adesc.get_size(), 4096), _free);
942         set_data_handle(_handle.get());
943     }
944
945     memory(const primitive_desc &adesc, void *ahandle) {
946         mkldnn_primitive_t result;
947         error::wrap_c_api(
948                 mkldnn_primitive_create(&result, adesc.get(), nullptr, nullptr),
949                 "could not create a memory primitive");
950         reset(result);
951         set_data_handle(ahandle);
952     }
953
954     /// Returns the descriptor of the memory primitive.
955     primitive_desc get_primitive_desc() const {
956         primitive_desc adesc;
957         const_mkldnn_primitive_desc_t cdesc;
958         error::wrap_c_api(mkldnn_primitive_get_primitive_desc(get(),
959                     &cdesc),
960                 "could not get primitive descriptor from a memory primitive");
961         /* FIXME: no const_cast should be here */
962         adesc.reset(const_cast<mkldnn_primitive_desc_t>(cdesc), true);
963         return adesc;
964     }
965
966     /// Returns a handle of the data contained in the memory primitive. On
967     /// the CPU engine, this is a pointer to the allocated memory.
968     inline void *get_data_handle() const {
969         void *handle;
970         error::wrap_c_api(mkldnn_memory_get_data_handle(get(), &handle),
971                 "could not get native handle");
972         return handle;
973     }
974
975     inline void set_data_handle(void *handle) const {
976         error::wrap_c_api(mkldnn_memory_set_data_handle(get(), handle),
977                 "could not set native handle");
978     }
979
980     inline void set_data_handle_no_pads_proc(void *handle) const {
981         error::wrap_c_api(mkldnn_memory_set_data_handle_no_pads_proc(get(), handle),
982                           "could not set native handle");
983     }
984
985     // Must go away or be private:
986     static mkldnn_data_type_t convert_to_c(data_type adata_type) {
987         return static_cast<mkldnn_data_type_t>(adata_type);
988     }
989     static mkldnn_memory_format_t convert_to_c(format aformat) {
990         return static_cast<mkldnn_memory_format_t>(aformat);
991     }
992 };
993
994 inline memory::desc zero_md() {
995     auto zero = mkldnn_memory_desc_t();
996     zero.primitive_kind = mkldnn_memory;
997     return memory::desc(zero);
998 }
999
1000 inline memory null_memory(engine eng) {
1001     mkldnn::memory::desc zero = zero_md();
1002     return memory({zero, eng}, nullptr);
1003 }
1004
1005 inline void check_num_parameters(const const_mkldnn_primitive_desc_t
1006     &aprimitive_desc, int n_inputs, int n_outputs,
1007     const std::string &prim_name) {
1008     const int n_inputs_expected = mkldnn_primitive_desc_query_s32(
1009             aprimitive_desc, mkldnn_query_num_of_inputs_s32, 0);
1010     const int n_outputs_expected = mkldnn_primitive_desc_query_s32(
1011             aprimitive_desc, mkldnn_query_num_of_outputs_s32, 0);
1012     if (n_outputs_expected > n_outputs ) {
1013         std::string message = "could not create " + prim_name +
1014             " primitive, not enought output parameters";
1015         throw error(mkldnn_invalid_arguments, message, nullptr);
1016     }
1017     if (n_inputs_expected > n_inputs ) {
1018         std::string message = "could not create " + prim_name +
1019             " primitive, not enought input parameters";
1020         throw error(mkldnn_invalid_arguments, message, nullptr);
1021     }
1022 }
1023
1024
1025 inline bool is_null_memory(const const_mkldnn_primitive_t &aprimitive) {
1026     const_mkldnn_primitive_desc_t aprimitive_pd;
1027     mkldnn_primitive_get_primitive_desc(aprimitive, &aprimitive_pd);
1028     const mkldnn_memory_desc_t *aprimitive_md = mkldnn_primitive_desc_query_memory_d(
1029         aprimitive_pd);
1030
1031     return ((aprimitive_md != nullptr) && (aprimitive_md->ndims == 0));
1032 }
1033
1034 inline bool operator==(mkldnn_data_type_t a, memory::data_type b) {
1035     return a == memory::convert_to_c(b);
1036 }
1037 inline bool operator!=(mkldnn_data_type_t a, memory::data_type b) {
1038     return !(a == b);
1039 }
1040 inline bool operator==(memory::data_type a, mkldnn_data_type_t b) {
1041     return b == a;
1042 }
1043 inline bool operator!=(memory::data_type a, mkldnn_data_type_t b) {
1044     return !(a == b);
1045 }
1046
1047 inline bool operator==(mkldnn_memory_format_t a, memory::format b) {
1048     return a == memory::convert_to_c(b);
1049 }
1050 inline bool operator!=(mkldnn_memory_format_t a, memory::format b) {
1051     return !(a == b);
1052 }
1053 inline bool operator==(memory::format a, mkldnn_memory_format_t b) {
1054     return b == a;
1055 }
1056 inline bool operator!=(memory::format a, mkldnn_memory_format_t b) {
1057     return !(a == b);
1058 }
1059
1060 /// @}
1061
1062 /// @addtogroup cpp_api_reorder Reorder
1063 /// A primitive to copy data between memory formats.
1064 ///
1065 /// @sa @ref c_api_reorder in @ref c_api
1066 /// @{
1067
1068 struct reorder : public primitive {
1069     struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1070         primitive_desc(const memory::primitive_desc &input,
1071                        const memory::primitive_desc &output) {
1072             mkldnn_primitive_desc_t result;
1073             error::wrap_c_api(mkldnn_reorder_primitive_desc_create(
1074                         &result, input.get(), output.get()),
1075                     "could not create a reorder primitive descriptor");
1076             reset(result);
1077         }
1078
1079         primitive_desc(const memory::primitive_desc &input,
1080                 const memory::primitive_desc &output,
1081                 const primitive_attr &aattr) {
1082             mkldnn_primitive_desc_t result;
1083             error::wrap_c_api(mkldnn_reorder_primitive_desc_create_v2(
1084                         &result, input.get(), output.get(), aattr.get()),
1085                     "could not create a reorder primitive descriptor");
1086             reset(result);
1087         }
1088
1089         engine get_engine() { return engine::query(*this); }
1090     };
1091
1092     reorder(const primitive_desc &aprimitive_desc,
1093             const primitive::at &input, const memory &output) {
1094         mkldnn_primitive_t result;
1095         mkldnn_primitive_at_t inputs[] = { input.data };
1096         const_mkldnn_primitive_t outputs[] = { output.get() };
1097         error::wrap_c_api(mkldnn_primitive_create(&result,
1098                     aprimitive_desc.get(), inputs, outputs),
1099                 "could not create a reorder primitive");
1100         reset(result);
1101     }
1102
1103     reorder(const primitive::at &input, const memory &output) {
1104         auto input_mpd = memory(input).get_primitive_desc();
1105         auto output_mpd = output.get_primitive_desc();
1106
1107         auto reorder_d = primitive_desc(input_mpd, output_mpd);
1108
1109         mkldnn_primitive_t result;
1110         mkldnn_primitive_at_t inputs[] = { input.data };
1111         const_mkldnn_primitive_t outputs[] = { output.get() };
1112         error::wrap_c_api(mkldnn_primitive_create(&result,
1113                     reorder_d.get(), inputs, outputs),
1114                 "could not create a reorder primitive");
1115         reset(result);
1116     }
1117 };
1118
1119 /// @}
1120
1121 /// @addtogroup cpp_api_view View
1122 /// A primitive to view on a memory.
1123 ///
1124 /// @sa mkldnn_view_primitive_desc_create in @ref c_api
1125 /// @{
1126
1127 struct view : public primitive {
1128     struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1129         primitive_desc(const memory::primitive_desc &input, memory::dims dims,
1130                 memory::dims offsets) {
1131             mkldnn_primitive_desc_t result;
1132
1133             error::wrap_c_api(mkldnn_view_primitive_desc_create(
1134                     &result, input.get(), &dims[0], &offsets[0]),
1135                 "could not create a view primitive descriptor");
1136             reset(result);
1137         }
1138
1139         memory::primitive_desc dst_primitive_desc() const {
1140             memory::primitive_desc adesc;
1141             mkldnn_primitive_desc_t cdesc;
1142             const_mkldnn_primitive_desc_t const_cdesc =
1143                 mkldnn_primitive_desc_query_pd(get(),
1144                                mkldnn::convert_to_c(dst_pd), 0);
1145             error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc,
1146                         const_cdesc),
1147                     "could not clone a dst primitive descriptor");
1148             adesc.reset(cdesc);
1149             return adesc;
1150         }
1151
1152         engine get_engine() { return engine::query(*this); }
1153     };
1154
1155     view(const primitive_desc &view_pd, primitive::at input) {
1156         mkldnn_primitive_t result;
1157         mkldnn_primitive_at_t inputs[] = { input.data };
1158         error::wrap_c_api(mkldnn_primitive_create(&result,
1159                     view_pd.get(), inputs, nullptr),
1160                 "could not create a view primitive");
1161         reset(result);
1162     }
1163
1164     view(memory input, memory::dims dims, memory::dims offsets) {
1165         mkldnn_primitive_t result;
1166         primitive_desc view_pd(input.get_primitive_desc(), dims,
1167                 offsets);
1168         mkldnn_primitive_at_t inputs[] = { primitive::at(input).data };
1169         error::wrap_c_api(mkldnn_primitive_create(&result,
1170                     view_pd.get(), inputs, nullptr),
1171                 "could not create a view primitive");
1172         reset(result);
1173     }
1174 };
1175
1176 /// @}
1177
1178 /// @addtogroup cpp_api_concat Concat
1179 /// A primitive to concatenate data by arbitrary dimension.
1180 ///
1181 /// @sa @ref c_api_concat in @ref c_api
1182 /// @{
1183
1184 struct concat : public primitive {
1185     struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1186         std::vector<const_mkldnn_primitive_desc_t> cpp_to_c(
1187                 std::vector<memory::primitive_desc> inputs) {
1188             std::vector<const_mkldnn_primitive_desc_t> c_api_inputs;
1189             c_api_inputs.reserve(inputs.size());
1190             auto convert_to_c = [](memory::primitive_desc d) { return d.get(); };
1191             std::transform(inputs.begin(), inputs.end(),
1192                     std::back_inserter(c_api_inputs), convert_to_c);
1193             return c_api_inputs;
1194         }
1195
1196         primitive_desc(const memory::desc &output, int concat_dimension,
1197                 std::vector<memory::primitive_desc> inputs) {
1198             mkldnn_primitive_desc_t result;
1199
1200             auto c_api_inputs = cpp_to_c(inputs);
1201
1202             error::wrap_c_api(mkldnn_concat_primitive_desc_create(
1203                     &result, &output.data, (int)c_api_inputs.size(),
1204                     concat_dimension, &c_api_inputs[0]),
1205                 "could not create a concat primitive descriptor");
1206             reset(result);
1207         }
1208
1209         primitive_desc(int concat_dimension,
1210                 std::vector<memory::primitive_desc> inputs) {
1211             mkldnn_primitive_desc_t result;
1212
1213             auto c_api_inputs = cpp_to_c(inputs);
1214
1215             error::wrap_c_api(mkldnn_concat_primitive_desc_create(
1216                     &result, nullptr, (int)c_api_inputs.size(),
1217                     concat_dimension, &c_api_inputs[0]),
1218                 "could not create a concat primitive descriptor");
1219             reset(result);
1220         }
1221
1222         memory::primitive_desc dst_primitive_desc() const {
1223             memory::primitive_desc adesc;
1224             mkldnn_primitive_desc_t cdesc;
1225             const_mkldnn_primitive_desc_t const_cdesc =
1226                 mkldnn_primitive_desc_query_pd(get(),
1227                                mkldnn::convert_to_c(dst_pd), 0);
1228             error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1229                     "could not clone a dst primitive descriptor");
1230             adesc.reset(cdesc);
1231             return adesc;
1232         }
1233
1234         engine get_engine() { return engine::query(*this); }
1235     };
1236
1237     concat(const primitive_desc &concat_pd,
1238             std::vector<primitive::at> &inputs, const memory &output) {
1239         mkldnn_primitive_t result;
1240
1241         std::vector<mkldnn_primitive_at_t> p_inputs;
1242         for (size_t i = 0; i < inputs.size(); i++)
1243             p_inputs.push_back(inputs[i].data);
1244         const_mkldnn_primitive_t outputs[] = { output.get() };
1245
1246         error::wrap_c_api(mkldnn_primitive_create(&result,
1247                     concat_pd.get(), &p_inputs[0], outputs),
1248                 "could not create a concat primitive");
1249         reset(result);
1250     }
1251 };
1252
1253 /// @}
1254
1255 /// @addtogroup cpp_api_sum Sum
1256 /// A primitive to sum data.
1257 ///
1258 /// @sa @ref c_api_sum in @ref c_api
1259 /// @{
1260
1261 struct sum : public primitive {
1262     struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1263         std::vector<const_mkldnn_primitive_desc_t> cpp_to_c(
1264                 std::vector<memory::primitive_desc> inputs) {
1265             std::vector<const_mkldnn_primitive_desc_t> c_api_inputs;
1266             c_api_inputs.reserve(inputs.size());
1267             auto convert_to_c = [](memory::primitive_desc d) { return d.get();};
1268             std::transform(inputs.begin(), inputs.end(),
1269                     std::back_inserter(c_api_inputs), convert_to_c);
1270             return c_api_inputs;
1271         }
1272
1273         primitive_desc(const memory::desc &output,
1274                 const std::vector<float> &scales,
1275                 std::vector<memory::primitive_desc> inputs) {
1276             mkldnn_primitive_desc_t result;
1277
1278             auto c_api_inputs = cpp_to_c(inputs);
1279
1280             error::wrap_c_api(
1281                 scales.size() == inputs.size() ? mkldnn_success
1282                                                : mkldnn_invalid_arguments,
1283                 "number of scales not equal to number of inputs");
1284
1285             error::wrap_c_api(mkldnn_sum_primitive_desc_create(
1286                     &result, &output.data, (int)c_api_inputs.size(),
1287                     &scales[0], &c_api_inputs[0]),
1288                 "could not create a sum primitive descriptor");
1289             reset(result);
1290         }
1291
1292         primitive_desc(const std::vector<float> &scales,
1293                 std::vector<memory::primitive_desc> inputs) {
1294             mkldnn_primitive_desc_t result;
1295
1296             auto c_api_inputs = cpp_to_c(inputs);
1297
1298             error::wrap_c_api(
1299                 scales.size() == inputs.size() ? mkldnn_success
1300                                                : mkldnn_invalid_arguments,
1301                 "number of scales not equal to number of inputs");
1302
1303             error::wrap_c_api(mkldnn_sum_primitive_desc_create(
1304                     &result, nullptr, (int)c_api_inputs.size(), &scales[0],
1305                     &c_api_inputs[0]),
1306                 "could not create a sum primitive descriptor");
1307             reset(result);
1308         }
1309
1310         memory::primitive_desc dst_primitive_desc() const {
1311             memory::primitive_desc adesc;
1312             mkldnn_primitive_desc_t cdesc;
1313             const_mkldnn_primitive_desc_t const_cdesc =
1314                 mkldnn_primitive_desc_query_pd(get(),
1315                                mkldnn::convert_to_c(dst_pd), 0);
1316             error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc,
1317                     const_cdesc),
1318                     "could not clone a dst primitive descriptor");
1319             adesc.reset(cdesc);
1320             return adesc;
1321         }
1322
1323         engine get_engine() { return engine::query(*this); }
1324     };
1325
1326     sum(const primitive_desc &sum_pd,
1327             std::vector<primitive::at> &inputs, const memory &output) {
1328         mkldnn_primitive_t result;
1329
1330         std::vector<mkldnn_primitive_at_t> p_inputs;
1331         for (size_t i = 0; i < inputs.size(); i++)
1332             p_inputs.push_back(inputs[i].data);
1333         const_mkldnn_primitive_t outputs[] = { output.get() };
1334
1335         error::wrap_c_api(mkldnn_primitive_create(&result,
1336                     sum_pd.get(), &p_inputs[0], outputs),
1337                 "could not create a sum primitive");
1338         reset(result);
1339     }
1340 };
1341
1342 /// @}
1343
1344 /// @}
1345
1346 /// @addtogroup cpp_api_primitives Primitives
1347 /// @{
1348
1349 /// @addtogroup cpp_api_primitive_descriptors Primitive descriptors
1350 /// @{
1351
1352 /// A base class for all primitive descriptors.
1353 struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1354     primitive_desc(const_mkldnn_op_desc_t desc, const primitive_attr *attr,
1355             const engine &e, const_mkldnn_primitive_desc_t hint_fwd_pd) {
1356         mkldnn_primitive_desc_iterator_t iterator = nullptr;
1357         mkldnn_status_t status = mkldnn_primitive_desc_iterator_create_v2(
1358                 &iterator, desc, attr ? attr->get() : nullptr, e.get(),
1359                 hint_fwd_pd);
1360         error::wrap_c_api(status,
1361                 "could not create a primitive descriptor iterator");
1362         pd_iterator.reset(iterator);
1363         fetch_impl();
1364     }
1365
1366     engine get_engine() { return engine::query(*this); }
1367
1368     primitive_attr get_primitive_attr() const {
1369         const_mkldnn_primitive_attr_t const_cattr;
1370         error::wrap_c_api(mkldnn_primitive_desc_get_attr(get(), &const_cattr),
1371                 "could not get attributes");
1372         mkldnn_primitive_attr_t cattr;
1373         error::wrap_c_api(mkldnn_primitive_attr_clone(&cattr, const_cattr),
1374                 "could not clone attributes");
1375
1376         primitive_attr attr;
1377         attr.reset(cattr);
1378         return attr;
1379     }
1380
1381     /// Returns implementation name
1382     const char *impl_info_str() const {
1383         const char *res;
1384         error::wrap_c_api(mkldnn_primitive_desc_query(get(),
1385                     mkldnn_query_impl_info_str, 0, &res),
1386                 "could not query implementation info string");
1387         return res;
1388     }
1389
1390     /// Advances the next implementation for the given op descriptor.
1391     ///
1392     /// Returns:
1393     /// - @c true on success
1394     /// - @c false if the last implementation reached, and
1395     ///   the primitive descriptor itself is kept unchanged
1396     bool next_impl() {
1397         mkldnn_status_t status = mkldnn_primitive_desc_iterator_next(
1398                 pd_iterator.get());
1399         if (status == mkldnn_iterator_ends) return false;
1400         error::wrap_c_api(status, "primitive descriptor iterator next failed");
1401
1402         fetch_impl();
1403         return true;
1404     }
1405
1406     /// Queries and returns requested memory primitive descriptor.
1407     memory::primitive_desc query_mpd(query what, int idx = 0) const {
1408         std::vector<query> valid_w{input_pd, output_pd, src_pd, diff_src_pd,
1409             weights_pd, diff_weights_pd, dst_pd, diff_dst_pd, workspace_pd};
1410         if (!std::any_of(valid_w.cbegin(), valid_w.cend(),
1411                     [=](query q) { return what == q; }))
1412             throw error(mkldnn_invalid_arguments, "invalid memory query");
1413
1414         const_mkldnn_primitive_desc_t const_cdesc
1415             = mkldnn_primitive_desc_query_pd(get(),
1416                     mkldnn::convert_to_c(what), idx);
1417
1418         // TODO: is there a better way to inform about this?
1419         if (const_cdesc == nullptr)
1420             throw error(mkldnn_not_required, "queried memory is not required");
1421
1422         mkldnn_primitive_desc_t cdesc;
1423         error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1424                 "could not clone a memory primitive descriptor");
1425
1426         memory::primitive_desc ret;
1427         ret.reset(cdesc);
1428         return ret;
1429     }
1430
1431     // register specialized queries, e.g. src_primitive_desc()
1432 #   define REG_QUERY_MPD(name, what, idx) \
1433     memory::primitive_desc name ## _primitive_desc() const \
1434     { return query_mpd(what ## _pd, idx); }
1435
1436   private:
1437     handle<mkldnn_primitive_desc_iterator_t> pd_iterator;
1438     void fetch_impl() {
1439         mkldnn_primitive_desc_t pd = mkldnn_primitive_desc_iterator_fetch(
1440                 pd_iterator.get());
1441         error::wrap_c_api(pd != nullptr ? mkldnn_success : mkldnn_runtime_error,
1442                 "could not fetch a primitive descriptor from the iterator");
1443         reset(pd);
1444     }
1445 };
1446
1447 /// @}
1448
1449 /// @addtogroup cpp_api_convolution Convolution
1450 /// A primitive to compute convolution using different algorithms.
1451 ///
1452 /// @sa @ref c_api_convolution in @ref c_api
1453 /// @{
1454
1455 struct convolution_forward: public primitive {
1456     struct desc {
1457         mkldnn_convolution_desc_t data;
1458         desc(prop_kind aprop_kind, algorithm aalgorithm,
1459                 const memory::desc &src_desc,
1460                 const memory::desc &weights_desc,
1461                 const memory::desc &bias_desc,
1462                 const memory::desc &dst_desc,
1463                 const memory::dims strides,
1464                 const memory::dims padding_l,
1465                 const memory::dims padding_r,
1466                 const padding_kind apadding_kind) {
1467             memory::validate_dims(strides);
1468             memory::validate_dims(padding_l);
1469             memory::validate_dims(padding_r);
1470             error::wrap_c_api(mkldnn_convolution_forward_desc_init(&data,
1471                         mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1472                         &src_desc.data, &weights_desc.data, &bias_desc.data,
1473                         &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
1474                         mkldnn::convert_to_c(apadding_kind)),
1475                     "could not create a convolution forward descriptor");
1476         }
1477         desc(prop_kind aprop_kind, algorithm aalgorithm,
1478                 const memory::desc &src_desc,
1479                 const memory::desc &weights_desc,
1480                 const memory::desc &dst_desc,
1481                 const memory::dims strides,
1482                 const memory::dims padding_l,
1483                 const memory::dims padding_r,
1484                 const padding_kind apadding_kind) {
1485             memory::validate_dims(strides);
1486             memory::validate_dims(padding_l);
1487             memory::validate_dims(padding_r);
1488             error::wrap_c_api(mkldnn_convolution_forward_desc_init(&data,
1489                         mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1490                         &src_desc.data, &weights_desc.data, nullptr,
1491                         &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
1492                         mkldnn::convert_to_c(apadding_kind)),
1493                     "could not create a convolution forward descriptor");
1494         }
1495         desc(prop_kind aprop_kind, algorithm aalgorithm,
1496                 const memory::desc &src_desc,
1497                 const memory::desc &weights_desc,
1498                 const memory::desc &bias_desc,
1499                 const memory::desc &dst_desc,
1500                 const memory::dims strides,
1501                 const memory::dims dilates,
1502                 const memory::dims padding_l,
1503                 const memory::dims padding_r,
1504                 const padding_kind apadding_kind) {
1505             memory::validate_dims(strides);
1506             memory::validate_dims(dilates);
1507             memory::validate_dims(padding_l);
1508             memory::validate_dims(padding_r);
1509             error::wrap_c_api(
1510                 mkldnn_dilated_convolution_forward_desc_init(&data,
1511                     mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1512                         &src_desc.data, &weights_desc.data, &bias_desc.data,
1513                         &dst_desc.data, &strides[0], &dilates[0],
1514                         &padding_l[0], &padding_r[0],
1515                         mkldnn::convert_to_c(apadding_kind)),
1516                     "could not create a dilated convolution forward descriptor");
1517         }
1518         desc(prop_kind aprop_kind, algorithm aalgorithm,
1519                 const memory::desc &src_desc,
1520                 const memory::desc &weights_desc,
1521                 const memory::desc &dst_desc,
1522                 const memory::dims strides,
1523                 const memory::dims dilates,
1524                 const memory::dims padding_l,
1525                 const memory::dims padding_r,
1526                 const padding_kind apadding_kind) {
1527             memory::validate_dims(strides);
1528             memory::validate_dims(dilates);
1529             memory::validate_dims(padding_l);
1530             memory::validate_dims(padding_r);
1531             error::wrap_c_api(
1532                 mkldnn_dilated_convolution_forward_desc_init(&data,
1533                     mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1534                         &src_desc.data, &weights_desc.data, nullptr,
1535                         &dst_desc.data, &strides[0], &dilates[0],
1536                         &padding_l[0], &padding_r[0],
1537                         mkldnn::convert_to_c(apadding_kind)),
1538                     "could not create a dilated convolution forward descriptor");
1539         }
1540     };
1541
1542     struct primitive_desc : public mkldnn::primitive_desc {
1543         primitive_desc(const desc &desc, const engine &e)
1544             : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
1545
1546         primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
1547             : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
1548
1549         REG_QUERY_MPD(src, src, 0);
1550         REG_QUERY_MPD(weights, weights, 0);
1551         REG_QUERY_MPD(bias, weights, 1);
1552         REG_QUERY_MPD(dst, dst, 0);
1553     };
1554
1555     convolution_forward(const primitive_desc &aprimitive_desc,
1556             const primitive::at &src, const primitive::at &weights,
1557             const primitive::at &bias, const memory &dst) {
1558         mkldnn_primitive_t result;
1559         mkldnn_primitive_at_t inputs[] = { src.data, weights.data,
1560                     bias.data };
1561         const_mkldnn_primitive_t outputs[] = { dst.get() };
1562         error::wrap_c_api(mkldnn_primitive_create(&result,
1563                     aprimitive_desc.get(), inputs, outputs),
1564                 "could not create a convolution forward bias primitive");
1565         reset(result);
1566     }
1567
1568     convolution_forward(const primitive_desc &aprimitive_desc,
1569             const primitive::at &src, const primitive::at &weights,
1570             const memory &dst) {
1571         mkldnn_primitive_t result;
1572         mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
1573         const_mkldnn_primitive_t outputs[] = { dst.get() };
1574         check_num_parameters(aprimitive_desc.get(), 2, 1,
1575             "convolution forward");
1576         error::wrap_c_api(mkldnn_primitive_create(&result,
1577                     aprimitive_desc.get(), inputs, outputs),
1578                 "could not create a convolution forward primitive");
1579         reset(result);
1580     }
1581 };
1582
1583 struct convolution_backward_data : public primitive {
1584     struct desc {
1585         mkldnn_convolution_desc_t data;
1586         desc(algorithm aalgorithm,
1587                 const memory::desc &diff_src_desc,
1588                 const memory::desc &weights_desc,
1589                 const memory::desc &diff_dst_desc,
1590                 const memory::dims strides,
1591                 const memory::dims padding_l,
1592                 const memory::dims padding_r,
1593                 const padding_kind apadding_kind) {
1594             memory::validate_dims(strides);
1595             memory::validate_dims(padding_l);
1596             memory::validate_dims(padding_r);
1597             error::wrap_c_api(mkldnn_convolution_backward_data_desc_init(
1598                         &data, convert_to_c(aalgorithm), &diff_src_desc.data,
1599                         &weights_desc.data, &diff_dst_desc.data,
1600                         &strides[0], &padding_l[0], &padding_r[0],
1601                         mkldnn::convert_to_c(apadding_kind)),
1602                     "could not create a convolution backward data descriptor");
1603         }
1604         desc(algorithm aalgorithm,
1605                 const memory::desc &diff_src_desc,
1606                 const memory::desc &weights_desc,
1607                 const memory::desc &diff_dst_desc,
1608                 const memory::dims strides,
1609                 const memory::dims dilates,
1610                 const memory::dims padding_l,
1611                 const memory::dims padding_r,
1612                 const padding_kind apadding_kind) {
1613             memory::validate_dims(strides);
1614             memory::validate_dims(dilates);
1615             memory::validate_dims(padding_l);
1616             memory::validate_dims(padding_r);
1617             error::wrap_c_api(
1618                 mkldnn_dilated_convolution_backward_data_desc_init(
1619                     &data, convert_to_c(aalgorithm), &diff_src_desc.data,
1620                     &weights_desc.data, &diff_dst_desc.data,
1621                     &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1622                     mkldnn::convert_to_c(apadding_kind)),
1623                     "could not create a convolution backward data descriptor");
1624         }
1625     };
1626
1627     struct primitive_desc : public mkldnn::primitive_desc {
1628         primitive_desc(const desc &desc, const engine &e,
1629                 const convolution_forward::primitive_desc &hint_fwd_pd)
1630             : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
1631
1632         primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
1633                 const convolution_forward::primitive_desc &hint_fwd_pd)
1634             : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
1635
1636         REG_QUERY_MPD(diff_src, diff_src, 0);
1637         REG_QUERY_MPD(weights, weights, 0);
1638         REG_QUERY_MPD(diff_dst, diff_dst, 0);
1639     };
1640
1641     convolution_backward_data(const primitive_desc &aprimitive_desc,
1642             const primitive::at &diff_dst, const primitive::at &weights,
1643             const memory &diff_src) {
1644         mkldnn_primitive_t result;
1645         mkldnn_primitive_at_t inputs[] = { diff_dst.data, weights.data  };
1646         const_mkldnn_primitive_t outputs[] = { diff_src.get() };
1647         check_num_parameters(aprimitive_desc.get(), 2, 1,
1648             "convolution backward data");
1649         error::wrap_c_api(mkldnn_primitive_create(&result,
1650                     aprimitive_desc.get(), inputs, outputs),
1651                 "could not create a convolution backward data primitive");
1652         reset(result);
1653     }
1654 };
1655
1656 struct convolution_backward_weights : public primitive {
1657     struct desc {
1658         mkldnn_convolution_desc_t data;
1659         desc(algorithm aalgorithm,
1660                 const memory::desc &src_desc,
1661                 const memory::desc &diff_weights_desc,
1662                 const memory::desc &diff_bias_desc,
1663                 const memory::desc &diff_dst_desc,
1664                 const memory::dims strides,
1665                 const memory::dims padding_l,
1666                 const memory::dims padding_r,
1667                 const padding_kind apadding_kind) {
1668             memory::validate_dims(strides);
1669             memory::validate_dims(padding_l);
1670             memory::validate_dims(padding_r);
1671             error::wrap_c_api(mkldnn_convolution_backward_weights_desc_init(
1672                         &data, convert_to_c(aalgorithm), &src_desc.data,
1673                         &diff_weights_desc.data, &diff_bias_desc.data,
1674                         &diff_dst_desc.data,
1675                         &strides[0], &padding_l[0], &padding_r[0],
1676                         mkldnn::convert_to_c(apadding_kind)),
1677                     "could not create a convolution backward weights descriptor");
1678         }
1679         desc(algorithm aalgorithm,
1680                 const memory::desc &src_desc,
1681                 const memory::desc &diff_weights_desc,
1682                 const memory::desc &diff_dst_desc,
1683                 const memory::dims strides,
1684                 const memory::dims padding_l,
1685                 const memory::dims padding_r,
1686                 const padding_kind apadding_kind) {
1687             memory::validate_dims(strides);
1688             memory::validate_dims(padding_l);
1689             memory::validate_dims(padding_r);
1690             error::wrap_c_api(mkldnn_convolution_backward_weights_desc_init(
1691                         &data, convert_to_c(aalgorithm), &src_desc.data,
1692                         &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
1693                         &strides[0], &padding_l[0], &padding_r[0],
1694                         mkldnn::convert_to_c(apadding_kind)),
1695                     "could not create a convolution backward weights descriptor");
1696         }
1697         desc(algorithm aalgorithm,
1698                 const memory::desc &src_desc,
1699                 const memory::desc &diff_weights_desc,
1700                 const memory::desc &diff_bias_desc,
1701                 const memory::desc &diff_dst_desc,
1702                 const memory::dims strides,
1703                 const memory::dims dilates,
1704                 const memory::dims padding_l,
1705                 const memory::dims padding_r,
1706                 const padding_kind apadding_kind) {
1707             memory::validate_dims(strides);
1708             memory::validate_dims(dilates);
1709             memory::validate_dims(padding_l);
1710             memory::validate_dims(padding_r);
1711             error::wrap_c_api(mkldnn_dilated_convolution_backward_weights_desc_init(
1712                         &data, convert_to_c(aalgorithm), &src_desc.data,
1713                         &diff_weights_desc.data, &diff_bias_desc.data,
1714                         &diff_dst_desc.data,
1715                         &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1716                         mkldnn::convert_to_c(apadding_kind)),
1717                     "could not create a convolution backward weights descriptor");
1718         }
1719         desc(algorithm aalgorithm,
1720                 const memory::desc &src_desc,
1721                 const memory::desc &diff_weights_desc,
1722                 const memory::desc &diff_dst_desc,
1723                 const memory::dims strides,
1724                 const memory::dims dilates,
1725                 const memory::dims padding_l,
1726                 const memory::dims padding_r,
1727                 const padding_kind apadding_kind) {
1728             memory::validate_dims(strides);
1729             memory::validate_dims(dilates);
1730             memory::validate_dims(padding_l);
1731             memory::validate_dims(padding_r);
1732             error::wrap_c_api(mkldnn_dilated_convolution_backward_weights_desc_init(
1733                         &data, convert_to_c(aalgorithm), &src_desc.data,
1734                         &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
1735                         &strides[0], &dilates[0],  &padding_l[0], &padding_r[0],
1736                         mkldnn::convert_to_c(apadding_kind)),
1737                     "could not create a convolution backward weights descriptor");
1738         }
1739
1740     };
1741
1742     struct primitive_desc : public mkldnn::primitive_desc {
1743         primitive_desc(const desc &desc, const engine &e,
1744                 const convolution_forward::primitive_desc &hint_fwd_pd)
1745             : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
1746
1747         primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
1748                 const convolution_forward::primitive_desc &hint_fwd_pd)
1749             : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
1750
1751         REG_QUERY_MPD(src, src, 0);
1752         REG_QUERY_MPD(diff_weights, diff_weights, 0);
1753         REG_QUERY_MPD(diff_bias, diff_weights, 1);
1754         REG_QUERY_MPD(diff_dst, diff_dst, 0);
1755     };
1756
1757     convolution_backward_weights(const primitive_desc &aprimitive_desc,
1758             const primitive::at &src, const primitive::at &diff_dst,
1759             const memory &diff_weights, const memory &diff_bias) {
1760         mkldnn_primitive_t result;
1761         mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
1762         const_mkldnn_primitive_t outputs[] = { diff_weights.get(),
1763                     diff_bias.get() };
1764         check_num_parameters(aprimitive_desc.get(), 2, 2,
1765             "convolution backward weights");
1766         error::wrap_c_api(mkldnn_primitive_create(&result,
1767                     aprimitive_desc.get(), inputs, outputs),
1768                 "could not create a convolution backward weights primitive");
1769         reset(result);
1770     }
1771     convolution_backward_weights(const primitive_desc &aprimitive_desc,
1772             const primitive::at &src, const primitive::at &diff_dst,
1773             const memory &diff_weights) {
1774         mkldnn_primitive_t result;
1775         mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
1776         const_mkldnn_primitive_t outputs[] = { diff_weights.get() };
1777         check_num_parameters(aprimitive_desc.get(), 2, 1,
1778             "convolution backward weights");
1779         error::wrap_c_api(mkldnn_primitive_create(&result,
1780                     aprimitive_desc.get(), inputs, outputs),
1781                 "could not create a convolution backward weights primitive");
1782         reset(result);
1783     }
1784 };
1785
1786 /// @}
1787
1788 /// @addtogroup cpp_api_deconvolution Deconvolution
1789 /// A primitive to compute deconvolution using different algorithms.
1790 ///
1791 /// @sa @ref c_api_deconvolution in @ref c_api
1792 /// @{
1793
1794 struct deconvolution_forward: public primitive {
1795     struct desc {
1796         mkldnn_deconvolution_desc_t data;
1797         desc(prop_kind aprop_kind, algorithm aalgorithm,
1798                 const memory::desc &src_desc,
1799                 const memory::desc &weights_desc,
1800                 const memory::desc &bias_desc,
1801                 const memory::desc &dst_desc,
1802                 const memory::dims strides,
1803                 const memory::dims padding_l,
1804                 const memory::dims padding_r,
1805                 const padding_kind apadding_kind) {
1806             memory::validate_dims(strides);
1807             memory::validate_dims(padding_l);
1808             memory::validate_dims(padding_r);
1809             error::wrap_c_api(mkldnn_deconvolution_forward_desc_init(&data,
1810                         mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1811                         &src_desc.data, &weights_desc.data, &bias_desc.data,
1812                         &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
1813                         mkldnn::convert_to_c(apadding_kind)),
1814                     "could not create a deconvolution forward descriptor");
1815         }
1816         desc(prop_kind aprop_kind, algorithm aalgorithm,
1817                 const memory::desc &src_desc,
1818                 const memory::desc &weights_desc,
1819                 const memory::desc &dst_desc,
1820                 const memory::dims strides,
1821                 const memory::dims padding_l,
1822                 const memory::dims padding_r,
1823                 const padding_kind apadding_kind) {
1824             memory::validate_dims(strides);
1825             memory::validate_dims(padding_l);
1826             memory::validate_dims(padding_r);
1827             error::wrap_c_api(mkldnn_deconvolution_forward_desc_init(&data,
1828                         mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1829                         &src_desc.data, &weights_desc.data, nullptr,
1830                         &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
1831                         mkldnn::convert_to_c(apadding_kind)),
1832                     "could not create a deconvolution forward descriptor");
1833         }
1834         desc(prop_kind aprop_kind, algorithm aalgorithm,
1835                 const memory::desc &src_desc,
1836                 const memory::desc &weights_desc,
1837                 const memory::desc &bias_desc,
1838                 const memory::desc &dst_desc,
1839                 const memory::dims strides,
1840                 const memory::dims dilates,
1841                 const memory::dims padding_l,
1842                 const memory::dims padding_r,
1843                 const padding_kind apadding_kind) {
1844             memory::validate_dims(strides);
1845             memory::validate_dims(dilates);
1846             memory::validate_dims(padding_l);
1847             memory::validate_dims(padding_r);
1848             error::wrap_c_api(mkldnn_dilated_deconvolution_forward_desc_init(&data,
1849                         mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1850                         &src_desc.data, &weights_desc.data, &bias_desc.data,
1851                         &dst_desc.data, &strides[0], &dilates[0], &padding_l[0],
1852                         &padding_r[0], mkldnn::convert_to_c(apadding_kind)),
1853                     "could not create a dilated deconvolution forward descriptor");
1854         }
1855         desc(prop_kind aprop_kind, algorithm aalgorithm,
1856                 const memory::desc &src_desc,
1857                 const memory::desc &weights_desc,
1858                 const memory::desc &dst_desc,
1859                 const memory::dims strides,
1860                 const memory::dims dilates,
1861                 const memory::dims padding_l,
1862                 const memory::dims padding_r,
1863                 const padding_kind apadding_kind) {
1864             memory::validate_dims(strides);
1865             memory::validate_dims(dilates);
1866             memory::validate_dims(padding_l);
1867             memory::validate_dims(padding_r);
1868             error::wrap_c_api(mkldnn_dilated_deconvolution_forward_desc_init(&data,
1869                         mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1870                         &src_desc.data, &weights_desc.data, nullptr,
1871                         &dst_desc.data, &strides[0], &dilates[0], &padding_l[0],
1872                         &padding_r[0], mkldnn::convert_to_c(apadding_kind)),
1873                     "could not create a dilated deconvolution forward descriptor");
1874         }
1875     };
1876
1877     struct primitive_desc : public mkldnn::primitive_desc {
1878         primitive_desc(const desc &desc, const engine &e)
1879             : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
1880
1881         primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
1882             : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
1883
1884         REG_QUERY_MPD(src, src, 0);
1885         REG_QUERY_MPD(weights, weights, 0);
1886         REG_QUERY_MPD(bias, weights, 1);
1887         REG_QUERY_MPD(dst, dst, 0);
1888     };
1889
1890     deconvolution_forward(const primitive_desc &aprimitive_desc,
1891             const primitive::at &src, const primitive::at &weights,
1892             const primitive::at &bias, const memory &dst) {
1893         mkldnn_primitive_t result;
1894         mkldnn_primitive_at_t inputs[] = { src.data, weights.data,
1895                     bias.data };
1896         const_mkldnn_primitive_t outputs[] = { dst.get() };
1897         check_num_parameters(aprimitive_desc.get(), 3, 1,
1898             "deconvolution forward");
1899         error::wrap_c_api(mkldnn_primitive_create(&result,
1900                     aprimitive_desc.get(), inputs, outputs),
1901                 "could not create a deconvolution forward bias primitive");
1902         reset(result);
1903     }
1904
1905     deconvolution_forward(const primitive_desc &aprimitive_desc,
1906             const primitive::at &src, const primitive::at &weights,
1907             const memory &dst) {
1908         mkldnn_primitive_t result;
1909         mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
1910         const_mkldnn_primitive_t outputs[] = { dst.get() };
1911         check_num_parameters(aprimitive_desc.get(), 2, 1,
1912             "deconvolution forward");
1913         error::wrap_c_api(mkldnn_primitive_create(&result,
1914                     aprimitive_desc.get(), inputs, outputs),
1915                 "could not create a deconvolution forward primitive");
1916         reset(result);
1917     }
1918 };
1919
1920 struct deconvolution_backward_data : public primitive {
1921     struct desc {
1922         mkldnn_deconvolution_desc_t data;
1923         desc(algorithm aalgorithm,
1924                 const memory::desc &diff_src_desc,
1925                 const memory::desc &weights_desc,
1926                 const memory::desc &diff_dst_desc,
1927                 const memory::dims strides,
1928                 const memory::dims padding_l,
1929                 const memory::dims padding_r,
1930                 const padding_kind apadding_kind) {
1931             memory::validate_dims(strides);
1932             memory::validate_dims(padding_l);
1933             memory::validate_dims(padding_r);
1934             error::wrap_c_api(mkldnn_deconvolution_backward_data_desc_init(
1935                         &data, convert_to_c(aalgorithm), &diff_src_desc.data,
1936                         &weights_desc.data, &diff_dst_desc.data,
1937                         &strides[0], &padding_l[0], &padding_r[0],
1938                         mkldnn::convert_to_c(apadding_kind)),
1939                     "could not create a deconvolution backward data descriptor");
1940         }
1941         desc(algorithm aalgorithm,
1942                 const memory::desc &diff_src_desc,
1943                 const memory::desc &weights_desc,
1944                 const memory::desc &diff_dst_desc,
1945                 const memory::dims strides,
1946                 const memory::dims dilates,
1947                 const memory::dims padding_l,
1948                 const memory::dims padding_r,
1949                 const padding_kind apadding_kind) {
1950             memory::validate_dims(strides);
1951             memory::validate_dims(dilates);
1952             memory::validate_dims(padding_l);
1953             memory::validate_dims(padding_r);
1954             error::wrap_c_api(mkldnn_dilated_deconvolution_backward_data_desc_init(
1955                         &data, convert_to_c(aalgorithm), &diff_src_desc.data,
1956                         &weights_desc.data, &diff_dst_desc.data,
1957                         &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1958                         mkldnn::convert_to_c(apadding_kind)),
1959                     "could not create a dilated deconvolution backward data descriptor");
1960         }
1961     };
1962
1963     struct primitive_desc : public mkldnn::primitive_desc {
1964         primitive_desc(const desc &desc, const engine &e,
1965                 const deconvolution_forward::primitive_desc &hint_fwd_pd)
1966             : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
1967
1968         primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
1969                 const deconvolution_forward::primitive_desc &hint_fwd_pd)
1970             : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
1971
1972         REG_QUERY_MPD(diff_src, diff_src, 0);
1973         REG_QUERY_MPD(weights, weights, 0);
1974         REG_QUERY_MPD(diff_dst, diff_dst, 0);
1975     };
1976
1977     deconvolution_backward_data(const primitive_desc &aprimitive_desc,
1978             const primitive::at &diff_dst, const primitive::at &weights,
1979             const memory &diff_src) {
1980         mkldnn_primitive_t result;
1981         mkldnn_primitive_at_t inputs[] = { diff_dst.data, weights.data  };
1982         const_mkldnn_primitive_t outputs[] = { diff_src.get() };
1983         check_num_parameters(aprimitive_desc.get(), 2, 1,
1984             "deconvolution backward data");
1985         error::wrap_c_api(mkldnn_primitive_create(&result,
1986                     aprimitive_desc.get(), inputs, outputs),
1987                 "could not create a deconvolution backward data primitive");
1988         reset(result);
1989     }
1990 };
1991
1992 struct deconvolution_backward_weights : public primitive {
1993     struct desc {
1994         mkldnn_deconvolution_desc_t data;
1995         desc(algorithm aalgorithm,
1996                 const memory::desc &src_desc,
1997                 const memory::desc &diff_weights_desc,
1998                 const memory::desc &diff_bias_desc,
1999                 const memory::desc &diff_dst_desc,
2000                 const memory::dims strides,
2001                 const memory::dims padding_l,
2002                 const memory::dims padding_r,
2003                 const padding_kind apadding_kind) {
2004             memory::validate_dims(strides);
2005             memory::validate_dims(padding_l);
2006             memory::validate_dims(padding_r);
2007             error::wrap_c_api(mkldnn_deconvolution_backward_weights_desc_init(
2008                         &data, convert_to_c(aalgorithm), &src_desc.data,
2009                         &diff_weights_desc.data, &diff_bias_desc.data,
2010                         &diff_dst_desc.data,
2011                         &strides[0], &padding_l[0], &padding_r[0],
2012                         mkldnn::convert_to_c(apadding_kind)),
2013                     "could not create a deconvolution backward weights descriptor");
2014         }
2015         desc(algorithm aalgorithm,
2016                 const memory::desc &src_desc,
2017                 const memory::desc &diff_weights_desc,
2018                 const memory::desc &diff_dst_desc,
2019                 const memory::dims strides,
2020                 const memory::dims padding_l,
2021                 const memory::dims padding_r,
2022                 const padding_kind apadding_kind) {
2023             memory::validate_dims(strides);
2024             memory::validate_dims(padding_l);
2025             memory::validate_dims(padding_r);
2026             error::wrap_c_api(mkldnn_deconvolution_backward_weights_desc_init(
2027                         &data, convert_to_c(aalgorithm), &src_desc.data,
2028                         &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
2029                         &strides[0], &padding_l[0], &padding_r[0],
2030                         mkldnn::convert_to_c(apadding_kind)),
2031                     "could not create a deconvolution backward weights descriptor");
2032         }
2033         desc(algorithm aalgorithm,
2034                 const memory::desc &src_desc,
2035                 const memory::desc &diff_weights_desc,
2036                 const memory::desc &diff_bias_desc,
2037                 const memory::desc &diff_dst_desc,
2038                 const memory::dims strides,
2039                 const memory::dims dilates,
2040                 const memory::dims padding_l,
2041                 const memory::dims padding_r,
2042                 const padding_kind apadding_kind) {
2043             memory::validate_dims(strides);
2044             memory::validate_dims(dilates);
2045             memory::validate_dims(padding_l);
2046             memory::validate_dims(padding_r);
2047             error::wrap_c_api(mkldnn_dilated_deconvolution_backward_weights_desc_init(
2048                         &data, convert_to_c(aalgorithm), &src_desc.data,
2049                         &diff_weights_desc.data, &diff_bias_desc.data,
2050                         &diff_dst_desc.data,
2051                         &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
2052                         mkldnn::convert_to_c(apadding_kind)),
2053                     "could not create a dilated  deconvolution backward weights descriptor");
2054         }
2055         desc(algorithm aalgorithm,
2056                 const memory::desc &src_desc,
2057                 const memory::desc &diff_weights_desc,
2058                 const memory::desc &diff_dst_desc,
2059                 const memory::dims strides,
2060                 const memory::dims dilates,
2061                 const memory::dims padding_l,
2062                 const memory::dims padding_r,
2063                 const padding_kind apadding_kind) {
2064             memory::validate_dims(strides);
2065             memory::validate_dims(dilates);
2066             memory::validate_dims(padding_l);
2067             memory::validate_dims(padding_r);
2068             error::wrap_c_api(mkldnn_dilated_deconvolution_backward_weights_desc_init(
2069                         &data, convert_to_c(aalgorithm), &src_desc.data,
2070                         &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
2071                         &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
2072                         mkldnn::convert_to_c(apadding_kind)),
2073                     "could not create a dilated deconvolution backward weights descriptor");
2074         }
2075     };
2076
2077     struct primitive_desc : public mkldnn::primitive_desc {
2078         primitive_desc(const desc &desc, const engine &e,
2079                 const deconvolution_forward::primitive_desc &hint_fwd_pd)
2080             : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2081
2082         primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2083                 const deconvolution_forward::primitive_desc &hint_fwd_pd)
2084             : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2085
2086         REG_QUERY_MPD(src, src, 0);
2087         REG_QUERY_MPD(diff_weights, diff_weights, 0);
2088         REG_QUERY_MPD(diff_bias, diff_weights, 1);
2089         REG_QUERY_MPD(diff_dst, diff_dst, 0);
2090     };
2091
2092     deconvolution_backward_weights(const primitive_desc &aprimitive_desc,
2093             const primitive::at &src, const primitive::at &diff_dst,
2094             const memory &diff_weights, const memory &diff_bias) {
2095         mkldnn_primitive_t result;
2096         mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
2097         const_mkldnn_primitive_t outputs[] = { diff_weights.get(),
2098                     diff_bias.get() };
2099         check_num_parameters(aprimitive_desc.get(), 2, 2,
2100             "deconvolution backward weights");
2101         error::wrap_c_api(mkldnn_primitive_create(&result,
2102                     aprimitive_desc.get(), inputs, outputs),
2103                 "could not create a deconvolution backward weights primitive");
2104         reset(result);
2105     }
2106     deconvolution_backward_weights(const primitive_desc &aprimitive_desc,
2107             const primitive::at &src, const primitive::at &diff_dst,
2108             const memory &diff_weights) {
2109         mkldnn_primitive_t result;
2110         mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
2111         const_mkldnn_primitive_t outputs[] = { diff_weights.get() };
2112         check_num_parameters(aprimitive_desc.get(), 2, 1,
2113             "deconvolution backward weights");
2114         error::wrap_c_api(mkldnn_primitive_create(&result,
2115                     aprimitive_desc.get(), inputs, outputs),
2116                 "could not create a deconvolution backward weights primitive");
2117         reset(result);
2118     }
2119 };
2120
2121 /// @}
2122
2123 /// @addtogroup cpp_api_roi_pooling ROIPooling
2124 /// @{
2125
2126 struct roi_pooling_forward : public primitive {
2127     struct desc {
2128         mkldnn_roi_pooling_desc_t data;
2129         std::vector<mkldnn_memory_desc_t> c_api_inputs;
2130
2131         desc(prop_kind aprop_kind, algorithm aalgorithm, std::vector<memory::desc> inputs,
2132              const memory::desc &dst_desc, int pooled_h, int pooled_w, double spatial_scale) {
2133
2134             for(size_t i = 0; i < inputs.size(); i++) {
2135                 c_api_inputs.push_back(inputs[i].data);
2136             }
2137
2138             error::wrap_c_api(mkldnn_roi_pooling_forward_desc_init(&data,
2139                         mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), &c_api_inputs[0],
2140                         c_api_inputs.size(),
2141                         &dst_desc.data, pooled_h, pooled_w, spatial_scale),
2142                     "could not create a roi pooling forward descriptor");
2143         }
2144     };
2145
2146     struct primitive_desc : public handle<mkldnn_primitive_desc_t>{
2147         primitive_desc(const desc &adesc, const engine &aengine) {
2148             mkldnn_primitive_desc_t result;
2149             error::wrap_c_api(mkldnn_primitive_desc_create(
2150                         &result, &adesc.data, aengine.get(), nullptr),
2151                     "could not create a roi pooling forward primitive descriptor");
2152             reset(result);
2153         }
2154     };
2155
2156     roi_pooling_forward(const primitive_desc &aprimitive_desc,
2157             std::vector<primitive::at> &inputs, const memory &dst) {
2158         mkldnn_primitive_t result;
2159
2160         std::vector<mkldnn_primitive_at_t> p_inputs;
2161         for (size_t i = 0; i < inputs.size(); i++) {
2162             p_inputs.push_back(inputs[i].data);
2163         }
2164
2165         const_mkldnn_primitive_t outputs[] = { dst.get() };
2166         error::wrap_c_api(mkldnn_primitive_create(&result,
2167                     aprimitive_desc.get(), &p_inputs[0], outputs),
2168                 "could not create a roi pooling forward primitive");
2169         reset(result);
2170     }
2171 };
2172
2173 /// @}
2174
2175 /// @addtogroup cpp_api_lrn LRN
2176 /// A primitive to perform local response normalization (LRN) across or within
2177 /// channels.
2178 ///
2179 /// @sa @ref c_api_lrn in @ref c_api
2180 /// @{
2181
2182 struct lrn_forward : public primitive {
2183     struct desc {
2184         mkldnn_lrn_desc_t data;
2185         desc(prop_kind aprop_kind, algorithm aalgorithm,
2186             const memory::desc &src_desc,
2187             int local_size, float alpha, float beta, float k)
2188         {
2189             error::wrap_c_api(mkldnn_lrn_forward_desc_init(&data,
2190                 mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
2191                 &src_desc.data, local_size, alpha, beta, k),
2192                 "could not create a lrn forward descriptor");
2193         }
2194         desc(prop_kind aprop_kind, algorithm aalgorithm,
2195             const memory::desc &src_desc,
2196             int local_size, float alpha, float beta)
2197         {
2198             error::wrap_c_api(mkldnn_lrn_forward_desc_init(&data,
2199                 mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
2200                 &src_desc.data, local_size, alpha, beta, float(1.0)),
2201                 "could not create a lrn forward descriptor");
2202         }
2203     };
2204
2205     struct primitive_desc : public mkldnn::primitive_desc {
2206         primitive_desc(const desc &desc, const engine &e)
2207             : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2208
2209         primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2210             : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2211
2212         REG_QUERY_MPD(src, src, 0);
2213         REG_QUERY_MPD(dst, dst, 0);
2214         REG_QUERY_MPD(workspace, workspace, 0);
2215     };
2216
2217     lrn_forward(const primitive_desc &aprimitive_desc,
2218             const primitive::at &src, const memory &workspace,
2219             const memory &dst) {
2220         mkldnn_primitive_t result;
2221         mkldnn_primitive_at_t inputs[] = { src.data };
2222         const_mkldnn_primitive_t outputs[] = { dst.get(),
2223                 workspace.get() };
2224         check_num_parameters(aprimitive_desc.get(), 1, 2, "lrn forward");
2225         error::wrap_c_api(mkldnn_primitive_create(&result,
2226                 aprimitive_desc.get(), inputs, outputs),
2227             "could not create a lrn forward primitive");
2228         reset(result);
2229     }
2230
2231     lrn_forward(const primitive_desc &aprimitive_desc,
2232             const primitive::at &src, const memory &dst) {
2233         mkldnn_primitive_t result;
2234         mkldnn_primitive_at_t inputs[] = { src.data };
2235         const_mkldnn_primitive_t outputs[] = { dst.get() };
2236         check_num_parameters(aprimitive_desc.get(), 1, 1, "lrn forward");
2237         error::wrap_c_api(mkldnn_primitive_create(&result,
2238                 aprimitive_desc.get(), inputs, outputs),
2239             "could not create a lrn forward primitive");
2240         reset(result);
2241     }
2242 };
2243
2244 struct lrn_backward : public primitive {
2245     struct desc {
2246         mkldnn_lrn_desc_t data;
2247         desc(algorithm aalgorithm,
2248             const memory::desc &data_desc,
2249             const memory::desc &diff_data_desc,
2250             int local_size, float alpha, float beta, float k)
2251         {
2252             error::wrap_c_api(mkldnn_lrn_backward_desc_init(&data,
2253                 convert_to_c(aalgorithm), &diff_data_desc.data,
2254                 &data_desc.data, local_size, alpha, beta, k),
2255                 "could not create a lrn backward descriptor");
2256         }
2257         desc(algorithm aalgorithm,
2258             const memory::desc &data_desc,
2259             const memory::desc &diff_data_desc,
2260             int local_size, float alpha, float beta)
2261         {
2262             error::wrap_c_api(mkldnn_lrn_backward_desc_init(&data,
2263                 convert_to_c(aalgorithm), &diff_data_desc.data,
2264                 &data_desc.data, local_size, alpha, beta, float(1.0)),
2265                 "could not create a lrn backward descriptor");
2266         }
2267     };
2268
2269     struct primitive_desc : public mkldnn::primitive_desc {
2270         primitive_desc(const desc &desc, const engine &e,
2271                 const lrn_forward::primitive_desc &hint_fwd_pd)
2272             : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2273
2274         primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2275                 const lrn_forward::primitive_desc &hint_fwd_pd)
2276             : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2277
2278         REG_QUERY_MPD(diff_src, diff_src, 0);
2279         REG_QUERY_MPD(diff_dst, diff_dst, 0);
2280         REG_QUERY_MPD(workspace, workspace, 0);
2281     };
2282
2283     lrn_backward(const primitive_desc &aprimitive_desc,
2284             const primitive::at &src, const primitive::at &diff_dst,
2285             const primitive::at &workspace, const memory &diff_src) {
2286         mkldnn_primitive_t result;
2287         mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data,
2288                 workspace.data };
2289         const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2290         check_num_parameters(aprimitive_desc.get(), 3, 1, "lrn backward");
2291         error::wrap_c_api(mkldnn_primitive_create(&result,
2292                 aprimitive_desc.get(), inputs, outputs),
2293             "could not create a lrn backward primitive");
2294         reset(result);
2295     }
2296
2297     lrn_backward(const primitive_desc &aprimitive_desc,
2298             const primitive::at &src, const primitive::at &diff_dst,
2299             const memory &diff_src) {
2300         mkldnn_primitive_t result;
2301         mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
2302         const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2303         check_num_parameters(aprimitive_desc.get(), 2, 1, "lrn backward");
2304         error::wrap_c_api(mkldnn_primitive_create(&result,
2305                 aprimitive_desc.get(), inputs, outputs),
2306             "could not create a lrn backward primitive");
2307         reset(result);
2308     }
2309 };
2310
2311 /// @}
2312
2313 /// @addtogroup cpp_api_pooling Pooling
2314 /// A primitive to perform max or average pooling.
2315 ///
2316 /// @sa @ref c_api_pooling in @ref c_api
2317 /// @{
2318
2319 struct pooling_forward : public primitive {
2320     struct desc {
2321         mkldnn_pooling_desc_t data;
2322         desc(prop_kind aprop_kind, algorithm aalgorithm,
2323                 const memory::desc &src_desc,
2324                 const memory::desc &dst_desc,
2325                 const memory::dims strides,
2326                 const memory::dims kernel,
2327                 const memory::dims padding_l,
2328                 const memory::dims padding_r,
2329                 const padding_kind apadding_kind) {
2330             memory::validate_dims(strides);
2331             memory::validate_dims(kernel);
2332             memory::validate_dims(padding_l);
2333             memory::validate_dims(padding_r);
2334             error::wrap_c_api(mkldnn_pooling_forward_desc_init(&data,
2335                         mkldnn::convert_to_c(aprop_kind),
2336                         convert_to_c(aalgorithm),
2337                         &src_desc.data, &dst_desc.data,
2338                         &strides[0], &kernel[0],
2339                         &padding_l[0], &padding_r[0],
2340                         mkldnn::convert_to_c(apadding_kind)),
2341                     "could not init a forward pooling descriptor");
2342         }
2343     };
2344
2345     struct primitive_desc : public mkldnn::primitive_desc {
2346         primitive_desc(const desc &desc, const engine &e)
2347             : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2348
2349         primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2350             : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2351
2352         REG_QUERY_MPD(src, src, 0);
2353         REG_QUERY_MPD(dst, dst, 0);
2354         REG_QUERY_MPD(workspace, workspace, 0);
2355     };
2356
2357     pooling_forward(const primitive_desc &aprimitive_desc, const primitive::at &src,
2358             const memory &dst) {
2359         mkldnn_primitive_t result;
2360         mkldnn_primitive_at_t inputs[] = { src.data };
2361         const_mkldnn_primitive_t outputs[] = { dst.get(), nullptr };
2362         check_num_parameters(aprimitive_desc.get(), 1, 1, "pooling forward");
2363         error::wrap_c_api(mkldnn_primitive_create(&result,
2364                     aprimitive_desc.get(), inputs, outputs),
2365                 "could not create a pooling forward primitive");
2366         reset(result);
2367     }
2368
2369     pooling_forward(const primitive_desc &aprimitive_desc, const primitive::at &src,
2370             const memory &dst, const memory &workspace) {
2371         mkldnn_primitive_t result;
2372         mkldnn_primitive_at_t inputs[] = { src.data };
2373         const_mkldnn_primitive_t outputs[] = { dst.get(), workspace.get() };
2374         check_num_parameters(aprimitive_desc.get(), 1, 2, "pooling forward");
2375         error::wrap_c_api(mkldnn_primitive_create(&result,
2376                     aprimitive_desc.get(), inputs, outputs),
2377                 "could not create a pooling forward primitive");
2378         reset(result);
2379     }
2380 };
2381
2382 struct pooling_backward : public primitive {
2383     struct desc {
2384         mkldnn_pooling_desc_t data;
2385         desc(algorithm aalgorithm,
2386                 const memory::desc &diff_src_desc,
2387                 const memory::desc &diff_dst_desc,
2388                 const memory::dims &strides,
2389                 const memory::dims &kernel,
2390                 const memory::dims &padding_l,
2391                 const memory::dims &padding_r,
2392                 const padding_kind apadding_kind) {
2393             memory::validate_dims(strides);
2394             memory::validate_dims(kernel);
2395             memory::validate_dims(padding_l);
2396             memory::validate_dims(padding_r);
2397             error::wrap_c_api(mkldnn_pooling_backward_desc_init(&data,
2398                         convert_to_c(aalgorithm),
2399                         &diff_src_desc.data, &diff_dst_desc.data,
2400                         &strides[0], &kernel[0],
2401                         &padding_l[0], &padding_r[0],
2402                         mkldnn::convert_to_c(apadding_kind)),
2403                     "could not init a backward pooling descriptor");
2404         }
2405     };
2406
2407     struct primitive_desc : public mkldnn::primitive_desc {
2408         primitive_desc(const desc &desc, const engine &e,
2409                 const pooling_forward::primitive_desc &hint_fwd_pd)
2410             : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2411
2412         primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2413                 const pooling_forward::primitive_desc &hint_fwd_pd)
2414             : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2415
2416         REG_QUERY_MPD(diff_src, diff_src, 0);
2417         REG_QUERY_MPD(diff_dst, diff_dst, 0);
2418         REG_QUERY_MPD(workspace, workspace, 0);
2419     };
2420
2421     pooling_backward(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst,
2422             const memory &diff_src) {
2423         mkldnn_primitive_t result;
2424         mkldnn_primitive_at_t inputs[] = { diff_dst.data };
2425         const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2426         check_num_parameters(aprimitive_desc.get(), 1, 1, "pooling backward");
2427         error::wrap_c_api(mkldnn_primitive_create(&result,
2428                     aprimitive_desc.get(), inputs, outputs),
2429                 "could not create a pooling backward primitive");
2430         reset(result);
2431     }
2432
2433     pooling_backward(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst,
2434             const primitive::at &workspace, const memory &diff_src) {
2435         mkldnn_primitive_t result;
2436         mkldnn_primitive_at_t inputs[] = { diff_dst.data, workspace.data };
2437         const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2438         check_num_parameters(aprimitive_desc.get(), 2, 1, "pooling backward");
2439         error::wrap_c_api(mkldnn_primitive_create(&result,
2440                     aprimitive_desc.get(), inputs, outputs),
2441                 "could not create a pooling backward primitive");
2442         reset(result);
2443     }
2444 };
2445
2446 /// @}
2447
2448 /// @addtogroup cpp_api_eltwise Eltwise
2449 /// A primitive to compute element-wise operations like parametric rectifier
2450 /// linear unit (ReLU).
2451 ///
2452 /// @sa @ref c_api_eltwise in @ref c_api
2453 /// @{
2454
2455 struct eltwise_forward : public primitive {
2456     struct desc {
2457         mkldnn_eltwise_desc_t data;
2458         template <typename T>
2459         desc(prop_kind aprop_kind, algorithm alg_kind,
2460                 const memory::desc &src_desc, T alpha = 0, T beta = 0) {
2461             error::wrap_c_api(mkldnn_eltwise_forward_desc_init(&data,
2462                         mkldnn::convert_to_c(aprop_kind),
2463                         mkldnn::convert_to_c(alg_kind), &src_desc.data,
2464                         static_cast<float>(alpha), static_cast<float>(beta)),
2465                     "could not create a eltwise forward descriptor");
2466         }
2467     };
2468
2469     struct primitive_desc : public mkldnn::primitive_desc {
2470         primitive_desc(const desc &desc, const engine &e)
2471             : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2472
2473         primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2474             : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2475
2476         REG_QUERY_MPD(src, src, 0);
2477         REG_QUERY_MPD(dst, dst, 0);
2478     };
2479
2480     eltwise_forward(const primitive_desc &aprimitive_desc,
2481             const primitive::at &src, const memory &dst) {
2482         mkldnn_primitive_t result;
2483         mkldnn_primitive_at_t inputs[] = { src.data };
2484         const_mkldnn_primitive_t outputs[] = { dst.get() };
2485         check_num_parameters(aprimitive_desc.get(), 1, 1, "eltwise forward");
2486         error::wrap_c_api(mkldnn_primitive_create(&result,
2487                 aprimitive_desc.get(), inputs, outputs),
2488             "could not create a eltwise forward primitive");
2489         reset(result);
2490     }
2491 };
2492
2493 struct eltwise_backward : public primitive {
2494     struct desc {
2495         mkldnn_eltwise_desc_t data;
2496
2497         template <typename T>
2498         desc(algorithm alg_kind, const memory::desc &diff_data_desc,
2499                 const memory::desc &data_desc, T alpha = 0, T beta = 0) {
2500             error::wrap_c_api(mkldnn_eltwise_backward_desc_init(&data,
2501                         mkldnn::convert_to_c(alg_kind), &diff_data_desc.data,
2502                         &data_desc.data, static_cast<float>(alpha),
2503                         static_cast<float>(beta)),
2504                     "could not create a eltwise backward descriptor");
2505         }
2506     };
2507
2508     struct primitive_desc : public mkldnn::primitive_desc {
2509         primitive_desc(const desc &desc, const engine &e,
2510                 const eltwise_forward::primitive_desc &hint_fwd_pd)
2511             : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2512
2513         primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2514                 const eltwise_forward::primitive_desc &hint_fwd_pd)
2515             : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2516
2517         REG_QUERY_MPD(src, src, 0);
2518         REG_QUERY_MPD(diff_src, diff_src, 0);
2519         REG_QUERY_MPD(diff_dst, diff_dst, 0);
2520     };
2521
2522     eltwise_backward(const primitive_desc &aprimitive_desc,
2523             const primitive::at &src, const primitive::at &diff_dst,
2524             const memory &diff_src) {
2525         mkldnn_primitive_t result;
2526         mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
2527         const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2528         check_num_parameters(aprimitive_desc.get(), 2, 1, "eltwise backward");
2529         error::wrap_c_api(mkldnn_primitive_create(&result,
2530                 aprimitive_desc.get(), inputs, outputs),
2531             "could not create a eltwise backward primitive");
2532         reset(result);
2533     }
2534 };
2535
2536 /// @}
2537
2538 /// @addtogroup cpp_api_depthwise Depthwise
2539 /// @{
2540
2541 struct depthwise_forward : public primitive {
2542     struct desc {
2543         mkldnn_depthwise_desc_t data;
2544
2545         desc(prop_kind aprop_kind, algorithm alg_kind,
2546              const memory::desc &src_desc, const memory::desc &dst_desc, const memory::desc &weights_desc,
2547              const memory::desc &bias_desc) {
2548             error::wrap_c_api(mkldnn_depthwise_forward_desc_init(&data,
2549                                                                  mkldnn::convert_to_c(aprop_kind),
2550                                                                      mkldnn::convert_to_c(alg_kind),
2551                                                                      &src_desc.data, &dst_desc.data,
2552                                                                  &weights_desc.data, &bias_desc.data),
2553                               "could not create a depthwise forward descriptor");
2554         }
2555
2556         desc(prop_kind aprop_kind, algorithm alg_kind,
2557              const memory::desc &src_desc, const memory::desc &dst_desc, const memory::desc &weights_desc) {
2558             error::wrap_c_api(mkldnn_depthwise_forward_desc_init(&data,
2559                                                                  mkldnn::convert_to_c(aprop_kind),
2560                                                                  mkldnn::convert_to_c(alg_kind),
2561                                                                  &src_desc.data, &dst_desc.data,
2562                                                                  &weights_desc.data, nullptr),
2563                               "could not create a depthwise forward descriptor");
2564         }
2565     };
2566
2567     struct primitive_desc : public mkldnn::primitive_desc {
2568         primitive_desc(const desc &desc, const engine &e)
2569             : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2570
2571         primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2572             : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2573
2574         REG_QUERY_MPD(src, src, 0);
2575         REG_QUERY_MPD(dst, dst, 0);
2576     };
2577
2578     depthwise_forward(const primitive_desc &aprimitive_desc,
2579                       const primitive::at &src, const primitive::at &weights,
2580                       const primitive::at &bias, const memory &dst) {
2581         mkldnn_primitive_t result;
2582         mkldnn_primitive_at_t inputs[] = { src.data, weights.data, bias.data };
2583         const_mkldnn_primitive_t outputs[] = { dst.get() };
2584         error::wrap_c_api(mkldnn_primitive_create(&result,
2585                                                   aprimitive_desc.get(), inputs, outputs),
2586                           "could not create a depthwise forward primitive");
2587         reset(result);
2588     }
2589
2590     depthwise_forward(const primitive_desc &aprimitive_desc,
2591                       const primitive::at &src, const primitive::at &weights,
2592                       const memory &dst) {
2593         mkldnn_primitive_t result;
2594         mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
2595         const_mkldnn_primitive_t outputs[] = { dst.get() };
2596         error::wrap_c_api(mkldnn_primitive_create(&result,
2597                                                   aprimitive_desc.get(), inputs, outputs),
2598                           "could not create a depthwise forward primitive");
2599         reset(result);
2600     }
2601 };
2602
2603 /// @}
2604
2605 /// @addtogroup cpp_api_softmax Softmax
2606 /// A primitive to perform softmax.
2607 ///
2608 /// @sa @ref c_api_softmax in @ref c_api
2609 /// @{
2610
2611 struct softmax_forward : public primitive {
2612     struct desc {
2613         mkldnn_softmax_desc_t data;
2614         desc(prop_kind aprop_kind, const memory::desc &data_desc,
2615              int softmax_axis) {
2616             error::wrap_c_api(mkldnn_softmax_forward_desc_init(&data,
2617                     mkldnn::convert_to_c(aprop_kind), &data_desc.data,
2618                     softmax_axis),
2619                 "could not create a softmax forward descriptor");
2620         }
2621     };
2622
2623     struct primitive_desc : public mkldnn::primitive_desc {
2624         primitive_desc(const desc &desc, const engine &e)
2625             : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2626
2627         primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2628             : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2629
2630         REG_QUERY_MPD(src, src, 0);
2631         REG_QUERY_MPD(dst, dst, 0);
2632     };
2633
2634     softmax_forward(const primitive_desc &aprimitive_desc,
2635             const primitive::at &src, const memory &dst) {
2636         mkldnn_primitive_t result;
2637         mkldnn_primitive_at_t inputs[] = { src.data };
2638         const_mkldnn_primitive_t outputs[] = { dst.get() };
2639         check_num_parameters(aprimitive_desc.get(), 1, 1, "softmax forward");
2640         error::wrap_c_api(mkldnn_primitive_create(&result,
2641                 aprimitive_desc.get(), inputs, outputs),
2642             "could not create a softmax forward primitive");
2643         reset(result);
2644     }
2645 };
2646
2647 struct softmax_backward : public primitive {
2648     struct desc {
2649         mkldnn_softmax_desc_t data;
2650         desc(const memory::desc &diff_desc, const memory::desc &data_desc,
2651                 int softmax_axis) {
2652             error::wrap_c_api(mkldnn_softmax_backward_desc_init(&data,
2653                         &diff_desc.data, &data_desc.data, softmax_axis),
2654                     "could not init a backward softmax descriptor");
2655         }
2656     };
2657
2658     struct primitive_desc : public mkldnn::primitive_desc {
2659         primitive_desc(const desc &desc, const engine &e,
2660                 const softmax_forward::primitive_desc &hint_fwd_pd)
2661             : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2662
2663         primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2664                 const softmax_forward::primitive_desc &hint_fwd_pd)
2665             : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2666
2667         REG_QUERY_MPD(dst, dst, 0);
2668         REG_QUERY_MPD(diff_src, diff_src, 0);
2669         REG_QUERY_MPD(diff_dst, diff_dst, 0);
2670         REG_QUERY_MPD(workspace, workspace, 0);
2671     };
2672
2673     softmax_backward(const primitive_desc &aprimitive_desc,
2674             const primitive::at &dst, const primitive::at &diff_dst,
2675             const memory &diff_src) {
2676         mkldnn_primitive_t result;
2677         mkldnn_primitive_at_t inputs[] = { dst.data, diff_dst.data };
2678         const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2679         error::wrap_c_api(mkldnn_primitive_create(&result,
2680                     aprimitive_desc.get(), inputs, outputs),
2681                 "could not create a softmax backward primitive");
2682         reset(result);
2683     }
2684 };
2685
2686 /// @}
2687
2688 /// @addtogroup cpp_api_batch_norm Batch normalization
2689 /// A primitive to perform batch normalization.
2690 ///
2691 /// @sa @ref c_api_batch_normalization in @ref c_api
2692 /// @{
2693
2694 struct batch_normalization_forward : public primitive {
2695     struct desc {
2696         mkldnn_batch_normalization_desc_t data;
2697         template <typename T>
2698         desc(prop_kind aprop_kind, const memory::desc &src_desc, T epsilon,
2699                 unsigned flags) {
2700             error::wrap_c_api(
2701                     mkldnn_batch_normalization_forward_desc_init(&data,
2702                         mkldnn::convert_to_c(aprop_kind), &src_desc.data,
2703                         static_cast<float>(epsilon), flags),
2704                 "could not create a batch normalization forward descriptor");
2705         }
2706     };
2707
2708     struct primitive_desc : public mkldnn::primitive_desc {
2709         primitive_desc(const desc &desc, const engine &e)
2710             : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2711
2712         primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2713             : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2714
2715         REG_QUERY_MPD(src, src, 0);
2716         REG_QUERY_MPD(weights, weights, 0);
2717         REG_QUERY_MPD(dst, dst, 0);
2718         REG_QUERY_MPD(workspace, workspace, 0);
2719
2720         memory::primitive_desc mean_primitive_desc() const
2721         { return stat_primitive_desc(mean); }
2722         memory::primitive_desc variance_primitive_desc() const
2723         { return stat_primitive_desc(var); }
2724
2725     private:
2726         enum { mean = 1, var = 2, };
2727         memory::primitive_desc stat_primitive_desc(int kind) const {
2728             mkldnn_batch_normalization_desc_t *p;
2729             error::wrap_c_api(mkldnn_primitive_desc_query(
2730                     get(), mkldnn::convert_to_c(batch_normalization_d), 0, &p),
2731                     "could not get a batch-normalization descriptor");
2732             return query_mpd(p->flags & use_global_stats ? src_pd : dst_pd, kind);
2733         }
2734     };
2735
2736     batch_normalization_forward(const primitive_desc &aprimitive_desc,
2737             const primitive::at &src, const primitive::at &mean,
2738             const primitive::at &variance, const primitive::at &weights,
2739             const memory &dst) {
2740         mkldnn_primitive_t result;
2741         mkldnn_primitive_at_t inputs[] = { src.data,
2742             mean.data, variance.data, weights.data };
2743         const_mkldnn_primitive_t outputs[] = { dst.get() };
2744         check_num_parameters(aprimitive_desc.get(), 4, 1,
2745             "batch normalization forward");
2746         error::wrap_c_api(mkldnn_primitive_create(&result,
2747                 aprimitive_desc.get(), inputs, outputs),
2748             "could not create a batch normalization forward primitive");
2749         reset(result);
2750     }
2751
2752     batch_normalization_forward(const primitive_desc &aprimitive_desc,
2753             const primitive::at &src, const primitive::at &mean,
2754             const primitive::at &variance, const memory &dst) {
2755         mkldnn_primitive_t result;
2756         mkldnn_primitive_at_t inputs[] = { src.data,
2757             mean.data, variance.data };
2758         const_mkldnn_primitive_t outputs[] = { dst.get() };
2759         check_num_parameters(aprimitive_desc.get(), 3, 1,
2760             "batch normalization forward");
2761         error::wrap_c_api(mkldnn_primitive_create(&result,
2762                 aprimitive_desc.get(), inputs, outputs),
2763             "could not create a batch normalization forward primitive");
2764         reset(result);
2765     }
2766
2767     /// @warning batch_normalization_forward has two constructors with very
2768     ///          similar signatures:
2769     ///           - (pd, src, weights, dst, mean, variance) // 2 in, 3 out
2770     ///           - (pd, src, dst, mean, variance, workspace) // 1 in, 4 out
2771     ///          The only way to distinguish between them is to explicitly
2772     ///          cast all input parameters to their type; that is, to
2773     ///          const primitive:at &.
2774     batch_normalization_forward(const primitive_desc &aprimitive_desc,
2775             const primitive::at &src, const primitive::at &weights,
2776             const memory &dst, const memory &mean, const memory &variance) {
2777         mkldnn_primitive_t result;
2778         mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
2779         const_mkldnn_primitive_t outputs[] = { dst.get(),
2780             mean.get(), variance.get() };
2781         check_num_parameters(aprimitive_desc.get(), 2, 3,
2782             "batch normalization forward");
2783         error::wrap_c_api(mkldnn_primitive_create(&result,
2784                 aprimitive_desc.get(), inputs, outputs),
2785             "could not create a batch normalization forward primitive");
2786         reset(result);
2787     }
2788
2789     batch_normalization_forward(const primitive_desc &aprimitive_desc,
2790             const primitive::at &src, const primitive::at &weights,
2791             const memory &dst, const memory &mean, const memory &variance,
2792             const memory &workspace) {
2793         mkldnn_primitive_t result;
2794         mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
2795         const_mkldnn_primitive_t outputs[] = { dst.get(),
2796             mean.get(), variance.get(), workspace.get() };
2797         check_num_parameters(aprimitive_desc.get(), 2, 4,
2798             "batch normalization forward");
2799         error::wrap_c_api(mkldnn_primitive_create(&result,
2800                 aprimitive_desc.get(), inputs, outputs),
2801             "could not create a batch normalization forward primitive");
2802         reset(result);
2803     }
2804
2805     batch_normalization_forward(const primitive_desc &aprimitive_desc,
2806             const primitive::at &src, const memory &dst, const memory &mean,
2807             const memory &variance) {
2808         mkldnn_primitive_t result;
2809         mkldnn_primitive_at_t inputs[] = { src.data };
2810         const_mkldnn_primitive_t outputs[] = { dst.get(),
2811             mean.get(), variance.get() };
2812         check_num_parameters(aprimitive_desc.get(), 1, 3,
2813             "batch normalization forward");
2814         error::wrap_c_api(mkldnn_primitive_create(&result,
2815                 aprimitive_desc.get(), inputs, outputs),
2816             "could not create a batch normalization forward primitive");
2817         reset(result);
2818     }
2819
2820     /// @warning batch_normalization_forward has two constructors with very
2821     ///          similar signatures:
2822     ///           - (pd, src, weights, dst, mean, variance) // 2 in, 3 out
2823     ///           - (pd, src, dst, mean, variance, workspace) // 1 in, 4 out
2824     ///          The only way to distinguish between them is to explicitly
2825     ///          cast all input parameters to their type; that is, to
2826     ///          const primitive:at &.
2827     /// @note To make users' experience a little better, this constructor
2828     ///       checks whether parameters match the corresponding primitive
2829     ///       descriptor, and if not, calls the other (proper) constructor.
2830     batch_normalization_forward(const primitive_desc &aprimitive_desc,
2831             const primitive::at &src, const memory &dst, const memory &mean,
2832             const memory &variance, const memory &workspace) {
2833         mkldnn_primitive_t result;
2834         mkldnn_primitive_at_t inputs[2] = { src.data };
2835         const_mkldnn_primitive_t outputs[4] = { dst.get(),
2836             mean.get(), variance.get(), workspace.get() };
2837
2838         if (1) { // check whether this is the `wrong` constructor
2839             const int n_inputs_expected = mkldnn_primitive_desc_query_s32(
2840                     aprimitive_desc.get(), mkldnn_query_num_of_inputs_s32, 0);
2841             const int n_outputs_expected = mkldnn_primitive_desc_query_s32(
2842                     aprimitive_desc.get(), mkldnn_query_num_of_outputs_s32, 0);
2843             if (n_inputs_expected == 2 && n_outputs_expected == 3) {
2844                 // shift parameters, get rid of workspace, and add weights...
2845                 auto _weights = dst;
2846                 inputs[1] = {_weights.get(), 0};
2847
2848                 auto _dst = mean, _mean = variance, _variance = workspace;
2849                 outputs[0] = _dst.get();
2850                 outputs[1] = _mean.get();
2851                 outputs[2] = _variance.get();
2852                 outputs[3] = nullptr;
2853             }
2854         }
2855         error::wrap_c_api(mkldnn_primitive_create(&result,
2856                 aprimitive_desc.get(), inputs, outputs),
2857             "could not create a batch normalization forward primitive");
2858         reset(result);
2859     }
2860
2861     batch_normalization_forward(const primitive_desc &aprimitive_desc,
2862             const primitive::at &src, const primitive::at &weights,
2863             const memory &dst) {
2864         mkldnn_primitive_t result;
2865         mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
2866         const_mkldnn_primitive_t outputs[] = { dst.get() };
2867         check_num_parameters(aprimitive_desc.get(), 2, 1,
2868             "batch normalization forward");
2869         error::wrap_c_api(mkldnn_primitive_create(&result,
2870                 aprimitive_desc.get(), inputs, outputs),
2871             "could not create a batch normalization forward primitive");
2872         reset(result);
2873     }
2874
2875     batch_normalization_forward(const primitive_desc &aprimitive_desc,
2876             const primitive::at &src, const memory &dst) {
2877         mkldnn_primitive_t result;
2878         mkldnn_primitive_at_t inputs[] = { src.data };
2879         const_mkldnn_primitive_t outputs[] = { dst.get() };
2880         check_num_parameters(aprimitive_desc.get(), 1, 1,
2881             "batch normalization forward");
2882         error::wrap_c_api(mkldnn_primitive_create(&result,
2883                 aprimitive_desc.get(), inputs, outputs),
2884             "could not create a batch normalization forward primitive");
2885         reset(result);
2886     }
2887 };
2888
2889 struct batch_normalization_backward : public primitive {
2890     struct desc {
2891         mkldnn_batch_normalization_desc_t data;
2892         template <typename T>
2893         desc(prop_kind aprop_kind, const memory::desc &diff_data_desc,
2894                 const memory::desc &data_desc, T epsilon, unsigned flags) {
2895             error::wrap_c_api(
2896                     mkldnn_batch_normalization_backward_desc_init(&data,
2897                         mkldnn::convert_to_c(aprop_kind),
2898                         &diff_data_desc.data, &data_desc.data,
2899                         static_cast<float>(epsilon), flags),
2900                 "could not create a batch normalization backward descriptor");
2901         }
2902     };
2903
2904     struct primitive_desc : public mkldnn::primitive_desc {
2905         primitive_desc(const desc &desc, const engine &e,
2906                 const batch_normalization_forward::primitive_desc &hint_fwd_pd)
2907             : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2908
2909         primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2910                 const batch_normalization_forward::primitive_desc &hint_fwd_pd)
2911             : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2912
2913         REG_QUERY_MPD(src, src, 0);
2914         REG_QUERY_MPD(mean, src, 1);
2915         REG_QUERY_MPD(variance, src, 2);
2916         REG_QUERY_MPD(weights, weights, 0);
2917         REG_QUERY_MPD(dst, dst, 0);
2918         REG_QUERY_MPD(diff_dst, diff_dst, 0);
2919         REG_QUERY_MPD(workspace, workspace, 0);
2920
2921         REG_QUERY_MPD(diff_src, diff_src, 0);
2922         REG_QUERY_MPD(diff_weights, diff_weights, 0);
2923     };
2924
2925     // Prop_kind == backward
2926     batch_normalization_backward(const primitive_desc &aprimitive_desc,
2927             const primitive::at &src, const primitive::at &mean,
2928             const primitive::at &variance, const primitive::at &diff_dst,
2929             const primitive::at &weights, const memory &diff_src,
2930             const memory &diff_weights) {
2931         mkldnn_primitive_t result;
2932         mkldnn_primitive_at_t inputs[] = { src.data,
2933             mean.data, variance.data, diff_dst.data, weights.data };
2934         const_mkldnn_primitive_t outputs[] = { diff_src.get(),
2935                 diff_weights.get() };
2936         check_num_parameters(aprimitive_desc.get(), 5, 2,
2937             "batch normalization backward");
2938         error::wrap_c_api(mkldnn_primitive_create(&result,
2939                 aprimitive_desc.get(), inputs, outputs),
2940             "could not create a batch normalization backward primitive");
2941         reset(result);
2942     }
2943
2944     // Prop_kind == backward (+ws)
2945     batch_normalization_backward(const primitive_desc &aprimitive_desc,
2946             const primitive::at &src, const primitive::at &mean,
2947             const primitive::at &variance, const primitive::at &diff_dst,
2948             const primitive::at &weights, const primitive::at &workspace,
2949             const memory &diff_src, const memory &diff_weights) {
2950         mkldnn_primitive_t result;
2951         mkldnn_primitive_at_t inputs[] = { src.data, mean.data, variance.data,
2952             diff_dst.data, weights.data, workspace.data };
2953         const_mkldnn_primitive_t outputs[] = { diff_src.get(),
2954                 diff_weights.get() };
2955         check_num_parameters(aprimitive_desc.get(), 6, 2,
2956             "batch normalization backward");
2957         error::wrap_c_api(mkldnn_primitive_create(&result,
2958                 aprimitive_desc.get(), inputs, outputs),
2959             "could not create a batch normalization backward primitive");
2960         reset(result);
2961     }
2962
2963     // Prop_kind == backward_data (+ws or +weights)
2964     /// @warning This constructor works for backward_data propagation
2965     ///          - w/ weights but w/o workspace, or
2966     ///          - w/ workspace but w/o weights
2967     batch_normalization_backward(const primitive_desc &aprimitive_desc,
2968             const primitive::at &src, const primitive::at &mean,
2969             const primitive::at &variance,const primitive::at &diff_dst,
2970             const primitive::at &weights_or_workspace, const memory &diff_src) {
2971         mkldnn_primitive_t result;
2972         mkldnn_primitive_at_t inputs[] = { src.data, mean.data, variance.data,
2973             diff_dst.data, weights_or_workspace.data };
2974         const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2975         check_num_parameters(aprimitive_desc.get(), 5, 1,
2976             "batch normalization backward");
2977         error::wrap_c_api(mkldnn_primitive_create(&result,
2978                 aprimitive_desc.get(), inputs, outputs),
2979             "could not create a batch normalization backward primitive");
2980         reset(result);
2981     }
2982
2983     // Prop_kind == backward_data
2984     batch_normalization_backward(const primitive_desc &aprimitive_desc,
2985             const primitive::at &src, const primitive::at &mean,
2986             const primitive::at &variance, const primitive::at &diff_dst,
2987             const memory &diff_src) {
2988         mkldnn_primitive_t result;
2989         mkldnn_primitive_at_t inputs[] = { src.data,
2990             mean.data, variance.data, diff_dst.data };
2991         const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2992         check_num_parameters(aprimitive_desc.get(), 4, 1,
2993             "batch normalization backward");
2994         error::wrap_c_api(mkldnn_primitive_create(&result,
2995                 aprimitive_desc.get(), inputs, outputs),
2996             "could not create a batch normalization backward primitive");
2997         reset(result);
2998     }
2999 };
3000
3001 /// @}
3002
3003 /// @addtogroup cpp_api_inner_product Inner Product
3004 /// A primitive to compute an inner product.
3005 ///
3006 /// @sa @ref c_api_inner_product in @ref c_api
3007 /// @{
3008
3009 struct inner_product_forward: public primitive {
3010     struct desc {
3011         mkldnn_inner_product_desc_t data;
3012         desc(prop_kind aprop_kind, const memory::desc &src_desc,
3013                 const memory::desc &weights_desc,
3014                 const memory::desc &bias_desc,
3015                 const memory::desc &dst_desc) {
3016             error::wrap_c_api(
3017                     mkldnn_inner_product_forward_desc_init(&data,
3018                         mkldnn::convert_to_c(aprop_kind), &src_desc.data,
3019                         &weights_desc.data, &bias_desc.data, &dst_desc.data),
3020                     "could not create a inner product forward descriptor");
3021         }
3022
3023         desc(prop_kind aprop_kind, const memory::desc &src_desc,
3024                 const memory::desc &weights_desc,
3025                 const memory::desc &dst_desc) {
3026             error::wrap_c_api(
3027                     mkldnn_inner_product_forward_desc_init(&data,
3028                         mkldnn::convert_to_c(aprop_kind), &src_desc.data,
3029                         &weights_desc.data, nullptr, &dst_desc.data),
3030                     "could not create a inner product forward descriptor");
3031         }
3032     };
3033
3034     struct primitive_desc : public mkldnn::primitive_desc {
3035         primitive_desc(const desc &desc, const engine &e)
3036             : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
3037
3038         primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
3039             : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
3040
3041         REG_QUERY_MPD(src, src, 0);
3042         REG_QUERY_MPD(weights, weights, 0);
3043         REG_QUERY_MPD(bias, weights, 1);
3044         REG_QUERY_MPD(dst, dst, 0);
3045     };
3046
3047     inner_product_forward(const primitive_desc &aprimitive_desc,
3048             const primitive::at &src, const primitive::at weights,
3049             const primitive::at &bias, const memory &dst) {
3050         mkldnn_primitive_t result;
3051         mkldnn_primitive_at_t inputs[] = { src.data, weights.data,
3052                 bias.data };
3053         const_mkldnn_primitive_t outputs[] = { dst.get() };
3054         check_num_parameters(aprimitive_desc.get(), 3, 1,
3055             "inner product forward");
3056         error::wrap_c_api(mkldnn_primitive_create(&result,
3057                 aprimitive_desc.get(), inputs, outputs),
3058             "could not create a inner product forward primitive");
3059         reset(result);
3060     }
3061
3062     inner_product_forward(const primitive_desc &aprimitive_desc,
3063             const primitive::at &src, const primitive::at weights,
3064             const memory &dst) {
3065         mkldnn_primitive_t result;
3066         mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
3067         const_mkldnn_primitive_t outputs[] = { dst.get() };
3068         check_num_parameters(aprimitive_desc.get(), 2, 1,
3069             "inner product forward");
3070         error::wrap_c_api(mkldnn_primitive_create(&result,
3071                 aprimitive_desc.get(), inputs, outputs),
3072             "could not create a inner product forward primitive");
3073         reset(result);
3074     }
3075 };
3076
3077 struct inner_product_backward_data: public primitive {
3078     struct desc {
3079         mkldnn_inner_product_desc_t data;
3080         desc(const memory::desc &diff_src_desc,
3081                 const memory::desc &weights_desc,
3082                 const memory::desc &diff_dst_desc) {
3083             error::wrap_c_api(
3084                     mkldnn_inner_product_backward_data_desc_init(&data,
3085                         &diff_src_desc.data, &weights_desc.data,
3086                         &diff_dst_desc.data),
3087                 "could not create a inner product backward data descriptor");
3088         }
3089     };
3090
3091     struct primitive_desc : public mkldnn::primitive_desc {
3092         primitive_desc(const desc &desc, const engine &e,
3093                 const inner_product_forward::primitive_desc &hint_fwd_pd)
3094             : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
3095
3096         primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
3097                 const inner_product_forward::primitive_desc &hint_fwd_pd)
3098             : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
3099
3100         REG_QUERY_MPD(diff_src, diff_src, 0);
3101         REG_QUERY_MPD(weights, weights, 0);
3102         REG_QUERY_MPD(diff_dst, diff_dst, 0);
3103     };
3104
3105     inner_product_backward_data(const primitive_desc &aprimitive_desc,
3106             const primitive::at &diff_dst, const primitive::at weights,
3107             const memory &diff_src) {
3108         mkldnn_primitive_t result;
3109         mkldnn_primitive_at_t inputs[] = { diff_dst.data, weights.data };
3110         const_mkldnn_primitive_t outputs[] = { diff_src.get() };
3111         check_num_parameters(aprimitive_desc.get(), 2, 1,
3112             "inner product backward data");
3113         error::wrap_c_api(mkldnn_primitive_create(&result,
3114                 aprimitive_desc.get(), inputs, outputs),
3115             "could not create a inner product backward data primitive");
3116         reset(result);
3117     }
3118 };
3119
3120 struct inner_product_backward_weights: public primitive {
3121     struct desc {
3122         mkldnn_inner_product_desc_t data;
3123         desc(const memory::desc &src_desc,
3124                 const memory::desc &diff_weights_desc,
3125                 const memory::desc &diff_bias_desc,
3126                 const memory::desc &diff_dst_desc) {
3127             error::wrap_c_api(
3128                     mkldnn_inner_product_backward_weights_desc_init(
3129                         &data, &src_desc.data, &diff_weights_desc.data,
3130                         &diff_bias_desc.data, &diff_dst_desc.data),
3131                 "could not create a inner product backward weights descriptor");
3132         }
3133         desc(const memory::desc &src_desc,
3134                 const memory::desc &diff_weights_desc,
3135                 const memory::desc &diff_dst_desc) {
3136             error::wrap_c_api(
3137                     mkldnn_inner_product_backward_weights_desc_init(
3138                         &data, &src_desc.data, &diff_weights_desc.data,
3139                         nullptr, &diff_dst_desc.data),
3140                 "could not create a inner product backward weights descriptor");
3141         }
3142     };
3143
3144     struct primitive_desc : public mkldnn::primitive_desc {
3145         primitive_desc(const desc &desc, const engine &e,
3146                 const inner_product_forward::primitive_desc &hint_fwd_pd)
3147             : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
3148
3149         primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
3150                 const inner_product_forward::primitive_desc &hint_fwd_pd)
3151             : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
3152
3153         REG_QUERY_MPD(src, src, 0);
3154         REG_QUERY_MPD(diff_weights, diff_weights, 0);
3155         REG_QUERY_MPD(diff_bias, diff_weights, 1);
3156         REG_QUERY_MPD(diff_dst, diff_dst, 0);
3157     };
3158
3159     inner_product_backward_weights(const primitive_desc &aprimitive_desc,
3160             const primitive::at &src, const primitive::at diff_dst,
3161             const memory &diff_weights) {
3162         mkldnn_primitive_t result;
3163         mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
3164         const_mkldnn_primitive_t outputs[] = { diff_weights.get() };
3165         check_num_parameters(aprimitive_desc.get(), 2, 1,
3166             "inner product backward weights");
3167         error::wrap_c_api(mkldnn_primitive_create(&result,
3168                 aprimitive_desc.get(), inputs, outputs),
3169             "could not create a inner product backward weights primitive");
3170         reset(result);
3171     }
3172
3173     inner_product_backward_weights(const primitive_desc &aprimitive_desc,
3174             const primitive::at &src, const primitive::at diff_dst,
3175             const memory &diff_weights, const memory &diff_bias) {
3176         mkldnn_primitive_t result;
3177         mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
3178         const_mkldnn_primitive_t outputs[] =
3179                 { diff_weights.get(), diff_bias.get()};
3180         check_num_parameters(aprimitive_desc.get(), 2, 2,
3181             "inner product backward weights");
3182         error::wrap_c_api(mkldnn_primitive_create(&result,
3183                 aprimitive_desc.get(), inputs, outputs),
3184             "could not create a inner product backward weights primitive");
3185         reset(result);
3186     }
3187 };
3188
3189 /// @}
3190
3191 /// @addtogroup cpp_api_rnn RNN
3192 /// A primitive to compute common recurrent layer.
3193 ///
3194 /// @sa @ref c_api_rnn in @ref c_api
3195 /// @{
3196
3197 struct rnn_cell {
3198     struct desc {
3199         mkldnn_rnn_cell_desc_t c_rnn_cell_;
3200
3201         desc(algorithm kind, algorithm activation_f) {
3202             error::wrap_c_api(mkldnn_rnn_cell_desc_init(&c_rnn_cell_,
3203                         mkldnn::convert_to_c(kind),
3204                         mkldnn::convert_to_c(activation_f), 0U, 0, 0),
3205                     "could not init an rnn cell descriptor");
3206         }
3207         desc(algorithm kind): desc(kind, algorithm::algorithm_undef) {}
3208
3209         operator const mkldnn_rnn_cell_desc_t*() const { return &c_rnn_cell_; }
3210
3211         algorithm get_cell_kind() const
3212         { return algorithm(c_rnn_cell_.cell_kind); }
3213         algorithm get_activation() const
3214         { return algorithm(c_rnn_cell_.activation_kind); }
3215
3216         float get_alpha() const { return c_rnn_cell_.alpha; }
3217         void set_alpha(float alpha) {
3218             c_rnn_cell_.flags |= mkldnn_rnn_cell_with_relu;
3219             c_rnn_cell_.alpha = alpha;
3220         }
3221
3222         float get_clipping() const { return c_rnn_cell_.clipping; }
3223         void set_clipping(float clipping) {
3224             c_rnn_cell_.flags |= mkldnn_rnn_cell_with_clipping;
3225             c_rnn_cell_.clipping = clipping;
3226         }
3227
3228         int get_gates_count() const {
3229             return mkldnn_rnn_cell_get_gates_count(&c_rnn_cell_);
3230         }
3231         int get_state_count() const {
3232             return mkldnn_rnn_cell_get_states_count(&c_rnn_cell_);
3233         }
3234     };
3235 };
3236
3237 struct rnn_forward : public primitive {
3238     struct desc {
3239         mkldnn_rnn_desc_t data;
3240         desc(prop_kind aprop_kind, rnn_cell::desc cell,
3241                 const rnn_direction direction,
3242                 const memory::desc &src_layer_desc,
3243                 const memory::desc &src_iter_desc,
3244                 const memory::desc &weights_layer_desc,
3245                 const memory::desc &weights_iter_desc,
3246                 const memory::desc &bias_desc,
3247                 const memory::desc &dst_layer_desc,
3248                 const memory::desc &dst_iter_desc
3249             ) {
3250             error::wrap_c_api(mkldnn_rnn_forward_desc_init(&data,
3251                         mkldnn::convert_to_c(aprop_kind), cell,
3252                         mkldnn::convert_to_c(direction),
3253                         &src_layer_desc.data, &src_iter_desc.data,
3254                         &weights_layer_desc.data, &weights_iter_desc.data,
3255                         &bias_desc.data,
3256                         &dst_layer_desc.data, &dst_iter_desc.data),
3257                     "could not create an RNN forward descriptor");
3258         }
3259
3260     };
3261
3262     struct primitive_desc : public mkldnn::primitive_desc {
3263         primitive_desc(const desc &desc, const engine &e)
3264             : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
3265
3266         primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
3267             : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
3268
3269         REG_QUERY_MPD(src_layer, src, 0);
3270         REG_QUERY_MPD(src_iter, src, 1);
3271         REG_QUERY_MPD(weights_layer, weights, 0);
3272         REG_QUERY_MPD(weights_iter, weights, 1);
3273         REG_QUERY_MPD(bias, weights, 2);
3274         REG_QUERY_MPD(dst_layer, dst, 0);
3275         REG_QUERY_MPD(dst_iter, dst, 1);
3276         REG_QUERY_MPD(workspace, workspace, 0);
3277     };
3278
3279     rnn_forward(const primitive_desc &aprimitive_desc,
3280             const primitive::at &src_layer, const primitive::at &src_iter,
3281             const primitive::at &weights_layer,
3282             const primitive::at &weights_iter, const primitive::at &bias,
3283             const memory &dst_layer, const memory &dst_iter,
3284             const memory &workspace) {
3285         mkldnn_primitive_t result;
3286         mkldnn_primitive_at_t inputs[5];
3287         const_mkldnn_primitive_t outputs[3];
3288         int idx=0;
3289         inputs[idx++] = src_layer.data;
3290         if (!is_null_memory(src_iter.data.primitive))
3291             inputs[idx++] = src_iter.data;
3292         inputs[idx++] = weights_layer.data;
3293         inputs[idx++] = weights_iter.data;
3294         if (!is_null_memory(bias.data.primitive)) inputs[idx++] = bias.data;
3295
3296         idx=0;
3297         outputs[idx++] = dst_layer.get();
3298         if (!is_null_memory(dst_iter.get())) outputs[idx++] = dst_iter.get();
3299         if (!is_null_memory(workspace.get())) outputs[idx++] = workspace.get();
3300
3301         error::wrap_c_api(mkldnn_primitive_create(&result,
3302                     aprimitive_desc.get(), inputs, outputs),
3303                 "could not create an RNN forward primitive");
3304         reset(result);
3305     }
3306 };
3307
3308 struct rnn_backward : public primitive {
3309     struct desc {
3310         mkldnn_rnn_desc_t data;
3311         desc(prop_kind aprop_kind, rnn_cell::desc cell,
3312                 const rnn_direction direction,
3313                 const memory::desc &src_layer_desc,
3314                 const memory::desc &src_iter_desc,
3315                 const memory::desc &weights_layer_desc,
3316                 const memory::desc &weights_iter_desc,
3317                 const memory::desc &bias_desc,
3318                 const memory::desc &dst_layer_desc,
3319                 const memory::desc &dst_iter_desc,
3320                 const memory::desc &diff_src_layer_desc,
3321                 const memory::desc &diff_src_iter_desc,
3322                 const memory::desc &diff_weights_layer_desc,
3323                 const memory::desc &diff_weights_iter_desc,
3324                 const memory::desc &diff_bias_desc,
3325                 const memory::desc &diff_dst_layer_desc,
3326                 const memory::desc &diff_dst_iter_desc) {
3327             error::wrap_c_api(mkldnn_rnn_backward_desc_init(&data,
3328                         mkldnn::convert_to_c(aprop_kind), cell,
3329                         mkldnn::convert_to_c(direction),
3330                         &src_layer_desc.data, &src_iter_desc.data,
3331                         &weights_layer_desc.data, &weights_iter_desc.data,
3332                         &bias_desc.data,
3333                         &dst_layer_desc.data, &dst_iter_desc.data,
3334                         &diff_src_layer_desc.data, &diff_src_iter_desc.data,
3335                         &diff_weights_layer_desc.data,
3336                         &diff_weights_iter_desc.data, &diff_bias_desc.data,
3337                         &diff_dst_layer_desc.data, &diff_dst_iter_desc.data),
3338                     "could not create an RNN backward descriptor");
3339         }
3340
3341     };
3342
3343     struct primitive_desc : public mkldnn::primitive_desc {
3344         primitive_desc(const desc &desc, const engine &e,
3345                 const rnn_forward::primitive_desc &hint_fwd_pd)
3346             : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
3347
3348         primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
3349                 const rnn_forward::primitive_desc &hint_fwd_pd)
3350             : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
3351
3352         REG_QUERY_MPD(src_layer, src, 0);
3353         REG_QUERY_MPD(src_iter, src, 1);
3354         REG_QUERY_MPD(weights_layer, weights, 0);
3355         REG_QUERY_MPD(weights_iter, weights, 1);
3356         REG_QUERY_MPD(bias, weights, 2);
3357         REG_QUERY_MPD(dst_layer, dst, 0);
3358         REG_QUERY_MPD(dst_iter, dst, 1);
3359         REG_QUERY_MPD(workspace, workspace, 0);
3360
3361         REG_QUERY_MPD(diff_src_layer, diff_src, 0);
3362         REG_QUERY_MPD(diff_src_iter, diff_src, 1);
3363         REG_QUERY_MPD(diff_weights_layer, diff_weights, 0);
3364         REG_QUERY_MPD(diff_weights_iter, diff_weights, 1);
3365         REG_QUERY_MPD(diff_bias, diff_weights, 2);
3366         REG_QUERY_MPD(diff_dst_layer, diff_dst, 0);
3367         REG_QUERY_MPD(diff_dst_iter, diff_dst, 1);
3368     };
3369
3370     // With last iteration (with and without input src_iter)
3371     rnn_backward(const primitive_desc &aprimitive_desc,
3372                  const primitive::at &src_layer,
3373                  const primitive::at &src_iter,
3374                  const primitive::at &weights_layer,
3375                  const primitive::at &weights_iter,
3376                  const primitive::at &bias,
3377                  const primitive::at &dst_layer,
3378                  const primitive::at &dst_iter,
3379                  const memory &diff_src_layer,
3380                  const memory &diff_src_iter,
3381                  const memory &diff_weights_layer,
3382                  const memory &diff_weights_iter,
3383                  const memory &diff_bias,
3384                  const primitive::at &diff_dst_layer,
3385                  const primitive::at &diff_dst_iter,
3386                  const primitive::at &workspace) {
3387         mkldnn_primitive_t result;
3388         mkldnn_primitive_at_t inputs[10];
3389         const_mkldnn_primitive_t outputs[5];
3390         int idx=0;
3391         inputs[idx++] = src_layer.data;
3392         if (!is_null_memory(src_iter.data.primitive))
3393             inputs[idx++] = src_iter.data;
3394         inputs[idx++] = weights_layer.data;
3395         inputs[idx++] = weights_iter.data;
3396         if (!is_null_memory(bias.data.primitive))
3397             inputs[idx++] = bias.data;
3398         inputs[idx++] = dst_layer.data;
3399         if (!is_null_memory(dst_iter.data.primitive))
3400             inputs[idx++] = dst_iter.data;
3401         inputs[idx++] = diff_dst_layer.data;
3402         if (!is_null_memory(diff_dst_iter.data.primitive))
3403             inputs[idx++] = diff_dst_iter.data;
3404         inputs[idx++] = workspace.data;
3405
3406         idx = 0;
3407         outputs[idx++] = diff_src_layer.get();
3408         if (!is_null_memory(diff_src_iter.get()))
3409             outputs[idx++] = diff_src_iter.get();
3410         outputs[idx++] = diff_weights_layer.get();
3411         outputs[idx++] = diff_weights_iter.get();
3412         if (!is_null_memory(diff_bias.get())) outputs[idx++] = diff_bias.get();
3413         error::wrap_c_api(mkldnn_primitive_create(&result,
3414                     aprimitive_desc.get(), inputs, outputs),
3415                 "could not create an RNN backward primitive");
3416         reset(result);
3417     }
3418 };
3419
3420 /// @}
3421
3422 /// @addtogroup cpp_api_shuffle Shuffle
3423 /// A primitive to shuffle data along the axis.
3424 ///
3425 /// @sa @ref c_api_shuffle in @ref c_api
3426 /// @{
3427
3428 struct shuffle_forward : public primitive {
3429     struct desc {
3430         mkldnn_shuffle_desc_t data;
3431         desc(prop_kind aprop_kind, const memory::desc &data_desc,
3432                 int axis, int group_size) {
3433             error::wrap_c_api(mkldnn_shuffle_forward_desc_init(&data,
3434                         mkldnn::convert_to_c(aprop_kind), &data_desc.data,
3435                         axis, group_size),
3436                     "could not create a shuffle forward descriptor");
3437         }
3438     };
3439
3440     struct primitive_desc : public mkldnn::primitive_desc {
3441         primitive_desc(const desc &desc, const engine &e)
3442             : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
3443
3444         REG_QUERY_MPD(src, src, 0);
3445         REG_QUERY_MPD(dst, dst, 0);
3446     };
3447
3448     shuffle_forward(const primitive_desc &aprimitive_desc,
3449             const primitive::at &src, const memory &dst) {
3450         mkldnn_primitive_t result;
3451         mkldnn_primitive_at_t inputs[] = { src.data };
3452         const_mkldnn_primitive_t outputs[] = { dst.get() };
3453         check_num_parameters(aprimitive_desc.get(), 1, 1, "shuffle forward");
3454         error::wrap_c_api(mkldnn_primitive_create(&result,
3455             aprimitive_desc.get(), inputs, outputs),
3456             "could not create a shuffle forward primitive");
3457         reset(result);
3458     }
3459 };
3460
3461 struct shuffle_backward : public primitive {
3462     struct desc {
3463         mkldnn_shuffle_desc_t data;
3464         desc(const memory::desc &diff_data_desc, int axis, int group_size) {
3465             error::wrap_c_api(mkldnn_shuffle_backward_desc_init(&data,
3466                         &diff_data_desc.data, axis, group_size),
3467                     "could not create a shuffle backward descriptor");
3468         }
3469     };
3470
3471     struct primitive_desc : public mkldnn::primitive_desc {
3472         primitive_desc(const desc &desc, const engine &e,
3473                 const shuffle_forward::primitive_desc &hint_fwd_pd)
3474             : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
3475
3476         REG_QUERY_MPD(diff_src, diff_src, 0);
3477         REG_QUERY_MPD(diff_dst, diff_dst, 0);
3478     };
3479
3480     shuffle_backward(const primitive_desc &aprimitive_desc,
3481             const primitive::at &diff_dst, const memory &diff_src) {
3482         mkldnn_primitive_t result;
3483         mkldnn_primitive_at_t inputs[] = { diff_dst.data};
3484         const_mkldnn_primitive_t outputs[] = { diff_src.get() };
3485         check_num_parameters(aprimitive_desc.get(), 1, 1, "shuffle backward");
3486         error::wrap_c_api(mkldnn_primitive_create(&result,
3487             aprimitive_desc.get(), inputs, outputs),
3488             "could not create a shuffle backward primitive");
3489         reset(result);
3490     }
3491 };
3492
3493 /// @}
3494
3495 /// @addtogroup cpp_api_binary_convolution Binary convolution
3496 /// A primitive to compute binary convolution using different algorithms.
3497 ///
3498 /// @sa @ref c_api_binary_convolution in @ref c_api
3499 /// @{
3500
3501 struct binary_convolution_forward: public primitive {
3502     struct desc {
3503         mkldnn_binary_convolution_desc_t data;
3504         desc(prop_kind aprop_kind, algorithm aalgorithm,
3505                 const memory::desc &src_desc,
3506                 const memory::desc &weights_desc,
3507                 const memory::desc &dst_desc,
3508                 const memory::dims strides,
3509                 const memory::dims dilates,
3510                 const memory::dims padding_l,
3511                 const memory::dims padding_r,
3512                 const float pad_value) {
3513             memory::validate_dims(strides);
3514             memory::validate_dims(dilates);
3515             memory::validate_dims(padding_l);
3516             memory::validate_dims(padding_r);
3517             error::wrap_c_api(
3518                 mkldnn_dilated_binary_convolution_forward_desc_init(&data,
3519                     mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
3520                         &src_desc.data, &weights_desc.data, &dst_desc.data,
3521                         &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
3522                         pad_value),
3523                     "could not create a dilated binary convolution forward descriptor");
3524         }
3525     };
3526
3527     struct primitive_desc : public mkldnn::primitive_desc {
3528         primitive_desc(const desc &desc, const engine &e)
3529             : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
3530
3531         primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
3532             : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
3533
3534         REG_QUERY_MPD(src, src, 0);
3535         REG_QUERY_MPD(weights, weights, 0);
3536         REG_QUERY_MPD(dst, dst, 0);
3537     };
3538
3539     binary_convolution_forward(const primitive_desc &aprimitive_desc,
3540             const primitive::at &src, const primitive::at &weights, const memory &dst) {
3541         mkldnn_primitive_t result;
3542         mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
3543         const_mkldnn_primitive_t outputs[] = { dst.get() };
3544         check_num_parameters(aprimitive_desc.get(), 2, 1,
3545             "binary convolution forward");
3546         error::wrap_c_api(mkldnn_primitive_create(&result,
3547                     aprimitive_desc.get(), inputs, outputs),
3548                 "could not create a binary convolution forward primitive");
3549         reset(result);
3550     }
3551 };
3552
3553 /// @}
3554
3555 /// @addtogroup cpp_api_binarization Binarization
3556 /// @{
3557
3558 struct binarization_forward : public primitive {
3559     struct desc {
3560         mkldnn_binarization_desc_t data;
3561
3562         desc(prop_kind aprop_kind, algorithm alg_kind,
3563              const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &output_mask_desc,
3564              const memory::desc &dst_desc) {
3565             error::wrap_c_api(mkldnn_binarization_forward_desc_init(&data,
3566                                                                  mkldnn::convert_to_c(aprop_kind),
3567                                                                  mkldnn::convert_to_c(alg_kind),
3568                                                                  &src_desc.data, &dst_desc.data,
3569                                                                  &weights_desc.data, &output_mask_desc.data),
3570                               "could not create a binarization forward descriptor");
3571         }
3572     };
3573
3574     struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
3575         primitive_desc(const desc &adesc, const engine &aengine) {
3576             mkldnn_primitive_desc_t result;
3577             error::wrap_c_api(mkldnn_primitive_desc_create(
3578                     &result, &adesc.data, aengine.get(), nullptr),
3579                               "could not create a binarization forward primitive descriptor");
3580             reset(result);
3581         }
3582
3583         engine get_engine() { return engine::query(*this); }
3584     };
3585
3586     binarization_forward(const primitive_desc &aprimitive_desc,
3587                       const primitive::at &src, const primitive::at &weights, const primitive::at &output_mask,
3588                       const memory &dst) {
3589         mkldnn_primitive_t result;
3590         mkldnn_primitive_at_t inputs[] = { src.data, weights.data, output_mask.data};
3591         const_mkldnn_primitive_t outputs[] = { dst.get() };
3592         error::wrap_c_api(mkldnn_primitive_create(&result, aprimitive_desc.get(), inputs, outputs),
3593                           "could not create a binarization forward primitive");
3594         reset(result);
3595     }
3596 };
3597
3598 /// @}
3599
3600 /// @addtogroup cpp_api_deformable_convolution Deformable convolution
3601 /// A primitive to compute deformable convolution.
3602 ///
3603 /// @sa @ref c_api_deformable_convolution in @ref c_api
3604 /// @{
3605
3606 struct deformable_convolution_forward: public primitive {
3607     struct desc {
3608         mkldnn_deformable_convolution_desc_t data;
3609         std::vector<mkldnn_memory_desc_t> c_api_inputs;
3610
3611         desc(prop_kind aprop_kind, algorithm aalgorithm, std::vector<memory::desc> inputs,
3612              const memory::desc &weights_desc,
3613              const memory::desc &bias_desc,
3614              const memory::desc &dst_desc,
3615              const memory::dims strides,
3616              const memory::dims dilates,
3617              const memory::dims padding_l,
3618              const memory::dims padding_r,
3619              const padding_kind apadding_kind,
3620              const int deformable_group) {
3621
3622             for (size_t i = 0; i < inputs.size(); i++) {
3623                 c_api_inputs.push_back(inputs[i].data);
3624             }
3625
3626             memory::validate_dims(strides);
3627             memory::validate_dims(dilates);
3628             memory::validate_dims(padding_l);
3629             memory::validate_dims(padding_r);
3630             error::wrap_c_api(mkldnn_deformable_convolution_forward_desc_init(&data,
3631                                                                    mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
3632                                                                    &c_api_inputs[0], c_api_inputs.size(), &weights_desc.data, &bias_desc.data,
3633                                                                    &dst_desc.data, &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
3634                                                                    mkldnn::convert_to_c(apadding_kind), deformable_group),
3635                               "could not create a deformable convolution forward descriptor");
3636         }
3637         desc(prop_kind aprop_kind, algorithm aalgorithm, std::vector<memory::desc> inputs,
3638              const memory::desc &weights_desc,
3639              const memory::desc &dst_desc,
3640              const memory::dims strides,
3641              const memory::dims dilates,
3642              const memory::dims padding_l,
3643              const memory::dims padding_r,
3644              const padding_kind apadding_kind,
3645              const int deformable_group) {
3646
3647             for (size_t i = 0; i < inputs.size(); i++) {
3648                 c_api_inputs.push_back(inputs[i].data);
3649             }
3650
3651             memory::validate_dims(strides);
3652             memory::validate_dims(dilates);
3653             memory::validate_dims(padding_l);
3654             memory::validate_dims(padding_r);
3655             error::wrap_c_api(mkldnn_deformable_convolution_forward_desc_init(&data,
3656                                                                    mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
3657                                                                    &c_api_inputs[0], c_api_inputs.size(), &weights_desc.data, nullptr,
3658                                                                    &dst_desc.data, &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
3659                                                                    mkldnn::convert_to_c(apadding_kind), deformable_group),
3660                               "could not create a deformable convolution forward descriptor");
3661         }
3662     };
3663
3664     struct primitive_desc : public mkldnn::primitive_desc {
3665         primitive_desc(const desc &desc, const engine &e)
3666                 : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
3667
3668         primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
3669                 : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
3670
3671         REG_QUERY_MPD(src, src, 0);
3672         REG_QUERY_MPD(weights, weights, 0);
3673         REG_QUERY_MPD(bias, weights, 1);
3674         REG_QUERY_MPD(dst, dst, 0);
3675     };
3676
3677     deformable_convolution_forward(const primitive_desc &aprimitive_desc,
3678                         const std::vector<primitive::at> &inputs, const primitive::at &weights,
3679                         const primitive::at &bias, const memory &dst) {
3680         mkldnn_primitive_t result;
3681
3682         mkldnn_primitive_at_t p_inputs[] = { inputs[0].data, inputs[1].data, weights.data,
3683                                            bias.data };
3684         const_mkldnn_primitive_t outputs[] = { dst.get() };
3685         error::wrap_c_api(mkldnn_primitive_create(&result,
3686                                                   aprimitive_desc.get(), &p_inputs[0], outputs),
3687                           "could not create a deformable convolution forward bias primitive");
3688         reset(result);
3689     }
3690
3691     deformable_convolution_forward(const primitive_desc &aprimitive_desc,
3692                         const std::vector<primitive::at> &inputs, const primitive::at &weights, const memory &dst) {
3693         mkldnn_primitive_t result;
3694         mkldnn_primitive_at_t p_inputs[] = { inputs[0].data, inputs[1].data, weights.data, };
3695         const_mkldnn_primitive_t outputs[] = { dst.get() };
3696         check_num_parameters(aprimitive_desc.get(), 3, 1,
3697                              "deformable convolution forward");
3698         error::wrap_c_api(mkldnn_primitive_create(&result,
3699                                                   aprimitive_desc.get(), &p_inputs[0], outputs),
3700                           "could not create a deformable convolution forward primitive");
3701         reset(result);
3702     }
3703 };
3704
3705 /// @}
3706
3707 /// @} Primitives
3708
3709 /// @addtogroup cpp_api_stream Stream
3710 /// Execution stream operations.
3711 ///
3712 /// @sa @ref c_api_stream in @ref c_api
3713 /// @{
3714
3715 #ifndef DOXYGEN_SHOULD_SKIP_THIS
3716 template <> struct handle_traits<mkldnn_stream_t> {
3717     static constexpr auto destructor = &mkldnn_stream_destroy;
3718 };
3719 #endif
3720
3721 struct stream: public handle<mkldnn_stream_t> {
3722     using handle::handle;
3723
3724     enum kind { any = mkldnn_stream_kind_t::mkldnn_any_stream,
3725         eager = mkldnn_stream_kind_t::mkldnn_eager,
3726         lazy = mkldnn_stream_kind_t::mkldnn_lazy };
3727
3728     static mkldnn_stream_kind_t convert_to_c(kind akind) {
3729         return static_cast<mkldnn_stream_kind_t>(akind);
3730     }
3731     /// Constructs a stream.
3732     stream(kind akind) {
3733         mkldnn_stream_t astream;
3734         error::wrap_c_api(mkldnn_stream_create(&astream,
3735                     convert_to_c(akind)),
3736                 "could not create a stream");
3737         reset(astream);
3738     }
3739
3740     /// Submits a vector of primitives to a stream for computations.
3741     ///
3742     /// @param primitives The vector of primitives to submit.
3743     /// @returns The stream.
3744     stream &submit(std::vector<primitive> primitives) {
3745         // TODO: find a proper way to convert vector<primitive> to
3746         // vector<mkldnn_primitive_t>
3747         if (primitives.size() == 0) return *this;
3748         std::vector<mkldnn_primitive_t> c_api_primitives;
3749         c_api_primitives.reserve(primitives.size());
3750         auto convert_to_c = [](primitive p) { return p.get(); };
3751         std::transform(primitives.begin(), primitives.end(),
3752                 std::back_inserter(c_api_primitives), convert_to_c);
3753
3754         mkldnn_primitive_t c_api_error_primitive;
3755         error::wrap_c_api(
3756                 mkldnn_stream_submit(get(),
3757                     c_api_primitives.size(), &c_api_primitives[0],
3758                     &c_api_error_primitive),
3759                 "could not submit primitives to a stream",
3760                 &c_api_error_primitive);
3761
3762         return *this;
3763     }
3764
3765     /// Waits for all computations submitted to the stream to complete.
3766     ///
3767     /// @param block Specifies whether the operation should wait indefinitely or
3768     ///              return immediately.
3769     /// @returns @c true if all computations completed.
3770     /// @returns @c false if not all computations completed.
3771     bool wait(bool block = true) {
3772         mkldnn_primitive_t c_api_error_primitive;
3773         mkldnn_status_t status = mkldnn_stream_wait(get(),
3774                 block, &c_api_error_primitive);
3775         if (status != mkldnn_success
3776                 && status != mkldnn_try_again)
3777             error::wrap_c_api(status, "could not wait on a stream",
3778                     &c_api_error_primitive);
3779         return (status == mkldnn_success);
3780     }
3781
3782     stream &rerun() {
3783         mkldnn_primitive_t c_api_error_primitive;
3784         error::wrap_c_api(
3785                 mkldnn_stream_rerun(get(), &c_api_error_primitive),
3786                 "could not rerun a stream", &c_api_error_primitive);
3787         return *this;
3788     }
3789 };
3790
3791 #undef REG_QUERY_MPD
3792
3793 /// @}
3794
3795 /// @} C++ API
3796
3797 } // namespace mkldnn
3798
3799 #endif