Publishing 2019 R1.1 content and Myriad plugin sources (#162)
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / include / mkldnn.hpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef MKLDNN_HPP
18 #define MKLDNN_HPP
19
20 #ifndef DOXYGEN_SHOULD_SKIP_THIS
21 #include <stdlib.h>
22 #include <memory>
23 #include <vector>
24 #include <algorithm>
25 #include <iterator>
26 #include <string>
27
28 #include "mkldnn.h"
29 #endif
30
31 namespace mkldnn {
32
33 /// @addtogroup cpp_api C++ API
34 /// @{
35
36 /// @addtogroup cpp_api_utils Utils
37 /// @{
38
39 /// A class that provides the destructor for an Intel(R) MKL-DNN C handle
40 template <typename T> class handle_traits {};
41
42 /// A class for wrapping an Intel(R) MKL-DNN handle. It is used as the base
43 /// class for primitive (#mkldnn_primitive_t), engine (#mkldnn_engine_t), and
44 /// stream (#mkldnn_stream_t) handles. An object of the #mkldnn::handle class
45 /// can be passed by value. This class enables wrapping:
46 ///  - Newly constructed handles.
47 ///    @n In this case, the constructed handle uses reference counting provided
48 ///    by @p std::shared_ptr with a proper deleter function specified through
49 ///    the @p handle_traits class.
50 ///  - Pre-existing handles returned by the Intel(R) MKL-DNN C API (for
51 ///    example, through #mkldnn_primitive_get_output()).
52 ///    @n In this case, an Intel(R) MKL-DNN C API handle is wrapped without a
53 ///    deleter because it is assumed that the handle wrapper for the original
54 ///    object deletes the handle (this model is similar to @p std::weak_ptr).
55 template <typename T, typename traits=handle_traits<T>> class handle {
56 private:
57     std::shared_ptr<typename std::remove_pointer<T>::type> _data;
58     handle(const handle &&) = delete;
59     handle &operator=(const handle &&other) = delete;
60 protected:
61     bool operator==(const T other) const { return other == _data.get(); }
62     bool operator!=(const T other) const { return !(*this == other); }
63 public:
64     /// Constructs a C handle wrapper.
65     /// @param t The C handle to wrap.
66     /// @param weak A flag to specify whether to construct a weak wrapper.
67     handle(T t = 0, bool weak = false): _data(0) {
68         reset(t, weak);
69     }
70
71     handle(const handle &other): _data(other._data) {}
72     handle &operator=(const handle &other) {
73         _data = other._data;
74         return *this;
75     }
76     /// Resets the value of a C handle.
77     /// @param t The new value of the C handle.
78     /// @param weak A flag to specify whether the wrapper should be weak.
79     void reset(T t, bool weak = false) {
80         if (weak) _data.reset(t, [](T) { return decltype(traits::destructor(0))(0); });
81         else _data.reset(t, traits::destructor);
82     }
83
84     /// Returns the value of the underlying C handle.
85     T get() const { return _data.get(); }
86
87     bool operator==(const handle &other) const { return other._data.get() == _data.get(); }
88     bool operator!=(const handle &other) const { return !(*this == other); }
89 };
90
91 #ifndef DOXYGEN_SHOULD_SKIP_THIS
92 template <> struct handle_traits<mkldnn_primitive_desc_t> {
93     static constexpr auto destructor = &mkldnn_primitive_desc_destroy;
94 };
95
96 template <> struct handle_traits<mkldnn_primitive_t> {
97     static constexpr auto destructor = &mkldnn_primitive_destroy;
98 };
99
100 template <> struct handle_traits<mkldnn_primitive_desc_iterator_t> {
101     static constexpr auto destructor = &mkldnn_primitive_desc_iterator_destroy;
102 };
103 #endif
104
105 /// Base class for all computational primitives.
106 class primitive: public handle<mkldnn_primitive_t> {
107     friend struct error;
108     friend struct stream;
109     friend class primitive_at;
110     using handle::handle;
111 public:
112     /// A proxy to C primitive kind enum
113     enum class kind {
114         undefined_primitive = mkldnn_undefined_primitive,
115         memory = mkldnn_memory,
116         view = mkldnn_view,
117         reorder = mkldnn_reorder,
118         concat = mkldnn_concat,
119         concat_inplace = mkldnn_concat_inplace,
120         sum = mkldnn_sum,
121         convolution = mkldnn_convolution,
122         deconvolution = mkldnn_deconvolution,
123         shuffle = mkldnn_shuffle,
124         eltwise = mkldnn_eltwise,
125         depthwise = mkldnn_depthwise,
126         softmax = mkldnn_softmax,
127         pooling = mkldnn_pooling,
128         lrn = mkldnn_lrn,
129         batch_normalization = mkldnn_batch_normalization,
130         inner_product = mkldnn_inner_product,
131         rnn = mkldnn_rnn,
132         binary_convolution = mkldnn_binary_convolution,
133         binarization = mkldnn_binarization,
134     };
135
136     /// A wrapper structure to specify a particular output of a primitive.
137     struct at {
138         /// The underlying C API structure.
139         mkldnn_primitive_at_t data;
140         /// Constructs a wrapper specifying @p aprimitive output with index @p
141         /// at.
142         ///
143         /// @param aprimitive The target primitive.
144         /// @param at The output index.
145
146         at(const primitive &aprimitive, size_t at = 0)
147             : data(mkldnn_primitive_at(aprimitive.get(), at)) {}
148         /// Returns the specified output.
149         inline operator primitive() const;
150     };
151
152     /// Returns the descriptor of the underlying C API primitive.
153     inline const_mkldnn_primitive_desc_t get_primitive_desc() const;
154     // TODO: use the C++ API wrapper structure.
155 };
156
157 inline mkldnn_primitive_kind_t convert_to_c(primitive::kind akind) {
158     return static_cast<mkldnn_primitive_kind_t>(akind);
159 }
160 /// Intel(R) MKL-DNN exception class.
161 ///
162 /// This class captures the status returned by the failed C API function, error
163 /// message, and, optionally, handle of the primitive that caused the error.
164 struct error: public std::exception {
165     mkldnn_status_t status;
166     std::string message;
167     primitive error_primitive;
168
169     /// Constructs an error instance.
170     ///
171     /// @param astatus The error status returned by the C API.
172     /// @param amessage The error message.
173     /// @param aerror_primitive (optional) A C handle of the primitive that
174     ///                         caused the error.
175
176     error(mkldnn_status_t astatus, std::string amessage,
177             mkldnn_primitive_t aerror_primitive = 0)
178         : status(astatus)
179         , message(amessage)
180         , error_primitive(aerror_primitive, true)
181     {}
182
183     /// A convenience function for wrapping calls to the C API. Checks the
184     /// return status and throws an #error in case of failure.
185     ///
186     /// @param status The error status returned by the C API.
187     /// @param message The error message.
188     /// @param error_primitive (optional) A C handle of the primitive that
189     ///                        caused the error.
190
191     static void wrap_c_api(mkldnn_status_t status,
192             const std::string &message,
193             mkldnn_primitive_t *error_primitive = 0)
194     {
195         if (status != mkldnn_success) {
196             if (nullptr != error_primitive)
197                 throw error(status, message, *error_primitive);
198             else
199                 throw error(status, message, nullptr);
200         }
201     }
202 };
203
204 inline primitive::at::operator primitive() const {
205     const_mkldnn_primitive_t output;
206     error::wrap_c_api(
207             mkldnn_primitive_get_output(data.primitive,
208                 data.output_index, &output),
209             "could not get an output primitive");
210     return primitive(const_cast<mkldnn_primitive_t>(output), true);
211 }
212
213 const_mkldnn_primitive_desc_t primitive::get_primitive_desc() const {
214     const_mkldnn_primitive_desc_t pd;
215     error::wrap_c_api(mkldnn_primitive_get_primitive_desc(get(), &pd),
216             "could not get primitive descriptor by primitive");
217     return pd;
218 }
219 /// @}
220
221 /// @addtogroup cpp_api_enums Common data types and enumerations
222 /// A proxy to @ref c_api_types in @ref c_api.
223 ///
224 /// @{
225
226 enum round_mode {
227     round_nearest = mkldnn_round_nearest,
228     round_down = mkldnn_round_down,
229 };
230
231 inline mkldnn_round_mode_t convert_to_c(round_mode mode) {
232     return static_cast<mkldnn_round_mode_t>(mode);
233 }
234
235 enum padding_kind {
236     zero = mkldnn_padding_zero
237 };
238
239 inline mkldnn_padding_kind_t convert_to_c(padding_kind kind) {
240     return static_cast<mkldnn_padding_kind_t>(kind);
241 }
242
243 enum prop_kind {
244     forward_training = mkldnn_forward_training,
245     forward_scoring = mkldnn_forward_scoring,
246     forward_inference = mkldnn_forward_inference,
247     forward = mkldnn_forward,
248     backward = mkldnn_backward,
249     backward_data = mkldnn_backward_data,
250     backward_weights = mkldnn_backward_weights,
251     backward_bias = mkldnn_backward_bias
252 };
253
254 inline mkldnn_prop_kind_t convert_to_c(prop_kind kind) {
255     return static_cast<mkldnn_prop_kind_t>(kind);
256 }
257
258 enum algorithm {
259     algorithm_undef = mkldnn_alg_kind_undef,
260     convolution_auto = mkldnn_convolution_auto,
261     convolution_direct = mkldnn_convolution_direct,
262     convolution_winograd = mkldnn_convolution_winograd,
263     deconvolution_direct = mkldnn_deconvolution_direct,
264     deconvolution_winograd = mkldnn_deconvolution_winograd,
265     eltwise_relu = mkldnn_eltwise_relu,
266     eltwise_tanh = mkldnn_eltwise_tanh,
267     eltwise_elu = mkldnn_eltwise_elu,
268     eltwise_square = mkldnn_eltwise_square,
269     eltwise_abs = mkldnn_eltwise_abs,
270     eltwise_sqrt = mkldnn_eltwise_sqrt,
271     eltwise_linear = mkldnn_eltwise_linear,
272     eltwise_bounded_relu = mkldnn_eltwise_bounded_relu,
273     eltwise_soft_relu = mkldnn_eltwise_soft_relu,
274     eltwise_logistic = mkldnn_eltwise_logistic,
275     eltwise_clamp = mkldnn_eltwise_clamp,
276     eltwise_exp = mkldnn_eltwise_exp,
277     eltwise_not = mkldnn_eltwise_not,
278     depthwise_scale_shift = mkldnn_depthwise_scale_shift,
279     depthwise_prelu = mkldnn_depthwise_prelu,
280     lrn_across_channels = mkldnn_lrn_across_channels,
281     lrn_within_channel  = mkldnn_lrn_within_channel,
282     pooling_max = mkldnn_pooling_max,
283     pooling_avg = mkldnn_pooling_avg,
284     pooling_avg_include_padding = mkldnn_pooling_avg_include_padding,
285     pooling_avg_exclude_padding = mkldnn_pooling_avg_exclude_padding,
286     vanilla_rnn = mkldnn_vanilla_rnn,
287     vanilla_lstm = mkldnn_vanilla_lstm,
288     vanilla_gru = mkldnn_vanilla_gru,
289     gru_linear_before_reset = mkldnn_gru_linear_before_reset,
290     roi_pooling_max = mkldnn_roi_pooling_max,
291     roi_pooling_bilinear = mkldnn_roi_pooling_bilinear,
292     binary_convolution_direct = mkldnn_binary_convolution_direct,
293     binarization_depthwise = mkldnn_binarization_depthwise,
294 };
295
296 inline mkldnn_alg_kind_t convert_to_c(algorithm aalgorithm) {
297     return static_cast<mkldnn_alg_kind_t>(aalgorithm);
298 }
299
300 enum batch_normalization_flag {
301     use_global_stats = mkldnn_use_global_stats,
302     use_scale_shift = mkldnn_use_scaleshift,
303     fuse_bn_relu = mkldnn_fuse_bn_relu
304 };
305
306 inline mkldnn_batch_normalization_flag_t convert_to_c(
307         batch_normalization_flag aflag) {
308     return static_cast<mkldnn_batch_normalization_flag_t>(aflag);
309 }
310
311 enum rnn_direction {
312     unidirectional_left2right = mkldnn_unidirectional_left2right,
313     unidirectional_right2left = mkldnn_unidirectional_right2left,
314     unidirectional = mkldnn_unidirectional,
315     bidirectional_concat = mkldnn_bidirectional_concat,
316     bidirectional_sum = mkldnn_bidirectional_sum,
317 };
318
319 inline mkldnn_rnn_direction_t convert_to_c(rnn_direction adir) {
320     return static_cast<mkldnn_rnn_direction_t>(adir);
321 }
322
323 enum query {
324     undef = mkldnn_query_undef,
325
326     eengine = mkldnn_query_engine,
327     primitive_kind = mkldnn_query_primitive_kind,
328
329     num_of_inputs_s32 = mkldnn_query_num_of_inputs_s32,
330     num_of_outputs_s32 = mkldnn_query_num_of_outputs_s32,
331
332     time_estimate_f64 = mkldnn_query_time_estimate_f64,
333     memory_consumption_s64 = mkldnn_query_memory_consumption_s64,
334
335     impl_info_str = mkldnn_query_impl_info_str,
336
337     op_d = mkldnn_query_op_d,
338     memory_d = mkldnn_query_memory_d,
339     convolution_d = mkldnn_query_convolution_d,
340     deconvolution_d = mkldnn_query_deconvolution_d,
341     shuffle_d = mkldnn_query_shuffle_d,
342     eltwise_d = mkldnn_query_eltwise_d,
343     depthwise_d = mkldnn_query_depthwise_d,
344     softmax_d = mkldnn_query_softmax_d,
345     pooling_d = mkldnn_query_pooling_d,
346     lrn_d = mkldnn_query_lrn_d,
347     batch_normalization_d = mkldnn_query_batch_normalization_d,
348     inner_product_d = mkldnn_query_inner_product_d,
349     rnn_d = mkldnn_query_rnn_d,
350     binary_convolution_d = mkldnn_query_binary_convolution_d,
351     binarization_d = mkldnn_query_binarization_d,
352
353     input_pd = mkldnn_query_input_pd,
354     output_pd = mkldnn_query_output_pd,
355     src_pd = mkldnn_query_src_pd,
356     diff_src_pd = mkldnn_query_diff_src_pd,
357     weights_pd = mkldnn_query_weights_pd,
358     diff_weights_pd = mkldnn_query_diff_weights_pd,
359     dst_pd = mkldnn_query_dst_pd,
360     diff_dst_pd = mkldnn_query_diff_dst_pd,
361     workspace_pd = mkldnn_query_workspace_pd,
362 };
363
364 inline mkldnn_query_t convert_to_c(query aquery) {
365     return static_cast<mkldnn_query_t>(aquery);
366 }
367
368 /// @}
369
370 /// @addtogroup cpp_api_attr Attributes
371 /// An extension for controlling primitive behavior.
372 ///
373 /// @sa @ref c_api_attributes in @ref c_api
374 /// @{
375
376 #ifndef DOXYGEN_SHOULD_SKIP_THIS
377 template <> struct handle_traits<mkldnn_post_ops_t> {
378     static constexpr auto destructor = &mkldnn_post_ops_destroy;
379 };
380 #endif
381
382 struct post_ops: public handle<mkldnn_post_ops_t> {
383     post_ops() {
384         mkldnn_post_ops_t result;
385         error::wrap_c_api(mkldnn_post_ops_create(&result),
386                 "could not create post operation sequence");
387         reset(result);
388     }
389
390     int len() const { return mkldnn_post_ops_len(get()); }
391
392     primitive::kind kind(int index) const {
393         error::wrap_c_api(
394                 index < len() ? mkldnn_success : mkldnn_invalid_arguments,
395                 "post_ops index is out of range");
396         return static_cast<primitive::kind>(mkldnn_post_ops_get_kind(get(),
397                     index));
398     }
399
400     void append_sum(float scale = 1.) {
401         error::wrap_c_api(mkldnn_post_ops_append_sum(get(), scale),
402                 "could not append sum");
403     }
404
405     void get_params_sum(int index, float &scale) const {
406         error::wrap_c_api(mkldnn_post_ops_get_params_sum(get(), index, &scale),
407                 "could not get sum params");
408     }
409
410     void append_eltwise(float scale, algorithm alg, float alpha,
411             float beta) {
412         error::wrap_c_api(mkldnn_post_ops_append_eltwise(get(), scale,
413                     convert_to_c(alg), alpha, beta),
414                 "could not append eltwise");
415     }
416
417     void get_params_eltwise(int index, float &scale, algorithm &alg,
418             float &alpha, float &beta) const {
419         mkldnn_alg_kind_t c_alg;
420         error::wrap_c_api(mkldnn_post_ops_get_params_eltwise(get(), index,
421                     &scale, &c_alg, &alpha, &beta),
422                 "could not get eltwise params");
423         alg = static_cast<algorithm>(c_alg);
424     }
425
426     void append_depthwise(algorithm alg, const float* weights_data,
427             const float* biases_data) {
428         error::wrap_c_api(mkldnn_post_ops_append_depthwise(get(),
429                     convert_to_c(alg), weights_data, biases_data),
430                 "could not append depthwise");
431     }
432
433     void get_params_depthwise(int index, algorithm &alg,
434             const float** weights_data, const float** biases_data) const {
435         mkldnn_alg_kind_t c_alg;
436         error::wrap_c_api(mkldnn_post_ops_get_params_depthwise(get(), index,
437                     &c_alg, weights_data, biases_data),
438                 "could not get depthwise params");
439         alg = static_cast<algorithm>(c_alg);
440     }
441
442     void append_dw_conv(int in_h, int in_w, int ker_h, int ker_w, int str_h, int str_w,
443             const float* weights_data, const float* biases_data) {
444         error::wrap_c_api(mkldnn_post_ops_append_dw_conv(get(),
445                 in_h, in_w, ker_h, ker_w, str_h, str_w, weights_data, biases_data),
446                           "could not append dw conv");
447     }
448
449     void get_params_dw_conv(int index, int &in_h, int &in_w, int &ker_h, int &ker_w, int &str_h, int &str_w,
450             const float** weights_data, const float** biases_data) const {
451         error::wrap_c_api(mkldnn_post_ops_get_params_dw_conv(get(), index,
452                 &in_h, &in_w, &ker_h, &ker_w, &str_h, &str_w, weights_data, biases_data),
453                           "could not get dw conv params");
454     }
455
456     void append_binarization(algorithm alg, const float* weights_data, const float* output_mask) {
457         error::wrap_c_api(mkldnn_post_ops_append_binarization(get(), convert_to_c(alg), weights_data, output_mask),
458                 "could not append binarization");
459     }
460
461     void get_params_binarization(int index, algorithm &alg, const float** weights_data, const float** output_mask) const {
462         mkldnn_alg_kind_t c_alg;
463         error::wrap_c_api(mkldnn_post_ops_get_params_binarization(get(), index, &c_alg, weights_data, output_mask),
464                 "could not get binarization params");
465         alg = static_cast<algorithm>(c_alg);
466     }
467 };
468
469 #ifndef DOXYGEN_SHOULD_SKIP_THIS
470 template <> struct handle_traits<mkldnn_primitive_attr_t> {
471     static constexpr auto destructor = &mkldnn_primitive_attr_destroy;
472 };
473 #endif
474
475 struct primitive_attr: public handle<mkldnn_primitive_attr_t> {
476     primitive_attr() {
477         mkldnn_primitive_attr_t result;
478         error::wrap_c_api(mkldnn_primitive_attr_create(&result),
479                 "could not create a primitive attr");
480         reset(result);
481     }
482
483     round_mode get_int_output_round_mode() const {
484         mkldnn_round_mode_t result;
485         error::wrap_c_api(mkldnn_primitive_attr_get_int_output_round_mode(
486                     get(), &result), "could not get int output round mode");
487         return round_mode(result);
488     }
489
490     void set_int_output_round_mode(round_mode mode) {
491         error::wrap_c_api(mkldnn_primitive_attr_set_int_output_round_mode(
492                     get(), mkldnn::convert_to_c(mode)),
493                 "could not set int output round mode");
494     }
495
496     void get_output_scales(int &mask, std::vector<float> &scales) const
497     {
498         int count, c_mask;
499         const float *c_scales;
500         error::wrap_c_api(mkldnn_primitive_attr_get_output_scales(get(),
501                     &count, &c_mask, &c_scales),
502                 "could not get int output scales");
503         scales.resize(count);
504
505         mask = c_mask;
506         for (int c = 0; c < count; ++c)
507             scales[c] = c_scales[c];
508     }
509
510     void set_output_scales(int mask, const std::vector<float> &scales)
511     {
512         error::wrap_c_api(mkldnn_primitive_attr_set_output_scales(get(),
513                     (int)scales.size(), mask, &scales[0]),
514                 "could not set int output scales");
515     }
516
517     const post_ops get_post_ops() const {
518         post_ops result;
519         const_mkldnn_post_ops_t c_result;
520         error::wrap_c_api(mkldnn_primitive_attr_get_post_ops(get(), &c_result),
521                 "could not get post operation sequence");
522         result.reset(const_cast<mkldnn_post_ops_t>(c_result), true);
523         return result;
524     }
525
526     void set_post_ops(post_ops ops) {
527         error::wrap_c_api(mkldnn_primitive_attr_set_post_ops(get(), ops.get()),
528                 "could not set post operation sequence");
529     }
530
531     void set_rnn_data_qparams(const float scale, const float shift)
532     {
533         error::wrap_c_api(mkldnn_primitive_attr_set_rnn_data_qparams(get(),
534                     scale, shift), "could not set rnn data int scale/shift");
535     }
536
537     void set_rnn_weights_qparams(int mask, const std::vector<float> &scales)
538     {
539         error::wrap_c_api(mkldnn_primitive_attr_set_rnn_weights_qparams(get(),
540                     (int)scales.size(), mask, &scales[0]),
541                 "could not set rnn weights int scales");
542     }
543 };
544
545 /// @}
546
547 /// @addtogroup cpp_api_engine Engine
548 /// Engine operations.
549 ///
550 /// @sa @ref c_api_engine in @ref c_api
551 /// @{
552
553 #ifndef DOXYGEN_SHOULD_SKIP_THIS
554 template <> struct handle_traits<mkldnn_engine_t> {
555     static constexpr auto destructor = &mkldnn_engine_destroy;
556 };
557 #endif
558
559 /// An execution engine.
560 struct engine: public handle<mkldnn_engine_t> {
561     friend class primitive;
562     // gcc bug??? using handle::handle;
563
564     /// Kinds of engines.
565     enum kind {
566         /// An unspecified engine
567         any = mkldnn_any_engine,
568         /// CPU engine
569         cpu = mkldnn_cpu,
570     };
571
572     /// Returns the number of engines of a certain kind.
573     ///
574     /// @param akind The kind of engines to count.
575
576     static size_t get_count(kind akind) {
577         return mkldnn_engine_get_count(convert_to_c(akind));
578     }
579
580     /// Constructs an engine.
581     ///
582     /// @param akind The kind of engine to construct.
583     /// @param index The index of the engine. Must be less than the value
584     ///              returned by #get_count() for this particular kind of engine.
585
586     engine(kind akind, size_t index) {
587         mkldnn_engine_t aengine;
588         error::wrap_c_api(
589                 mkldnn_engine_create(&aengine,
590                     convert_to_c(akind), index),
591                 "could not create an engine");
592         reset(aengine);
593     }
594
595     explicit engine(const mkldnn_engine_t& aengine)
596         : handle(aengine, true) {}
597
598     engine(const handle<mkldnn_primitive_desc_t> &pd) {
599         mkldnn_engine_t engine_q;
600         error::wrap_c_api(
601                 mkldnn_primitive_desc_query(pd.get(),
602                     mkldnn::convert_to_c(eengine), 0, &engine_q),
603                 "could not get engine from primitive_desc");
604         reset(engine_q, true);
605     }
606
607     template <class primitive_desc>
608     static engine query(const primitive_desc &pd) {
609         mkldnn_engine_t engine_q;
610         error::wrap_c_api(
611                 mkldnn_primitive_desc_query(pd.get(),
612                     mkldnn::convert_to_c(eengine), 0, &engine_q),
613                 "could not get engine from primitive_desc");
614
615         return engine(engine_q);
616     }
617
618 private:
619     static mkldnn_engine_kind_t convert_to_c(kind akind) {
620         return static_cast<mkldnn_engine_kind_t>(akind);
621     }
622 };
623
624 /// @}
625
626 /// @addtogroup cpp_api_memory_related Memory and memory related operations
627 /// @{
628
629 /// @addtogroup cpp_api_memory Memory
630 /// A primitive to describe and store data.
631 ///
632 /// For more information, refer to @ref c_api_memory in @ref c_api.
633 /// @{
634
635 /// Memory primitive that describes the data.
636 struct memory: public primitive  {
637     private:
638     std::shared_ptr<char> _handle;
639
640     public:
641     typedef std::vector<std::remove_extent<mkldnn_dims_t>::type> dims;
642
643     template <typename T> static void validate_dims(std::vector<T> v) {
644         if (v.size() > TENSOR_MAX_DIMS)
645             throw error(mkldnn_invalid_arguments,
646                     "invalid dimensions");
647     }
648
649     /// Data type specification. See #mkldnn_data_type_t for a detailed
650     /// description.
651     enum data_type {
652         data_undef = mkldnn_data_type_undef,
653         f32 = mkldnn_f32,
654         s32 = mkldnn_s32,
655         s16 = mkldnn_s16,
656         s8 = mkldnn_s8,
657         u8 = mkldnn_u8,
658         bin = mkldnn_bin,
659     };
660
661     /// Memory format specification. See #mkldnn_memory_format_t
662     /// for a detailed description.
663     enum format {
664         format_undef = mkldnn_format_undef,
665         any = mkldnn_any,
666         blocked = mkldnn_blocked,
667         x = mkldnn_x,
668         nc = mkldnn_nc,
669         ncw = mkldnn_ncw,
670         nwc = mkldnn_nwc,
671         nCw16c = mkldnn_nCw16c,
672         nchw = mkldnn_nchw,
673         nhwc = mkldnn_nhwc,
674         chwn = mkldnn_chwn,
675         nCw4c = mkldnn_nCw4c,
676         nCw8c = mkldnn_nCw8c,
677         nChw4c = mkldnn_nChw4c,
678         nChw8c = mkldnn_nChw8c,
679         nChw16c = mkldnn_nChw16c,
680         ncdhw = mkldnn_ncdhw,
681         ndhwc = mkldnn_ndhwc,
682         nCdhw4c = mkldnn_nCdhw4c,
683         nCdhw8c = mkldnn_nCdhw8c,
684         nCdhw16c = mkldnn_nCdhw16c,
685         oi = mkldnn_oi,
686         io = mkldnn_io,
687         oiw = mkldnn_oiw,
688         wio = mkldnn_wio,
689         Owi4o = mkldnn_Owi4o,
690         OIw4i4o = mkldnn_OIw4i4o,
691         Owi8o = mkldnn_Owi8o,
692         OIw8o8i = mkldnn_OIw8o8i,
693         OIw8i8o = mkldnn_OIw8i8o,
694         OIw16i16o = mkldnn_OIw16i16o,
695         OIw16o16i = mkldnn_OIw16o16i,
696         Oiw4o = mkldnn_Oiw4o,
697         Oiw16o = mkldnn_Oiw16o,
698         Owi16o = mkldnn_Owi16o,
699         OIw8i16o2i = mkldnn_OIw8i16o2i,
700         OIw8o16i2o = mkldnn_OIw8o16i2o,
701         IOw16o16i = mkldnn_IOw16o16i,
702         oihw = mkldnn_oihw,
703         ihwo = mkldnn_ihwo,
704         hwio = mkldnn_hwio,
705         iohw = mkldnn_iohw,
706         hwio_s8s8 = mkldnn_hwio_s8s8,
707         dhwio = mkldnn_dhwio,
708         oidhw = mkldnn_oidhw,
709         OIdhw4i4o = mkldnn_OIdhw4i4o,
710         Odhwi4o = mkldnn_Odhwi4o,
711         OIdhw8i8o = mkldnn_OIdhw8i8o,
712         OIdhw8o8i = mkldnn_OIdhw8o8i,
713         Odhwi8o = mkldnn_Odhwi8o,
714         OIdhw16i16o = mkldnn_OIdhw16i16o,
715         OIdhw16o16i = mkldnn_OIdhw16o16i,
716         Oidhw4o = mkldnn_Oidhw4o,
717         Oidhw16o = mkldnn_Oidhw16o,
718         Odhwi16o = mkldnn_Odhwi16o,
719         oIhw8i = mkldnn_oIhw8i,
720         oIhw16i = mkldnn_oIhw16i,
721         oIdhw8i = mkldnn_oIdhw8i,
722         oIdhw16i = mkldnn_oIdhw16i,
723         OIhw4i4o = mkldnn_OIhw4i4o,
724         OIhw8i8o = mkldnn_OIhw8i8o,
725         OIhw16i16o = mkldnn_OIhw16i16o,
726         OIhw8o8i = mkldnn_OIhw8o8i,
727         OIhw16o16i = mkldnn_OIhw16o16i,
728         IOhw16o16i = mkldnn_IOhw16o16i,
729         OIhw8i16o2i = mkldnn_OIhw8i16o2i,
730         OIdhw8i16o2i = mkldnn_OIdhw8i16o2i,
731         OIhw8o16i2o = mkldnn_OIhw8o16i2o,
732         OIhw4i16o4i = mkldnn_OIhw4i16o4i,
733         OIhw4i16o4i_s8s8 = mkldnn_OIhw4i16o4i_s8s8,
734         Oihw8o = mkldnn_Oihw8o,
735         Oihw4o = mkldnn_Oihw4o,
736         Oihw16o = mkldnn_Oihw16o,
737         Ohwi8o = mkldnn_Ohwi8o,
738         Ohwi4o = mkldnn_Ohwi4o,
739         Ohwi16o = mkldnn_Ohwi16o,
740         OhIw16o4i = mkldnn_OhIw16o4i,
741         OhIw8o4i = mkldnn_OhIw8o4i,
742         OhIw8o32i = mkldnn_OhIw8o32i,
743         OhIw16o32i = mkldnn_OhIw16o32i,
744         OhIw8o4i_s8s8 = mkldnn_OhIw8o4i_s8s8,
745         goiw = mkldnn_goiw,
746         gOwi4o = mkldnn_gOwi4o,
747         gOIw4i4o = mkldnn_gOIw4i4o,
748         gOwi8o = mkldnn_gOwi8o,
749         gOIw8o8i = mkldnn_gOIw8o8i,
750         gOIw8i8o = mkldnn_gOIw8i8o,
751         gOIw16i16o = mkldnn_gOIw16i16o,
752         gOIw16o16i = mkldnn_gOIw16o16i,
753         gOiw4o = mkldnn_gOiw4o,
754         gOiw16o = mkldnn_gOiw16o,
755         gOwi16o = mkldnn_gOwi16o,
756         gOIw8i16o2i = mkldnn_gOIw8i16o2i,
757         gIOw16o16i = mkldnn_gIOw16o16i,
758         gOIw8o16i2o = mkldnn_gOIw8o16i2o,
759         goihw = mkldnn_goihw,
760         hwigo = mkldnn_hwigo,
761         giohw = mkldnn_giohw,
762         hwigo_s8s8 = mkldnn_hwigo_s8s8,
763         gOIdhw4i4o = mkldnn_gOIdhw4i4o,
764         gOdhwi4o = mkldnn_gOdhwi4o,
765         gOIdhw8i8o = mkldnn_gOIdhw8i8o,
766         gOIdhw8o8i = mkldnn_gOIdhw8o8i,
767         gOdhwi8o = mkldnn_gOdhwi8o,
768         gOIhw4i4o = mkldnn_gOIhw4i4o,
769         gOIhw8i8o = mkldnn_gOIhw8i8o,
770         gOIhw16i16o = mkldnn_gOIhw16i16o,
771         gOIhw8i16o2i = mkldnn_gOIhw8i16o2i,
772         gOIdhw8i16o2i = mkldnn_gOIdhw8i16o2i,
773         gOIhw8o16i2o = mkldnn_gOIhw8o16i2o,
774         gOIhw4i16o4i = mkldnn_gOIhw4i16o4i,
775         gOIhw4i16o4i_s8s8 = mkldnn_gOIhw4i16o4i_s8s8,
776         gOIhw2i8o4i = mkldnn_gOIhw2i8o4i,
777         gOIhw2i8o4i_s8s8 = mkldnn_gOIhw2i8o4i_s8s8,
778         gOihw8o = mkldnn_gOihw8o,
779         gOihw4o = mkldnn_gOihw4o,
780         gOihw16o = mkldnn_gOihw16o,
781         gOhwi4o = mkldnn_gOhwi4o,
782         gOhwi8o = mkldnn_gOhwi8o,
783         gOhwi16o = mkldnn_gOhwi16o,
784         Goihw8g = mkldnn_Goihw8g,
785         Goihw16g = mkldnn_Goihw16g,
786         Goihw16g_s8s8 = mkldnn_Goihw16g_s8s8,
787         gOIhw4o4i = mkldnn_gOIhw4o4i,
788         gOIhw4o4i_s8s8 = mkldnn_gOIhw4o4i_s8s8,
789         gOIhw8o8i = mkldnn_gOIhw8o8i,
790         gOIhw16o16i = mkldnn_gOIhw16o16i,
791         gIOhw16o16i = mkldnn_gIOhw16o16i,
792         gOhIw16o4i = mkldnn_gOhIw16o4i,
793         gOhIw8o4i = mkldnn_gOhIw8o4i,
794         gOhIw8o4i_s8s8 = mkldnn_gOhIw8o4i_s8s8,
795         goidhw = mkldnn_goidhw,
796         gOIdhw16i16o = mkldnn_gOIdhw16i16o,
797         gOIdhw16o16i = mkldnn_gOIdhw16o16i,
798         gOidhw4o = mkldnn_gOidhw4o,
799         gOidhw16o = mkldnn_gOidhw16o,
800         gOdhwi16o = mkldnn_gOdhwi16o,
801         ntc = mkldnn_ntc,
802         tnc = mkldnn_tnc,
803         ldsnc = mkldnn_ldsnc,
804         ldigo = mkldnn_ldigo,
805         ldgoi = mkldnn_ldgoi,
806         ldgo = mkldnn_ldgo,
807         rnn_packed = mkldnn_rnn_packed,
808         wino_fmt = mkldnn_wino_fmt,
809         format_last = mkldnn_format_last,
810     };
811
812     /// A memory descriptor.
813     struct desc {
814         friend struct memory;
815         /// The underlying C API data structure.
816         mkldnn_memory_desc_t data;
817
818         /// Constructs a memory descriptor.
819         ///
820         /// @param adims Data dimensions
821         /// @param adata_type Data precision/type.
822         /// @param aformat Data layout format.
823         desc(dims adims, data_type adata_type,
824                 format aformat) {
825             validate_dims(adims);
826             error::wrap_c_api(
827                     mkldnn_memory_desc_init(&data, (int)adims.size(),
828                         adims.size() == 0 ? nullptr : &adims[0],
829                         convert_to_c(adata_type), convert_to_c(aformat)),
830                     "could not initialize a memory descriptor");
831         }
832
833         /// Constructs a memory descriptor from a C API data structure.
834         ///
835         /// @param adata A C API #mkldnn_memory_desc_t structure.
836         desc(const mkldnn_memory_desc_t &adata): data(adata) {}
837     };
838
839     /// A memory primitive descriptor.
840     struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
841         friend struct memory;
842
843         // TODO: make private
844         primitive_desc() {}
845
846         /// Constructs a memory primitive descriptor.
847         primitive_desc(const desc &adesc, const engine &aengine) {
848             mkldnn_primitive_desc_t result;
849             error::wrap_c_api(
850                     mkldnn_memory_primitive_desc_create(&result,
851                         &adesc.data, aengine.get()),
852                     "could not initialize a memory primitive descriptor");
853             reset(result);
854         }
855
856         /// Returns the memory primitive descriptor.
857         memory::desc desc() {
858             auto memory_d = mkldnn_primitive_desc_query_memory_d(get());
859             return memory::desc(*memory_d); }
860
861         /// Returns the number of bytes required to allocate the memory described
862         /// including the padding area.
863         size_t get_size() const {
864              return mkldnn_memory_primitive_desc_get_size(get());
865         }
866
867         bool operator==(const primitive_desc &other) const {
868             return (0 == mkldnn_memory_primitive_desc_equal(get(),
869                         other.get())) ? false : true;
870         }
871
872         bool operator!=(const primitive_desc &other) const {
873             return !operator==(other);
874         }
875
876         engine get_engine() { return engine::query(*this); }
877     };
878
879     /// Constructs a memory primitive from a generic primitive.
880     ///
881     /// @param aprimitive The primitive to treat as memory.
882     memory(const primitive &aprimitive): primitive(aprimitive) {}
883     /// Constructs a memory primitive.
884     ///
885     /// @param adesc Memory primitive descriptor.
886     memory(const primitive_desc &adesc) {
887         mkldnn_primitive_t result;
888         error::wrap_c_api(
889                 mkldnn_primitive_create(&result, adesc.get(), nullptr, nullptr),
890                 "could not create a memory primitive");
891         reset(result);
892         auto _malloc = [](size_t size, int alignment) {
893             void *ptr;
894 #ifdef _WIN32
895             ptr = _aligned_malloc(size, alignment);
896             int rc = ((ptr)? 0 : errno);
897 #else
898             int rc = ::posix_memalign(&ptr, alignment, size);
899 #endif /* _WIN32 */
900             return (rc == 0) ? (char*)ptr : nullptr;
901         };
902         auto _free = [](char* p) {
903 #ifdef _WIN32
904             _aligned_free((void*)p);
905 #else
906             ::free((void*)p);
907 #endif /* _WIN32 */
908         };
909         _handle.reset(_malloc(adesc.get_size(), 4096), _free);
910         set_data_handle(_handle.get());
911     }
912
913     memory(const primitive_desc &adesc, void *ahandle) {
914         mkldnn_primitive_t result;
915         error::wrap_c_api(
916                 mkldnn_primitive_create(&result, adesc.get(), nullptr, nullptr),
917                 "could not create a memory primitive");
918         reset(result);
919         set_data_handle(ahandle);
920     }
921
922     /// Returns the descriptor of the memory primitive.
923     primitive_desc get_primitive_desc() const {
924         primitive_desc adesc;
925         const_mkldnn_primitive_desc_t cdesc;
926         error::wrap_c_api(mkldnn_primitive_get_primitive_desc(get(),
927                     &cdesc),
928                 "could not get primitive descriptor from a memory primitive");
929         /* FIXME: no const_cast should be here */
930         adesc.reset(const_cast<mkldnn_primitive_desc_t>(cdesc), true);
931         return adesc;
932     }
933
934     /// Returns a handle of the data contained in the memory primitive. On
935     /// the CPU engine, this is a pointer to the allocated memory.
936     inline void *get_data_handle() const {
937         void *handle;
938         error::wrap_c_api(mkldnn_memory_get_data_handle(get(), &handle),
939                 "could not get native handle");
940         return handle;
941     }
942
943     inline void set_data_handle(void *handle) const {
944         error::wrap_c_api(mkldnn_memory_set_data_handle(get(), handle),
945                 "could not set native handle");
946     }
947
948     // Must go away or be private:
949     static mkldnn_data_type_t convert_to_c(data_type adata_type) {
950         return static_cast<mkldnn_data_type_t>(adata_type);
951     }
952     static mkldnn_memory_format_t convert_to_c(format aformat) {
953         return static_cast<mkldnn_memory_format_t>(aformat);
954     }
955 };
956
957 inline memory::desc zero_md() {
958     auto zero = mkldnn_memory_desc_t();
959     zero.primitive_kind = mkldnn_memory;
960     return memory::desc(zero);
961 }
962
963 inline memory null_memory(engine eng) {
964     mkldnn::memory::desc zero = zero_md();
965     return memory({zero, eng}, nullptr);
966 }
967
968 inline void check_num_parameters(const const_mkldnn_primitive_desc_t
969     &aprimitive_desc, int n_inputs, int n_outputs,
970     const std::string &prim_name) {
971     const int n_inputs_expected = mkldnn_primitive_desc_query_s32(
972             aprimitive_desc, mkldnn_query_num_of_inputs_s32, 0);
973     const int n_outputs_expected = mkldnn_primitive_desc_query_s32(
974             aprimitive_desc, mkldnn_query_num_of_outputs_s32, 0);
975     if (n_outputs_expected > n_outputs ) {
976         std::string message = "could not create " + prim_name +
977             " primitive, not enought output parameters";
978         throw error(mkldnn_invalid_arguments, message, nullptr);
979     }
980     if (n_inputs_expected > n_inputs ) {
981         std::string message = "could not create " + prim_name +
982             " primitive, not enought input parameters";
983         throw error(mkldnn_invalid_arguments, message, nullptr);
984     }
985 }
986
987
988 inline bool is_null_memory(const const_mkldnn_primitive_t &aprimitive) {
989     const_mkldnn_primitive_desc_t aprimitive_pd;
990     mkldnn_primitive_get_primitive_desc(aprimitive, &aprimitive_pd);
991     const mkldnn_memory_desc_t *aprimitive_md = mkldnn_primitive_desc_query_memory_d(
992         aprimitive_pd);
993
994     return ((aprimitive_md != nullptr) && (aprimitive_md->ndims == 0));
995 }
996
997 inline bool operator==(mkldnn_data_type_t a, memory::data_type b) {
998     return a == memory::convert_to_c(b);
999 }
1000 inline bool operator!=(mkldnn_data_type_t a, memory::data_type b) {
1001     return !(a == b);
1002 }
1003 inline bool operator==(memory::data_type a, mkldnn_data_type_t b) {
1004     return b == a;
1005 }
1006 inline bool operator!=(memory::data_type a, mkldnn_data_type_t b) {
1007     return !(a == b);
1008 }
1009
1010 inline bool operator==(mkldnn_memory_format_t a, memory::format b) {
1011     return a == memory::convert_to_c(b);
1012 }
1013 inline bool operator!=(mkldnn_memory_format_t a, memory::format b) {
1014     return !(a == b);
1015 }
1016 inline bool operator==(memory::format a, mkldnn_memory_format_t b) {
1017     return b == a;
1018 }
1019 inline bool operator!=(memory::format a, mkldnn_memory_format_t b) {
1020     return !(a == b);
1021 }
1022
1023 /// @}
1024
1025 /// @addtogroup cpp_api_reorder Reorder
1026 /// A primitive to copy data between memory formats.
1027 ///
1028 /// @sa @ref c_api_reorder in @ref c_api
1029 /// @{
1030
1031 struct reorder : public primitive {
1032     struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1033         primitive_desc(const memory::primitive_desc &input,
1034                        const memory::primitive_desc &output) {
1035             mkldnn_primitive_desc_t result;
1036             error::wrap_c_api(mkldnn_reorder_primitive_desc_create(
1037                         &result, input.get(), output.get()),
1038                     "could not create a reorder primitive descriptor");
1039             reset(result);
1040         }
1041
1042         primitive_desc(const memory::primitive_desc &input,
1043                 const memory::primitive_desc &output,
1044                 const primitive_attr &aattr) {
1045             mkldnn_primitive_desc_t result;
1046             error::wrap_c_api(mkldnn_reorder_primitive_desc_create_v2(
1047                         &result, input.get(), output.get(), aattr.get()),
1048                     "could not create a reorder primitive descriptor");
1049             reset(result);
1050         }
1051
1052         engine get_engine() { return engine::query(*this); }
1053     };
1054
1055     reorder(const primitive_desc &aprimitive_desc,
1056             const primitive::at &input, const memory &output) {
1057         mkldnn_primitive_t result;
1058         mkldnn_primitive_at_t inputs[] = { input.data };
1059         const_mkldnn_primitive_t outputs[] = { output.get() };
1060         error::wrap_c_api(mkldnn_primitive_create(&result,
1061                     aprimitive_desc.get(), inputs, outputs),
1062                 "could not create a reorder primitive");
1063         reset(result);
1064     }
1065
1066     reorder(const primitive::at &input, const memory &output) {
1067         auto input_mpd = memory(input).get_primitive_desc();
1068         auto output_mpd = output.get_primitive_desc();
1069
1070         auto reorder_d = primitive_desc(input_mpd, output_mpd);
1071
1072         mkldnn_primitive_t result;
1073         mkldnn_primitive_at_t inputs[] = { input.data };
1074         const_mkldnn_primitive_t outputs[] = { output.get() };
1075         error::wrap_c_api(mkldnn_primitive_create(&result,
1076                     reorder_d.get(), inputs, outputs),
1077                 "could not create a reorder primitive");
1078         reset(result);
1079     }
1080 };
1081
1082 /// @}
1083
1084 /// @addtogroup cpp_api_view View
1085 /// A primitive to view on a memory.
1086 ///
1087 /// @sa mkldnn_view_primitive_desc_create in @ref c_api
1088 /// @{
1089
1090 struct view : public primitive {
1091     struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1092         primitive_desc(const memory::primitive_desc &input, memory::dims dims,
1093                 memory::dims offsets) {
1094             mkldnn_primitive_desc_t result;
1095
1096             error::wrap_c_api(mkldnn_view_primitive_desc_create(
1097                     &result, input.get(), &dims[0], &offsets[0]),
1098                 "could not create a view primitive descriptor");
1099             reset(result);
1100         }
1101
1102         memory::primitive_desc dst_primitive_desc() const {
1103             memory::primitive_desc adesc;
1104             mkldnn_primitive_desc_t cdesc;
1105             const_mkldnn_primitive_desc_t const_cdesc =
1106                 mkldnn_primitive_desc_query_pd(get(),
1107                                mkldnn::convert_to_c(dst_pd), 0);
1108             error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc,
1109                         const_cdesc),
1110                     "could not clone a dst primitive descriptor");
1111             adesc.reset(cdesc);
1112             return adesc;
1113         }
1114
1115         engine get_engine() { return engine::query(*this); }
1116     };
1117
1118     view(const primitive_desc &view_pd, primitive::at input) {
1119         mkldnn_primitive_t result;
1120         mkldnn_primitive_at_t inputs[] = { input.data };
1121         error::wrap_c_api(mkldnn_primitive_create(&result,
1122                     view_pd.get(), inputs, nullptr),
1123                 "could not create a view primitive");
1124         reset(result);
1125     }
1126
1127     view(memory input, memory::dims dims, memory::dims offsets) {
1128         mkldnn_primitive_t result;
1129         primitive_desc view_pd(input.get_primitive_desc(), dims,
1130                 offsets);
1131         mkldnn_primitive_at_t inputs[] = { primitive::at(input).data };
1132         error::wrap_c_api(mkldnn_primitive_create(&result,
1133                     view_pd.get(), inputs, nullptr),
1134                 "could not create a view primitive");
1135         reset(result);
1136     }
1137 };
1138
1139 /// @}
1140
1141 /// @addtogroup cpp_api_concat Concat
1142 /// A primitive to concatenate data by arbitrary dimension.
1143 ///
1144 /// @sa @ref c_api_concat in @ref c_api
1145 /// @{
1146
1147 struct concat : public primitive {
1148     struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1149         std::vector<const_mkldnn_primitive_desc_t> cpp_to_c(
1150                 std::vector<memory::primitive_desc> inputs) {
1151             std::vector<const_mkldnn_primitive_desc_t> c_api_inputs;
1152             c_api_inputs.reserve(inputs.size());
1153             auto convert_to_c = [](memory::primitive_desc d) { return d.get(); };
1154             std::transform(inputs.begin(), inputs.end(),
1155                     std::back_inserter(c_api_inputs), convert_to_c);
1156             return c_api_inputs;
1157         }
1158
1159         primitive_desc(const memory::desc &output, int concat_dimension,
1160                 std::vector<memory::primitive_desc> inputs) {
1161             mkldnn_primitive_desc_t result;
1162
1163             auto c_api_inputs = cpp_to_c(inputs);
1164
1165             error::wrap_c_api(mkldnn_concat_primitive_desc_create(
1166                     &result, &output.data, (int)c_api_inputs.size(),
1167                     concat_dimension, &c_api_inputs[0]),
1168                 "could not create a concat primitive descriptor");
1169             reset(result);
1170         }
1171
1172         primitive_desc(int concat_dimension,
1173                 std::vector<memory::primitive_desc> inputs) {
1174             mkldnn_primitive_desc_t result;
1175
1176             auto c_api_inputs = cpp_to_c(inputs);
1177
1178             error::wrap_c_api(mkldnn_concat_primitive_desc_create(
1179                     &result, nullptr, (int)c_api_inputs.size(),
1180                     concat_dimension, &c_api_inputs[0]),
1181                 "could not create a concat primitive descriptor");
1182             reset(result);
1183         }
1184
1185         memory::primitive_desc dst_primitive_desc() const {
1186             memory::primitive_desc adesc;
1187             mkldnn_primitive_desc_t cdesc;
1188             const_mkldnn_primitive_desc_t const_cdesc =
1189                 mkldnn_primitive_desc_query_pd(get(),
1190                                mkldnn::convert_to_c(dst_pd), 0);
1191             error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1192                     "could not clone a dst primitive descriptor");
1193             adesc.reset(cdesc);
1194             return adesc;
1195         }
1196
1197         engine get_engine() { return engine::query(*this); }
1198     };
1199
1200     concat(const primitive_desc &concat_pd,
1201             std::vector<primitive::at> &inputs, const memory &output) {
1202         mkldnn_primitive_t result;
1203
1204         std::vector<mkldnn_primitive_at_t> p_inputs;
1205         for (size_t i = 0; i < inputs.size(); i++)
1206             p_inputs.push_back(inputs[i].data);
1207         const_mkldnn_primitive_t outputs[] = { output.get() };
1208
1209         error::wrap_c_api(mkldnn_primitive_create(&result,
1210                     concat_pd.get(), &p_inputs[0], outputs),
1211                 "could not create a concat primitive");
1212         reset(result);
1213     }
1214 };
1215
1216 /// @}
1217
1218 /// @addtogroup cpp_api_sum Sum
1219 /// A primitive to sum data.
1220 ///
1221 /// @sa @ref c_api_sum in @ref c_api
1222 /// @{
1223
1224 struct sum : public primitive {
1225     struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1226         std::vector<const_mkldnn_primitive_desc_t> cpp_to_c(
1227                 std::vector<memory::primitive_desc> inputs) {
1228             std::vector<const_mkldnn_primitive_desc_t> c_api_inputs;
1229             c_api_inputs.reserve(inputs.size());
1230             auto convert_to_c = [](memory::primitive_desc d) { return d.get();};
1231             std::transform(inputs.begin(), inputs.end(),
1232                     std::back_inserter(c_api_inputs), convert_to_c);
1233             return c_api_inputs;
1234         }
1235
1236         primitive_desc(const memory::desc &output,
1237                 const std::vector<float> &scales,
1238                 std::vector<memory::primitive_desc> inputs) {
1239             mkldnn_primitive_desc_t result;
1240
1241             auto c_api_inputs = cpp_to_c(inputs);
1242
1243             error::wrap_c_api(
1244                 scales.size() == inputs.size() ? mkldnn_success
1245                                                : mkldnn_invalid_arguments,
1246                 "number of scales not equal to number of inputs");
1247
1248             error::wrap_c_api(mkldnn_sum_primitive_desc_create(
1249                     &result, &output.data, (int)c_api_inputs.size(),
1250                     &scales[0], &c_api_inputs[0]),
1251                 "could not create a sum primitive descriptor");
1252             reset(result);
1253         }
1254
1255         primitive_desc(const std::vector<float> &scales,
1256                 std::vector<memory::primitive_desc> inputs) {
1257             mkldnn_primitive_desc_t result;
1258
1259             auto c_api_inputs = cpp_to_c(inputs);
1260
1261             error::wrap_c_api(
1262                 scales.size() == inputs.size() ? mkldnn_success
1263                                                : mkldnn_invalid_arguments,
1264                 "number of scales not equal to number of inputs");
1265
1266             error::wrap_c_api(mkldnn_sum_primitive_desc_create(
1267                     &result, nullptr, (int)c_api_inputs.size(), &scales[0],
1268                     &c_api_inputs[0]),
1269                 "could not create a sum primitive descriptor");
1270             reset(result);
1271         }
1272
1273         memory::primitive_desc dst_primitive_desc() const {
1274             memory::primitive_desc adesc;
1275             mkldnn_primitive_desc_t cdesc;
1276             const_mkldnn_primitive_desc_t const_cdesc =
1277                 mkldnn_primitive_desc_query_pd(get(),
1278                                mkldnn::convert_to_c(dst_pd), 0);
1279             error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc,
1280                     const_cdesc),
1281                     "could not clone a dst primitive descriptor");
1282             adesc.reset(cdesc);
1283             return adesc;
1284         }
1285
1286         engine get_engine() { return engine::query(*this); }
1287     };
1288
1289     sum(const primitive_desc &sum_pd,
1290             std::vector<primitive::at> &inputs, const memory &output) {
1291         mkldnn_primitive_t result;
1292
1293         std::vector<mkldnn_primitive_at_t> p_inputs;
1294         for (size_t i = 0; i < inputs.size(); i++)
1295             p_inputs.push_back(inputs[i].data);
1296         const_mkldnn_primitive_t outputs[] = { output.get() };
1297
1298         error::wrap_c_api(mkldnn_primitive_create(&result,
1299                     sum_pd.get(), &p_inputs[0], outputs),
1300                 "could not create a sum primitive");
1301         reset(result);
1302     }
1303 };
1304
1305 /// @}
1306
1307 /// @}
1308
1309 /// @addtogroup cpp_api_primitives Primitives
1310 /// @{
1311
1312 /// @addtogroup cpp_api_primitive_descriptors Primitive descriptors
1313 /// @{
1314
1315 /// A base class for all primitive descriptors.
1316 struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1317     primitive_desc(const_mkldnn_op_desc_t desc, const primitive_attr *attr,
1318             const engine &e, const_mkldnn_primitive_desc_t hint_fwd_pd) {
1319         mkldnn_primitive_desc_iterator_t iterator = nullptr;
1320         mkldnn_status_t status = mkldnn_primitive_desc_iterator_create_v2(
1321                 &iterator, desc, attr ? attr->get() : nullptr, e.get(),
1322                 hint_fwd_pd);
1323         error::wrap_c_api(status,
1324                 "could not create a primitive descriptor iterator");
1325         pd_iterator.reset(iterator);
1326         fetch_impl();
1327     }
1328
1329     engine get_engine() { return engine::query(*this); }
1330
1331     primitive_attr get_primitive_attr() const {
1332         const_mkldnn_primitive_attr_t const_cattr;
1333         error::wrap_c_api(mkldnn_primitive_desc_get_attr(get(), &const_cattr),
1334                 "could not get attributes");
1335         mkldnn_primitive_attr_t cattr;
1336         error::wrap_c_api(mkldnn_primitive_attr_clone(&cattr, const_cattr),
1337                 "could not clone attributes");
1338
1339         primitive_attr attr;
1340         attr.reset(cattr);
1341         return attr;
1342     }
1343
1344     /// Returns implementation name
1345     const char *impl_info_str() const {
1346         const char *res;
1347         error::wrap_c_api(mkldnn_primitive_desc_query(get(),
1348                     mkldnn_query_impl_info_str, 0, &res),
1349                 "could not query implementation info string");
1350         return res;
1351     }
1352
1353     /// Advances the next implementation for the given op descriptor.
1354     ///
1355     /// Returns:
1356     /// - @c true on success
1357     /// - @c false if the last implementation reached, and
1358     ///   the primitive descriptor itself is kept unchanged
1359     bool next_impl() {
1360         mkldnn_status_t status = mkldnn_primitive_desc_iterator_next(
1361                 pd_iterator.get());
1362         if (status == mkldnn_iterator_ends) return false;
1363         error::wrap_c_api(status, "primitive descriptor iterator next failed");
1364
1365         fetch_impl();
1366         return true;
1367     }
1368
1369     /// Queries and returns requested memory primitive descriptor.
1370     memory::primitive_desc query_mpd(query what, int idx = 0) const {
1371         std::vector<query> valid_w{input_pd, output_pd, src_pd, diff_src_pd,
1372             weights_pd, diff_weights_pd, dst_pd, diff_dst_pd, workspace_pd};
1373         if (!std::any_of(valid_w.cbegin(), valid_w.cend(),
1374                     [=](query q) { return what == q; }))
1375             throw error(mkldnn_invalid_arguments, "invalid memory query");
1376
1377         const_mkldnn_primitive_desc_t const_cdesc
1378             = mkldnn_primitive_desc_query_pd(get(),
1379                     mkldnn::convert_to_c(what), idx);
1380
1381         // TODO: is there a better way to inform about this?
1382         if (const_cdesc == nullptr)
1383             throw error(mkldnn_not_required, "queried memory is not required");
1384
1385         mkldnn_primitive_desc_t cdesc;
1386         error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1387                 "could not clone a memory primitive descriptor");
1388
1389         memory::primitive_desc ret;
1390         ret.reset(cdesc);
1391         return ret;
1392     }
1393
1394     // register specialized queries, e.g. src_primitive_desc()
1395 #   define REG_QUERY_MPD(name, what, idx) \
1396     memory::primitive_desc name ## _primitive_desc() const \
1397     { return query_mpd(what ## _pd, idx); }
1398
1399   private:
1400     handle<mkldnn_primitive_desc_iterator_t> pd_iterator;
1401     void fetch_impl() {
1402         mkldnn_primitive_desc_t pd = mkldnn_primitive_desc_iterator_fetch(
1403                 pd_iterator.get());
1404         error::wrap_c_api(pd != nullptr ? mkldnn_success : mkldnn_runtime_error,
1405                 "could not fetch a primitive descriptor from the iterator");
1406         reset(pd);
1407     }
1408 };
1409
1410 /// @}
1411
1412 /// @addtogroup cpp_api_convolution Convolution
1413 /// A primitive to compute convolution using different algorithms.
1414 ///
1415 /// @sa @ref c_api_convolution in @ref c_api
1416 /// @{
1417
1418 struct convolution_forward: public primitive {
1419     struct desc {
1420         mkldnn_convolution_desc_t data;
1421         desc(prop_kind aprop_kind, algorithm aalgorithm,
1422                 const memory::desc &src_desc,
1423                 const memory::desc &weights_desc,
1424                 const memory::desc &bias_desc,
1425                 const memory::desc &dst_desc,
1426                 const memory::dims strides,
1427                 const memory::dims padding_l,
1428                 const memory::dims padding_r,
1429                 const padding_kind apadding_kind) {
1430             memory::validate_dims(strides);
1431             memory::validate_dims(padding_l);
1432             memory::validate_dims(padding_r);
1433             error::wrap_c_api(mkldnn_convolution_forward_desc_init(&data,
1434                         mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1435                         &src_desc.data, &weights_desc.data, &bias_desc.data,
1436                         &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
1437                         mkldnn::convert_to_c(apadding_kind)),
1438                     "could not create a convolution forward descriptor");
1439         }
1440         desc(prop_kind aprop_kind, algorithm aalgorithm,
1441                 const memory::desc &src_desc,
1442                 const memory::desc &weights_desc,
1443                 const memory::desc &dst_desc,
1444                 const memory::dims strides,
1445                 const memory::dims padding_l,
1446                 const memory::dims padding_r,
1447                 const padding_kind apadding_kind) {
1448             memory::validate_dims(strides);
1449             memory::validate_dims(padding_l);
1450             memory::validate_dims(padding_r);
1451             error::wrap_c_api(mkldnn_convolution_forward_desc_init(&data,
1452                         mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1453                         &src_desc.data, &weights_desc.data, nullptr,
1454                         &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
1455                         mkldnn::convert_to_c(apadding_kind)),
1456                     "could not create a convolution forward descriptor");
1457         }
1458         desc(prop_kind aprop_kind, algorithm aalgorithm,
1459                 const memory::desc &src_desc,
1460                 const memory::desc &weights_desc,
1461                 const memory::desc &bias_desc,
1462                 const memory::desc &dst_desc,
1463                 const memory::dims strides,
1464                 const memory::dims dilates,
1465                 const memory::dims padding_l,
1466                 const memory::dims padding_r,
1467                 const padding_kind apadding_kind) {
1468             memory::validate_dims(strides);
1469             memory::validate_dims(dilates);
1470             memory::validate_dims(padding_l);
1471             memory::validate_dims(padding_r);
1472             error::wrap_c_api(
1473                 mkldnn_dilated_convolution_forward_desc_init(&data,
1474                     mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1475                         &src_desc.data, &weights_desc.data, &bias_desc.data,
1476                         &dst_desc.data, &strides[0], &dilates[0],
1477                         &padding_l[0], &padding_r[0],
1478                         mkldnn::convert_to_c(apadding_kind)),
1479                     "could not create a dilated convolution forward descriptor");
1480         }
1481         desc(prop_kind aprop_kind, algorithm aalgorithm,
1482                 const memory::desc &src_desc,
1483                 const memory::desc &weights_desc,
1484                 const memory::desc &dst_desc,
1485                 const memory::dims strides,
1486                 const memory::dims dilates,
1487                 const memory::dims padding_l,
1488                 const memory::dims padding_r,
1489                 const padding_kind apadding_kind) {
1490             memory::validate_dims(strides);
1491             memory::validate_dims(dilates);
1492             memory::validate_dims(padding_l);
1493             memory::validate_dims(padding_r);
1494             error::wrap_c_api(
1495                 mkldnn_dilated_convolution_forward_desc_init(&data,
1496                     mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1497                         &src_desc.data, &weights_desc.data, nullptr,
1498                         &dst_desc.data, &strides[0], &dilates[0],
1499                         &padding_l[0], &padding_r[0],
1500                         mkldnn::convert_to_c(apadding_kind)),
1501                     "could not create a dilated convolution forward descriptor");
1502         }
1503     };
1504
1505     struct primitive_desc : public mkldnn::primitive_desc {
1506         primitive_desc(const desc &desc, const engine &e)
1507             : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
1508
1509         primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
1510             : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
1511
1512         REG_QUERY_MPD(src, src, 0);
1513         REG_QUERY_MPD(weights, weights, 0);
1514         REG_QUERY_MPD(bias, weights, 1);
1515         REG_QUERY_MPD(dst, dst, 0);
1516     };
1517
1518     convolution_forward(const primitive_desc &aprimitive_desc,
1519             const primitive::at &src, const primitive::at &weights,
1520             const primitive::at &bias, const memory &dst) {
1521         mkldnn_primitive_t result;
1522         mkldnn_primitive_at_t inputs[] = { src.data, weights.data,
1523                     bias.data };
1524         const_mkldnn_primitive_t outputs[] = { dst.get() };
1525         error::wrap_c_api(mkldnn_primitive_create(&result,
1526                     aprimitive_desc.get(), inputs, outputs),
1527                 "could not create a convolution forward bias primitive");
1528         reset(result);
1529     }
1530
1531     convolution_forward(const primitive_desc &aprimitive_desc,
1532             const primitive::at &src, const primitive::at &weights,
1533             const memory &dst) {
1534         mkldnn_primitive_t result;
1535         mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
1536         const_mkldnn_primitive_t outputs[] = { dst.get() };
1537         check_num_parameters(aprimitive_desc.get(), 2, 1,
1538             "convolution forward");
1539         error::wrap_c_api(mkldnn_primitive_create(&result,
1540                     aprimitive_desc.get(), inputs, outputs),
1541                 "could not create a convolution forward primitive");
1542         reset(result);
1543     }
1544 };
1545
1546 struct convolution_backward_data : public primitive {
1547     struct desc {
1548         mkldnn_convolution_desc_t data;
1549         desc(algorithm aalgorithm,
1550                 const memory::desc &diff_src_desc,
1551                 const memory::desc &weights_desc,
1552                 const memory::desc &diff_dst_desc,
1553                 const memory::dims strides,
1554                 const memory::dims padding_l,
1555                 const memory::dims padding_r,
1556                 const padding_kind apadding_kind) {
1557             memory::validate_dims(strides);
1558             memory::validate_dims(padding_l);
1559             memory::validate_dims(padding_r);
1560             error::wrap_c_api(mkldnn_convolution_backward_data_desc_init(
1561                         &data, convert_to_c(aalgorithm), &diff_src_desc.data,
1562                         &weights_desc.data, &diff_dst_desc.data,
1563                         &strides[0], &padding_l[0], &padding_r[0],
1564                         mkldnn::convert_to_c(apadding_kind)),
1565                     "could not create a convolution backward data descriptor");
1566         }
1567         desc(algorithm aalgorithm,
1568                 const memory::desc &diff_src_desc,
1569                 const memory::desc &weights_desc,
1570                 const memory::desc &diff_dst_desc,
1571                 const memory::dims strides,
1572                 const memory::dims dilates,
1573                 const memory::dims padding_l,
1574                 const memory::dims padding_r,
1575                 const padding_kind apadding_kind) {
1576             memory::validate_dims(strides);
1577             memory::validate_dims(dilates);
1578             memory::validate_dims(padding_l);
1579             memory::validate_dims(padding_r);
1580             error::wrap_c_api(
1581                 mkldnn_dilated_convolution_backward_data_desc_init(
1582                     &data, convert_to_c(aalgorithm), &diff_src_desc.data,
1583                     &weights_desc.data, &diff_dst_desc.data,
1584                     &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1585                     mkldnn::convert_to_c(apadding_kind)),
1586                     "could not create a convolution backward data descriptor");
1587         }
1588     };
1589
1590     struct primitive_desc : public mkldnn::primitive_desc {
1591         primitive_desc(const desc &desc, const engine &e,
1592                 const convolution_forward::primitive_desc &hint_fwd_pd)
1593             : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
1594
1595         primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
1596                 const convolution_forward::primitive_desc &hint_fwd_pd)
1597             : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
1598
1599         REG_QUERY_MPD(diff_src, diff_src, 0);
1600         REG_QUERY_MPD(weights, weights, 0);
1601         REG_QUERY_MPD(diff_dst, diff_dst, 0);
1602     };
1603
1604     convolution_backward_data(const primitive_desc &aprimitive_desc,
1605             const primitive::at &diff_dst, const primitive::at &weights,
1606             const memory &diff_src) {
1607         mkldnn_primitive_t result;
1608         mkldnn_primitive_at_t inputs[] = { diff_dst.data, weights.data  };
1609         const_mkldnn_primitive_t outputs[] = { diff_src.get() };
1610         check_num_parameters(aprimitive_desc.get(), 2, 1,
1611             "convolution backward data");
1612         error::wrap_c_api(mkldnn_primitive_create(&result,
1613                     aprimitive_desc.get(), inputs, outputs),
1614                 "could not create a convolution backward data primitive");
1615         reset(result);
1616     }
1617 };
1618
1619 struct convolution_backward_weights : public primitive {
1620     struct desc {
1621         mkldnn_convolution_desc_t data;
1622         desc(algorithm aalgorithm,
1623                 const memory::desc &src_desc,
1624                 const memory::desc &diff_weights_desc,
1625                 const memory::desc &diff_bias_desc,
1626                 const memory::desc &diff_dst_desc,
1627                 const memory::dims strides,
1628                 const memory::dims padding_l,
1629                 const memory::dims padding_r,
1630                 const padding_kind apadding_kind) {
1631             memory::validate_dims(strides);
1632             memory::validate_dims(padding_l);
1633             memory::validate_dims(padding_r);
1634             error::wrap_c_api(mkldnn_convolution_backward_weights_desc_init(
1635                         &data, convert_to_c(aalgorithm), &src_desc.data,
1636                         &diff_weights_desc.data, &diff_bias_desc.data,
1637                         &diff_dst_desc.data,
1638                         &strides[0], &padding_l[0], &padding_r[0],
1639                         mkldnn::convert_to_c(apadding_kind)),
1640                     "could not create a convolution backward weights descriptor");
1641         }
1642         desc(algorithm aalgorithm,
1643                 const memory::desc &src_desc,
1644                 const memory::desc &diff_weights_desc,
1645                 const memory::desc &diff_dst_desc,
1646                 const memory::dims strides,
1647                 const memory::dims padding_l,
1648                 const memory::dims padding_r,
1649                 const padding_kind apadding_kind) {
1650             memory::validate_dims(strides);
1651             memory::validate_dims(padding_l);
1652             memory::validate_dims(padding_r);
1653             error::wrap_c_api(mkldnn_convolution_backward_weights_desc_init(
1654                         &data, convert_to_c(aalgorithm), &src_desc.data,
1655                         &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
1656                         &strides[0], &padding_l[0], &padding_r[0],
1657                         mkldnn::convert_to_c(apadding_kind)),
1658                     "could not create a convolution backward weights descriptor");
1659         }
1660         desc(algorithm aalgorithm,
1661                 const memory::desc &src_desc,
1662                 const memory::desc &diff_weights_desc,
1663                 const memory::desc &diff_bias_desc,
1664                 const memory::desc &diff_dst_desc,
1665                 const memory::dims strides,
1666                 const memory::dims dilates,
1667                 const memory::dims padding_l,
1668                 const memory::dims padding_r,
1669                 const padding_kind apadding_kind) {
1670             memory::validate_dims(strides);
1671             memory::validate_dims(dilates);
1672             memory::validate_dims(padding_l);
1673             memory::validate_dims(padding_r);
1674             error::wrap_c_api(mkldnn_dilated_convolution_backward_weights_desc_init(
1675                         &data, convert_to_c(aalgorithm), &src_desc.data,
1676                         &diff_weights_desc.data, &diff_bias_desc.data,
1677                         &diff_dst_desc.data,
1678                         &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1679                         mkldnn::convert_to_c(apadding_kind)),
1680                     "could not create a convolution backward weights descriptor");
1681         }
1682         desc(algorithm aalgorithm,
1683                 const memory::desc &src_desc,
1684                 const memory::desc &diff_weights_desc,
1685                 const memory::desc &diff_dst_desc,
1686                 const memory::dims strides,
1687                 const memory::dims dilates,
1688                 const memory::dims padding_l,
1689                 const memory::dims padding_r,
1690                 const padding_kind apadding_kind) {
1691             memory::validate_dims(strides);
1692             memory::validate_dims(dilates);
1693             memory::validate_dims(padding_l);
1694             memory::validate_dims(padding_r);
1695             error::wrap_c_api(mkldnn_dilated_convolution_backward_weights_desc_init(
1696                         &data, convert_to_c(aalgorithm), &src_desc.data,
1697                         &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
1698                         &strides[0], &dilates[0],  &padding_l[0], &padding_r[0],
1699                         mkldnn::convert_to_c(apadding_kind)),
1700                     "could not create a convolution backward weights descriptor");
1701         }
1702
1703     };
1704
1705     struct primitive_desc : public mkldnn::primitive_desc {
1706         primitive_desc(const desc &desc, const engine &e,
1707                 const convolution_forward::primitive_desc &hint_fwd_pd)
1708             : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
1709
1710         primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
1711                 const convolution_forward::primitive_desc &hint_fwd_pd)
1712             : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
1713
1714         REG_QUERY_MPD(src, src, 0);
1715         REG_QUERY_MPD(diff_weights, diff_weights, 0);
1716         REG_QUERY_MPD(diff_bias, diff_weights, 1);
1717         REG_QUERY_MPD(diff_dst, diff_dst, 0);
1718     };
1719
1720     convolution_backward_weights(const primitive_desc &aprimitive_desc,
1721             const primitive::at &src, const primitive::at &diff_dst,
1722             const memory &diff_weights, const memory &diff_bias) {
1723         mkldnn_primitive_t result;
1724         mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
1725         const_mkldnn_primitive_t outputs[] = { diff_weights.get(),
1726                     diff_bias.get() };
1727         check_num_parameters(aprimitive_desc.get(), 2, 2,
1728             "convolution backward weights");
1729         error::wrap_c_api(mkldnn_primitive_create(&result,
1730                     aprimitive_desc.get(), inputs, outputs),
1731                 "could not create a convolution backward weights primitive");
1732         reset(result);
1733     }
1734     convolution_backward_weights(const primitive_desc &aprimitive_desc,
1735             const primitive::at &src, const primitive::at &diff_dst,
1736             const memory &diff_weights) {
1737         mkldnn_primitive_t result;
1738         mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
1739         const_mkldnn_primitive_t outputs[] = { diff_weights.get() };
1740         check_num_parameters(aprimitive_desc.get(), 2, 1,
1741             "convolution backward weights");
1742         error::wrap_c_api(mkldnn_primitive_create(&result,
1743                     aprimitive_desc.get(), inputs, outputs),
1744                 "could not create a convolution backward weights primitive");
1745         reset(result);
1746     }
1747 };
1748
1749 /// @}
1750
1751 /// @addtogroup cpp_api_deconvolution Deconvolution
1752 /// A primitive to compute deconvolution using different algorithms.
1753 ///
1754 /// @sa @ref c_api_deconvolution in @ref c_api
1755 /// @{
1756
1757 struct deconvolution_forward: public primitive {
1758     struct desc {
1759         mkldnn_deconvolution_desc_t data;
1760         desc(prop_kind aprop_kind, algorithm aalgorithm,
1761                 const memory::desc &src_desc,
1762                 const memory::desc &weights_desc,
1763                 const memory::desc &bias_desc,
1764                 const memory::desc &dst_desc,
1765                 const memory::dims strides,
1766                 const memory::dims padding_l,
1767                 const memory::dims padding_r,
1768                 const padding_kind apadding_kind) {
1769             memory::validate_dims(strides);
1770             memory::validate_dims(padding_l);
1771             memory::validate_dims(padding_r);
1772             error::wrap_c_api(mkldnn_deconvolution_forward_desc_init(&data,
1773                         mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1774                         &src_desc.data, &weights_desc.data, &bias_desc.data,
1775                         &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
1776                         mkldnn::convert_to_c(apadding_kind)),
1777                     "could not create a deconvolution forward descriptor");
1778         }
1779         desc(prop_kind aprop_kind, algorithm aalgorithm,
1780                 const memory::desc &src_desc,
1781                 const memory::desc &weights_desc,
1782                 const memory::desc &dst_desc,
1783                 const memory::dims strides,
1784                 const memory::dims padding_l,
1785                 const memory::dims padding_r,
1786                 const padding_kind apadding_kind) {
1787             memory::validate_dims(strides);
1788             memory::validate_dims(padding_l);
1789             memory::validate_dims(padding_r);
1790             error::wrap_c_api(mkldnn_deconvolution_forward_desc_init(&data,
1791                         mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1792                         &src_desc.data, &weights_desc.data, nullptr,
1793                         &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
1794                         mkldnn::convert_to_c(apadding_kind)),
1795                     "could not create a deconvolution forward descriptor");
1796         }
1797         desc(prop_kind aprop_kind, algorithm aalgorithm,
1798                 const memory::desc &src_desc,
1799                 const memory::desc &weights_desc,
1800                 const memory::desc &bias_desc,
1801                 const memory::desc &dst_desc,
1802                 const memory::dims strides,
1803                 const memory::dims dilates,
1804                 const memory::dims padding_l,
1805                 const memory::dims padding_r,
1806                 const padding_kind apadding_kind) {
1807             memory::validate_dims(strides);
1808             memory::validate_dims(dilates);
1809             memory::validate_dims(padding_l);
1810             memory::validate_dims(padding_r);
1811             error::wrap_c_api(mkldnn_dilated_deconvolution_forward_desc_init(&data,
1812                         mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1813                         &src_desc.data, &weights_desc.data, &bias_desc.data,
1814                         &dst_desc.data, &strides[0], &dilates[0], &padding_l[0],
1815                         &padding_r[0], mkldnn::convert_to_c(apadding_kind)),
1816                     "could not create a dilated deconvolution forward descriptor");
1817         }
1818         desc(prop_kind aprop_kind, algorithm aalgorithm,
1819                 const memory::desc &src_desc,
1820                 const memory::desc &weights_desc,
1821                 const memory::desc &dst_desc,
1822                 const memory::dims strides,
1823                 const memory::dims dilates,
1824                 const memory::dims padding_l,
1825                 const memory::dims padding_r,
1826                 const padding_kind apadding_kind) {
1827             memory::validate_dims(strides);
1828             memory::validate_dims(dilates);
1829             memory::validate_dims(padding_l);
1830             memory::validate_dims(padding_r);
1831             error::wrap_c_api(mkldnn_dilated_deconvolution_forward_desc_init(&data,
1832                         mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1833                         &src_desc.data, &weights_desc.data, nullptr,
1834                         &dst_desc.data, &strides[0], &dilates[0], &padding_l[0],
1835                         &padding_r[0], mkldnn::convert_to_c(apadding_kind)),
1836                     "could not create a dilated deconvolution forward descriptor");
1837         }
1838     };
1839
1840     struct primitive_desc : public mkldnn::primitive_desc {
1841         primitive_desc(const desc &desc, const engine &e)
1842             : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
1843
1844         primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
1845             : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
1846
1847         REG_QUERY_MPD(src, src, 0);
1848         REG_QUERY_MPD(weights, weights, 0);
1849         REG_QUERY_MPD(bias, weights, 1);
1850         REG_QUERY_MPD(dst, dst, 0);
1851     };
1852
1853     deconvolution_forward(const primitive_desc &aprimitive_desc,
1854             const primitive::at &src, const primitive::at &weights,
1855             const primitive::at &bias, const memory &dst) {
1856         mkldnn_primitive_t result;
1857         mkldnn_primitive_at_t inputs[] = { src.data, weights.data,
1858                     bias.data };
1859         const_mkldnn_primitive_t outputs[] = { dst.get() };
1860         check_num_parameters(aprimitive_desc.get(), 3, 1,
1861             "deconvolution forward");
1862         error::wrap_c_api(mkldnn_primitive_create(&result,
1863                     aprimitive_desc.get(), inputs, outputs),
1864                 "could not create a deconvolution forward bias primitive");
1865         reset(result);
1866     }
1867
1868     deconvolution_forward(const primitive_desc &aprimitive_desc,
1869             const primitive::at &src, const primitive::at &weights,
1870             const memory &dst) {
1871         mkldnn_primitive_t result;
1872         mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
1873         const_mkldnn_primitive_t outputs[] = { dst.get() };
1874         check_num_parameters(aprimitive_desc.get(), 2, 1,
1875             "deconvolution forward");
1876         error::wrap_c_api(mkldnn_primitive_create(&result,
1877                     aprimitive_desc.get(), inputs, outputs),
1878                 "could not create a deconvolution forward primitive");
1879         reset(result);
1880     }
1881 };
1882
1883 struct deconvolution_backward_data : public primitive {
1884     struct desc {
1885         mkldnn_deconvolution_desc_t data;
1886         desc(algorithm aalgorithm,
1887                 const memory::desc &diff_src_desc,
1888                 const memory::desc &weights_desc,
1889                 const memory::desc &diff_dst_desc,
1890                 const memory::dims strides,
1891                 const memory::dims padding_l,
1892                 const memory::dims padding_r,
1893                 const padding_kind apadding_kind) {
1894             memory::validate_dims(strides);
1895             memory::validate_dims(padding_l);
1896             memory::validate_dims(padding_r);
1897             error::wrap_c_api(mkldnn_deconvolution_backward_data_desc_init(
1898                         &data, convert_to_c(aalgorithm), &diff_src_desc.data,
1899                         &weights_desc.data, &diff_dst_desc.data,
1900                         &strides[0], &padding_l[0], &padding_r[0],
1901                         mkldnn::convert_to_c(apadding_kind)),
1902                     "could not create a deconvolution backward data descriptor");
1903         }
1904         desc(algorithm aalgorithm,
1905                 const memory::desc &diff_src_desc,
1906                 const memory::desc &weights_desc,
1907                 const memory::desc &diff_dst_desc,
1908                 const memory::dims strides,
1909                 const memory::dims dilates,
1910                 const memory::dims padding_l,
1911                 const memory::dims padding_r,
1912                 const padding_kind apadding_kind) {
1913             memory::validate_dims(strides);
1914             memory::validate_dims(dilates);
1915             memory::validate_dims(padding_l);
1916             memory::validate_dims(padding_r);
1917             error::wrap_c_api(mkldnn_dilated_deconvolution_backward_data_desc_init(
1918                         &data, convert_to_c(aalgorithm), &diff_src_desc.data,
1919                         &weights_desc.data, &diff_dst_desc.data,
1920                         &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1921                         mkldnn::convert_to_c(apadding_kind)),
1922                     "could not create a dilated deconvolution backward data descriptor");
1923         }
1924     };
1925
1926     struct primitive_desc : public mkldnn::primitive_desc {
1927         primitive_desc(const desc &desc, const engine &e,
1928                 const deconvolution_forward::primitive_desc &hint_fwd_pd)
1929             : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
1930
1931         primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
1932                 const deconvolution_forward::primitive_desc &hint_fwd_pd)
1933             : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
1934
1935         REG_QUERY_MPD(diff_src, diff_src, 0);
1936         REG_QUERY_MPD(weights, weights, 0);
1937         REG_QUERY_MPD(diff_dst, diff_dst, 0);
1938     };
1939
1940     deconvolution_backward_data(const primitive_desc &aprimitive_desc,
1941             const primitive::at &diff_dst, const primitive::at &weights,
1942             const memory &diff_src) {
1943         mkldnn_primitive_t result;
1944         mkldnn_primitive_at_t inputs[] = { diff_dst.data, weights.data  };
1945         const_mkldnn_primitive_t outputs[] = { diff_src.get() };
1946         check_num_parameters(aprimitive_desc.get(), 2, 1,
1947             "deconvolution backward data");
1948         error::wrap_c_api(mkldnn_primitive_create(&result,
1949                     aprimitive_desc.get(), inputs, outputs),
1950                 "could not create a deconvolution backward data primitive");
1951         reset(result);
1952     }
1953 };
1954
1955 struct deconvolution_backward_weights : public primitive {
1956     struct desc {
1957         mkldnn_deconvolution_desc_t data;
1958         desc(algorithm aalgorithm,
1959                 const memory::desc &src_desc,
1960                 const memory::desc &diff_weights_desc,
1961                 const memory::desc &diff_bias_desc,
1962                 const memory::desc &diff_dst_desc,
1963                 const memory::dims strides,
1964                 const memory::dims padding_l,
1965                 const memory::dims padding_r,
1966                 const padding_kind apadding_kind) {
1967             memory::validate_dims(strides);
1968             memory::validate_dims(padding_l);
1969             memory::validate_dims(padding_r);
1970             error::wrap_c_api(mkldnn_deconvolution_backward_weights_desc_init(
1971                         &data, convert_to_c(aalgorithm), &src_desc.data,
1972                         &diff_weights_desc.data, &diff_bias_desc.data,
1973                         &diff_dst_desc.data,
1974                         &strides[0], &padding_l[0], &padding_r[0],
1975                         mkldnn::convert_to_c(apadding_kind)),
1976                     "could not create a deconvolution backward weights descriptor");
1977         }
1978         desc(algorithm aalgorithm,
1979                 const memory::desc &src_desc,
1980                 const memory::desc &diff_weights_desc,
1981                 const memory::desc &diff_dst_desc,
1982                 const memory::dims strides,
1983                 const memory::dims padding_l,
1984                 const memory::dims padding_r,
1985                 const padding_kind apadding_kind) {
1986             memory::validate_dims(strides);
1987             memory::validate_dims(padding_l);
1988             memory::validate_dims(padding_r);
1989             error::wrap_c_api(mkldnn_deconvolution_backward_weights_desc_init(
1990                         &data, convert_to_c(aalgorithm), &src_desc.data,
1991                         &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
1992                         &strides[0], &padding_l[0], &padding_r[0],
1993                         mkldnn::convert_to_c(apadding_kind)),
1994                     "could not create a deconvolution backward weights descriptor");
1995         }
1996         desc(algorithm aalgorithm,
1997                 const memory::desc &src_desc,
1998                 const memory::desc &diff_weights_desc,
1999                 const memory::desc &diff_bias_desc,
2000                 const memory::desc &diff_dst_desc,
2001                 const memory::dims strides,
2002                 const memory::dims dilates,
2003                 const memory::dims padding_l,
2004                 const memory::dims padding_r,
2005                 const padding_kind apadding_kind) {
2006             memory::validate_dims(strides);
2007             memory::validate_dims(dilates);
2008             memory::validate_dims(padding_l);
2009             memory::validate_dims(padding_r);
2010             error::wrap_c_api(mkldnn_dilated_deconvolution_backward_weights_desc_init(
2011                         &data, convert_to_c(aalgorithm), &src_desc.data,
2012                         &diff_weights_desc.data, &diff_bias_desc.data,
2013                         &diff_dst_desc.data,
2014                         &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
2015                         mkldnn::convert_to_c(apadding_kind)),
2016                     "could not create a dilated  deconvolution backward weights descriptor");
2017         }
2018         desc(algorithm aalgorithm,
2019                 const memory::desc &src_desc,
2020                 const memory::desc &diff_weights_desc,
2021                 const memory::desc &diff_dst_desc,
2022                 const memory::dims strides,
2023                 const memory::dims dilates,
2024                 const memory::dims padding_l,
2025                 const memory::dims padding_r,
2026                 const padding_kind apadding_kind) {
2027             memory::validate_dims(strides);
2028             memory::validate_dims(dilates);
2029             memory::validate_dims(padding_l);
2030             memory::validate_dims(padding_r);
2031             error::wrap_c_api(mkldnn_dilated_deconvolution_backward_weights_desc_init(
2032                         &data, convert_to_c(aalgorithm), &src_desc.data,
2033                         &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
2034                         &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
2035                         mkldnn::convert_to_c(apadding_kind)),
2036                     "could not create a dilated deconvolution backward weights descriptor");
2037         }
2038     };
2039
2040     struct primitive_desc : public mkldnn::primitive_desc {
2041         primitive_desc(const desc &desc, const engine &e,
2042                 const deconvolution_forward::primitive_desc &hint_fwd_pd)
2043             : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2044
2045         primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2046                 const deconvolution_forward::primitive_desc &hint_fwd_pd)
2047             : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2048
2049         REG_QUERY_MPD(src, src, 0);
2050         REG_QUERY_MPD(diff_weights, diff_weights, 0);
2051         REG_QUERY_MPD(diff_bias, diff_weights, 1);
2052         REG_QUERY_MPD(diff_dst, diff_dst, 0);
2053     };
2054
2055     deconvolution_backward_weights(const primitive_desc &aprimitive_desc,
2056             const primitive::at &src, const primitive::at &diff_dst,
2057             const memory &diff_weights, const memory &diff_bias) {
2058         mkldnn_primitive_t result;
2059         mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
2060         const_mkldnn_primitive_t outputs[] = { diff_weights.get(),
2061                     diff_bias.get() };
2062         check_num_parameters(aprimitive_desc.get(), 2, 2,
2063             "deconvolution backward weights");
2064         error::wrap_c_api(mkldnn_primitive_create(&result,
2065                     aprimitive_desc.get(), inputs, outputs),
2066                 "could not create a deconvolution backward weights primitive");
2067         reset(result);
2068     }
2069     deconvolution_backward_weights(const primitive_desc &aprimitive_desc,
2070             const primitive::at &src, const primitive::at &diff_dst,
2071             const memory &diff_weights) {
2072         mkldnn_primitive_t result;
2073         mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
2074         const_mkldnn_primitive_t outputs[] = { diff_weights.get() };
2075         check_num_parameters(aprimitive_desc.get(), 2, 1,
2076             "deconvolution backward weights");
2077         error::wrap_c_api(mkldnn_primitive_create(&result,
2078                     aprimitive_desc.get(), inputs, outputs),
2079                 "could not create a deconvolution backward weights primitive");
2080         reset(result);
2081     }
2082 };
2083
2084 /// @}
2085
2086 /// @addtogroup cpp_api_roi_pooling ROIPooling
2087 /// @{
2088
2089 struct roi_pooling_forward : public primitive {
2090     struct desc {
2091         mkldnn_roi_pooling_desc_t data;
2092         std::vector<mkldnn_memory_desc_t> c_api_inputs;
2093
2094         desc(prop_kind aprop_kind, algorithm aalgorithm, std::vector<memory::desc> inputs,
2095              const memory::desc &dst_desc, int pooled_h, int pooled_w, double spatial_scale) {
2096
2097             for(size_t i = 0; i < inputs.size(); i++) {
2098                 c_api_inputs.push_back(inputs[i].data);
2099             }
2100
2101             error::wrap_c_api(mkldnn_roi_pooling_forward_desc_init(&data,
2102                         mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), &c_api_inputs[0],
2103                         c_api_inputs.size(),
2104                         &dst_desc.data, pooled_h, pooled_w, spatial_scale),
2105                     "could not create a roi pooling forward descriptor");
2106         }
2107     };
2108
2109     struct primitive_desc : public handle<mkldnn_primitive_desc_t>{
2110         primitive_desc(const desc &adesc, const engine &aengine) {
2111             mkldnn_primitive_desc_t result;
2112             error::wrap_c_api(mkldnn_primitive_desc_create(
2113                         &result, &adesc.data, aengine.get(), nullptr),
2114                     "could not create a roi pooling forward primitive descriptor");
2115             reset(result);
2116         }
2117     };
2118
2119     roi_pooling_forward(const primitive_desc &aprimitive_desc,
2120             std::vector<primitive::at> &inputs, const memory &dst) {
2121         mkldnn_primitive_t result;
2122
2123         std::vector<mkldnn_primitive_at_t> p_inputs;
2124         for (size_t i = 0; i < inputs.size(); i++) {
2125             p_inputs.push_back(inputs[i].data);
2126         }
2127
2128         const_mkldnn_primitive_t outputs[] = { dst.get() };
2129         error::wrap_c_api(mkldnn_primitive_create(&result,
2130                     aprimitive_desc.get(), &p_inputs[0], outputs),
2131                 "could not create a roi pooling forward primitive");
2132         reset(result);
2133     }
2134 };
2135
2136 /// @}
2137
2138 /// @addtogroup cpp_api_lrn LRN
2139 /// A primitive to perform local response normalization (LRN) across or within
2140 /// channels.
2141 ///
2142 /// @sa @ref c_api_lrn in @ref c_api
2143 /// @{
2144
2145 struct lrn_forward : public primitive {
2146     struct desc {
2147         mkldnn_lrn_desc_t data;
2148         desc(prop_kind aprop_kind, algorithm aalgorithm,
2149             const memory::desc &src_desc,
2150             int local_size, float alpha, float beta, float k)
2151         {
2152             error::wrap_c_api(mkldnn_lrn_forward_desc_init(&data,
2153                 mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
2154                 &src_desc.data, local_size, alpha, beta, k),
2155                 "could not create a lrn forward descriptor");
2156         }
2157         desc(prop_kind aprop_kind, algorithm aalgorithm,
2158             const memory::desc &src_desc,
2159             int local_size, float alpha, float beta)
2160         {
2161             error::wrap_c_api(mkldnn_lrn_forward_desc_init(&data,
2162                 mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
2163                 &src_desc.data, local_size, alpha, beta, float(1.0)),
2164                 "could not create a lrn forward descriptor");
2165         }
2166     };
2167
2168     struct primitive_desc : public mkldnn::primitive_desc {
2169         primitive_desc(const desc &desc, const engine &e)
2170             : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2171
2172         primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2173             : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2174
2175         REG_QUERY_MPD(src, src, 0);
2176         REG_QUERY_MPD(dst, dst, 0);
2177         REG_QUERY_MPD(workspace, workspace, 0);
2178     };
2179
2180     lrn_forward(const primitive_desc &aprimitive_desc,
2181             const primitive::at &src, const memory &workspace,
2182             const memory &dst) {
2183         mkldnn_primitive_t result;
2184         mkldnn_primitive_at_t inputs[] = { src.data };
2185         const_mkldnn_primitive_t outputs[] = { dst.get(),
2186                 workspace.get() };
2187         check_num_parameters(aprimitive_desc.get(), 1, 2, "lrn forward");
2188         error::wrap_c_api(mkldnn_primitive_create(&result,
2189                 aprimitive_desc.get(), inputs, outputs),
2190             "could not create a lrn forward primitive");
2191         reset(result);
2192     }
2193
2194     lrn_forward(const primitive_desc &aprimitive_desc,
2195             const primitive::at &src, const memory &dst) {
2196         mkldnn_primitive_t result;
2197         mkldnn_primitive_at_t inputs[] = { src.data };
2198         const_mkldnn_primitive_t outputs[] = { dst.get() };
2199         check_num_parameters(aprimitive_desc.get(), 1, 1, "lrn forward");
2200         error::wrap_c_api(mkldnn_primitive_create(&result,
2201                 aprimitive_desc.get(), inputs, outputs),
2202             "could not create a lrn forward primitive");
2203         reset(result);
2204     }
2205 };
2206
2207 struct lrn_backward : public primitive {
2208     struct desc {
2209         mkldnn_lrn_desc_t data;
2210         desc(algorithm aalgorithm,
2211             const memory::desc &data_desc,
2212             const memory::desc &diff_data_desc,
2213             int local_size, float alpha, float beta, float k)
2214         {
2215             error::wrap_c_api(mkldnn_lrn_backward_desc_init(&data,
2216                 convert_to_c(aalgorithm), &diff_data_desc.data,
2217                 &data_desc.data, local_size, alpha, beta, k),
2218                 "could not create a lrn backward descriptor");
2219         }
2220         desc(algorithm aalgorithm,
2221             const memory::desc &data_desc,
2222             const memory::desc &diff_data_desc,
2223             int local_size, float alpha, float beta)
2224         {
2225             error::wrap_c_api(mkldnn_lrn_backward_desc_init(&data,
2226                 convert_to_c(aalgorithm), &diff_data_desc.data,
2227                 &data_desc.data, local_size, alpha, beta, float(1.0)),
2228                 "could not create a lrn backward descriptor");
2229         }
2230     };
2231
2232     struct primitive_desc : public mkldnn::primitive_desc {
2233         primitive_desc(const desc &desc, const engine &e,
2234                 const lrn_forward::primitive_desc &hint_fwd_pd)
2235             : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2236
2237         primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2238                 const lrn_forward::primitive_desc &hint_fwd_pd)
2239             : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2240
2241         REG_QUERY_MPD(diff_src, diff_src, 0);
2242         REG_QUERY_MPD(diff_dst, diff_dst, 0);
2243         REG_QUERY_MPD(workspace, workspace, 0);
2244     };
2245
2246     lrn_backward(const primitive_desc &aprimitive_desc,
2247             const primitive::at &src, const primitive::at &diff_dst,
2248             const primitive::at &workspace, const memory &diff_src) {
2249         mkldnn_primitive_t result;
2250         mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data,
2251                 workspace.data };
2252         const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2253         check_num_parameters(aprimitive_desc.get(), 3, 1, "lrn backward");
2254         error::wrap_c_api(mkldnn_primitive_create(&result,
2255                 aprimitive_desc.get(), inputs, outputs),
2256             "could not create a lrn backward primitive");
2257         reset(result);
2258     }
2259
2260     lrn_backward(const primitive_desc &aprimitive_desc,
2261             const primitive::at &src, const primitive::at &diff_dst,
2262             const memory &diff_src) {
2263         mkldnn_primitive_t result;
2264         mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
2265         const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2266         check_num_parameters(aprimitive_desc.get(), 2, 1, "lrn backward");
2267         error::wrap_c_api(mkldnn_primitive_create(&result,
2268                 aprimitive_desc.get(), inputs, outputs),
2269             "could not create a lrn backward primitive");
2270         reset(result);
2271     }
2272 };
2273
2274 /// @}
2275
2276 /// @addtogroup cpp_api_pooling Pooling
2277 /// A primitive to perform max or average pooling.
2278 ///
2279 /// @sa @ref c_api_pooling in @ref c_api
2280 /// @{
2281
2282 struct pooling_forward : public primitive {
2283     struct desc {
2284         mkldnn_pooling_desc_t data;
2285         desc(prop_kind aprop_kind, algorithm aalgorithm,
2286                 const memory::desc &src_desc,
2287                 const memory::desc &dst_desc,
2288                 const memory::dims strides,
2289                 const memory::dims kernel,
2290                 const memory::dims padding_l,
2291                 const memory::dims padding_r,
2292                 const padding_kind apadding_kind) {
2293             memory::validate_dims(strides);
2294             memory::validate_dims(kernel);
2295             memory::validate_dims(padding_l);
2296             memory::validate_dims(padding_r);
2297             error::wrap_c_api(mkldnn_pooling_forward_desc_init(&data,
2298                         mkldnn::convert_to_c(aprop_kind),
2299                         convert_to_c(aalgorithm),
2300                         &src_desc.data, &dst_desc.data,
2301                         &strides[0], &kernel[0],
2302                         &padding_l[0], &padding_r[0],
2303                         mkldnn::convert_to_c(apadding_kind)),
2304                     "could not init a forward pooling descriptor");
2305         }
2306     };
2307
2308     struct primitive_desc : public mkldnn::primitive_desc {
2309         primitive_desc(const desc &desc, const engine &e)
2310             : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2311
2312         primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2313             : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2314
2315         REG_QUERY_MPD(src, src, 0);
2316         REG_QUERY_MPD(dst, dst, 0);
2317         REG_QUERY_MPD(workspace, workspace, 0);
2318     };
2319
2320     pooling_forward(const primitive_desc &aprimitive_desc, const primitive::at &src,
2321             const memory &dst) {
2322         mkldnn_primitive_t result;
2323         mkldnn_primitive_at_t inputs[] = { src.data };
2324         const_mkldnn_primitive_t outputs[] = { dst.get(), nullptr };
2325         check_num_parameters(aprimitive_desc.get(), 1, 1, "pooling forward");
2326         error::wrap_c_api(mkldnn_primitive_create(&result,
2327                     aprimitive_desc.get(), inputs, outputs),
2328                 "could not create a pooling forward primitive");
2329         reset(result);
2330     }
2331
2332     pooling_forward(const primitive_desc &aprimitive_desc, const primitive::at &src,
2333             const memory &dst, const memory &workspace) {
2334         mkldnn_primitive_t result;
2335         mkldnn_primitive_at_t inputs[] = { src.data };
2336         const_mkldnn_primitive_t outputs[] = { dst.get(), workspace.get() };
2337         check_num_parameters(aprimitive_desc.get(), 1, 2, "pooling forward");
2338         error::wrap_c_api(mkldnn_primitive_create(&result,
2339                     aprimitive_desc.get(), inputs, outputs),
2340                 "could not create a pooling forward primitive");
2341         reset(result);
2342     }
2343 };
2344
2345 struct pooling_backward : public primitive {
2346     struct desc {
2347         mkldnn_pooling_desc_t data;
2348         desc(algorithm aalgorithm,
2349                 const memory::desc &diff_src_desc,
2350                 const memory::desc &diff_dst_desc,
2351                 const memory::dims &strides,
2352                 const memory::dims &kernel,
2353                 const memory::dims &padding_l,
2354                 const memory::dims &padding_r,
2355                 const padding_kind apadding_kind) {
2356             memory::validate_dims(strides);
2357             memory::validate_dims(kernel);
2358             memory::validate_dims(padding_l);
2359             memory::validate_dims(padding_r);
2360             error::wrap_c_api(mkldnn_pooling_backward_desc_init(&data,
2361                         convert_to_c(aalgorithm),
2362                         &diff_src_desc.data, &diff_dst_desc.data,
2363                         &strides[0], &kernel[0],
2364                         &padding_l[0], &padding_r[0],
2365                         mkldnn::convert_to_c(apadding_kind)),
2366                     "could not init a backward pooling descriptor");
2367         }
2368     };
2369
2370     struct primitive_desc : public mkldnn::primitive_desc {
2371         primitive_desc(const desc &desc, const engine &e,
2372                 const pooling_forward::primitive_desc &hint_fwd_pd)
2373             : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2374
2375         primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2376                 const pooling_forward::primitive_desc &hint_fwd_pd)
2377             : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2378
2379         REG_QUERY_MPD(diff_src, diff_src, 0);
2380         REG_QUERY_MPD(diff_dst, diff_dst, 0);
2381         REG_QUERY_MPD(workspace, workspace, 0);
2382     };
2383
2384     pooling_backward(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst,
2385             const memory &diff_src) {
2386         mkldnn_primitive_t result;
2387         mkldnn_primitive_at_t inputs[] = { diff_dst.data };
2388         const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2389         check_num_parameters(aprimitive_desc.get(), 1, 1, "pooling backward");
2390         error::wrap_c_api(mkldnn_primitive_create(&result,
2391                     aprimitive_desc.get(), inputs, outputs),
2392                 "could not create a pooling backward primitive");
2393         reset(result);
2394     }
2395
2396     pooling_backward(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst,
2397             const primitive::at &workspace, const memory &diff_src) {
2398         mkldnn_primitive_t result;
2399         mkldnn_primitive_at_t inputs[] = { diff_dst.data, workspace.data };
2400         const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2401         check_num_parameters(aprimitive_desc.get(), 2, 1, "pooling backward");
2402         error::wrap_c_api(mkldnn_primitive_create(&result,
2403                     aprimitive_desc.get(), inputs, outputs),
2404                 "could not create a pooling backward primitive");
2405         reset(result);
2406     }
2407 };
2408
2409 /// @}
2410
2411 /// @addtogroup cpp_api_eltwise Eltwise
2412 /// A primitive to compute element-wise operations like parametric rectifier
2413 /// linear unit (ReLU).
2414 ///
2415 /// @sa @ref c_api_eltwise in @ref c_api
2416 /// @{
2417
2418 struct eltwise_forward : public primitive {
2419     struct desc {
2420         mkldnn_eltwise_desc_t data;
2421         template <typename T>
2422         desc(prop_kind aprop_kind, algorithm alg_kind,
2423                 const memory::desc &src_desc, T alpha = 0, T beta = 0) {
2424             error::wrap_c_api(mkldnn_eltwise_forward_desc_init(&data,
2425                         mkldnn::convert_to_c(aprop_kind),
2426                         mkldnn::convert_to_c(alg_kind), &src_desc.data,
2427                         static_cast<float>(alpha), static_cast<float>(beta)),
2428                     "could not create a eltwise forward descriptor");
2429         }
2430     };
2431
2432     struct primitive_desc : public mkldnn::primitive_desc {
2433         primitive_desc(const desc &desc, const engine &e)
2434             : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2435
2436         primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2437             : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2438
2439         REG_QUERY_MPD(src, src, 0);
2440         REG_QUERY_MPD(dst, dst, 0);
2441     };
2442
2443     eltwise_forward(const primitive_desc &aprimitive_desc,
2444             const primitive::at &src, const memory &dst) {
2445         mkldnn_primitive_t result;
2446         mkldnn_primitive_at_t inputs[] = { src.data };
2447         const_mkldnn_primitive_t outputs[] = { dst.get() };
2448         check_num_parameters(aprimitive_desc.get(), 1, 1, "eltwise forward");
2449         error::wrap_c_api(mkldnn_primitive_create(&result,
2450                 aprimitive_desc.get(), inputs, outputs),
2451             "could not create a eltwise forward primitive");
2452         reset(result);
2453     }
2454 };
2455
2456 struct eltwise_backward : public primitive {
2457     struct desc {
2458         mkldnn_eltwise_desc_t data;
2459
2460         template <typename T>
2461         desc(algorithm alg_kind, const memory::desc &diff_data_desc,
2462                 const memory::desc &data_desc, T alpha = 0, T beta = 0) {
2463             error::wrap_c_api(mkldnn_eltwise_backward_desc_init(&data,
2464                         mkldnn::convert_to_c(alg_kind), &diff_data_desc.data,
2465                         &data_desc.data, static_cast<float>(alpha),
2466                         static_cast<float>(beta)),
2467                     "could not create a eltwise backward descriptor");
2468         }
2469     };
2470
2471     struct primitive_desc : public mkldnn::primitive_desc {
2472         primitive_desc(const desc &desc, const engine &e,
2473                 const eltwise_forward::primitive_desc &hint_fwd_pd)
2474             : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2475
2476         primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2477                 const eltwise_forward::primitive_desc &hint_fwd_pd)
2478             : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2479
2480         REG_QUERY_MPD(src, src, 0);
2481         REG_QUERY_MPD(diff_src, diff_src, 0);
2482         REG_QUERY_MPD(diff_dst, diff_dst, 0);
2483     };
2484
2485     eltwise_backward(const primitive_desc &aprimitive_desc,
2486             const primitive::at &src, const primitive::at &diff_dst,
2487             const memory &diff_src) {
2488         mkldnn_primitive_t result;
2489         mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
2490         const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2491         check_num_parameters(aprimitive_desc.get(), 2, 1, "eltwise backward");
2492         error::wrap_c_api(mkldnn_primitive_create(&result,
2493                 aprimitive_desc.get(), inputs, outputs),
2494             "could not create a eltwise backward primitive");
2495         reset(result);
2496     }
2497 };
2498
2499 /// @}
2500
2501 /// @addtogroup cpp_api_depthwise Depthwise
2502 /// @{
2503
2504 struct depthwise_forward : public primitive {
2505     struct desc {
2506         mkldnn_depthwise_desc_t data;
2507
2508         desc(prop_kind aprop_kind, algorithm alg_kind,
2509              const memory::desc &src_desc, const memory::desc &dst_desc, const memory::desc &weights_desc,
2510              const memory::desc &bias_desc) {
2511             error::wrap_c_api(mkldnn_depthwise_forward_desc_init(&data,
2512                                                                  mkldnn::convert_to_c(aprop_kind),
2513                                                                      mkldnn::convert_to_c(alg_kind),
2514                                                                      &src_desc.data, &dst_desc.data,
2515                                                                  &weights_desc.data, &bias_desc.data),
2516                               "could not create a depthwise forward descriptor");
2517         }
2518
2519         desc(prop_kind aprop_kind, algorithm alg_kind,
2520              const memory::desc &src_desc, const memory::desc &dst_desc, const memory::desc &weights_desc) {
2521             error::wrap_c_api(mkldnn_depthwise_forward_desc_init(&data,
2522                                                                  mkldnn::convert_to_c(aprop_kind),
2523                                                                  mkldnn::convert_to_c(alg_kind),
2524                                                                  &src_desc.data, &dst_desc.data,
2525                                                                  &weights_desc.data, nullptr),
2526                               "could not create a depthwise forward descriptor");
2527         }
2528     };
2529
2530     struct primitive_desc : public mkldnn::primitive_desc {
2531         primitive_desc(const desc &desc, const engine &e)
2532             : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2533
2534         primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2535             : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2536
2537         REG_QUERY_MPD(src, src, 0);
2538         REG_QUERY_MPD(dst, dst, 0);
2539     };
2540
2541     depthwise_forward(const primitive_desc &aprimitive_desc,
2542                       const primitive::at &src, const primitive::at &weights,
2543                       const primitive::at &bias, const memory &dst) {
2544         mkldnn_primitive_t result;
2545         mkldnn_primitive_at_t inputs[] = { src.data, weights.data, bias.data };
2546         const_mkldnn_primitive_t outputs[] = { dst.get() };
2547         error::wrap_c_api(mkldnn_primitive_create(&result,
2548                                                   aprimitive_desc.get(), inputs, outputs),
2549                           "could not create a depthwise forward primitive");
2550         reset(result);
2551     }
2552
2553     depthwise_forward(const primitive_desc &aprimitive_desc,
2554                       const primitive::at &src, const primitive::at &weights,
2555                       const memory &dst) {
2556         mkldnn_primitive_t result;
2557         mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
2558         const_mkldnn_primitive_t outputs[] = { dst.get() };
2559         error::wrap_c_api(mkldnn_primitive_create(&result,
2560                                                   aprimitive_desc.get(), inputs, outputs),
2561                           "could not create a depthwise forward primitive");
2562         reset(result);
2563     }
2564 };
2565
2566 /// @}
2567
2568 /// @addtogroup cpp_api_softmax Softmax
2569 /// A primitive to perform softmax.
2570 ///
2571 /// @sa @ref c_api_softmax in @ref c_api
2572 /// @{
2573
2574 struct softmax_forward : public primitive {
2575     struct desc {
2576         mkldnn_softmax_desc_t data;
2577         desc(prop_kind aprop_kind, const memory::desc &data_desc,
2578              int softmax_axis) {
2579             error::wrap_c_api(mkldnn_softmax_forward_desc_init(&data,
2580                     mkldnn::convert_to_c(aprop_kind), &data_desc.data,
2581                     softmax_axis),
2582                 "could not create a softmax forward descriptor");
2583         }
2584     };
2585
2586     struct primitive_desc : public mkldnn::primitive_desc {
2587         primitive_desc(const desc &desc, const engine &e)
2588             : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2589
2590         primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2591             : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2592
2593         REG_QUERY_MPD(src, src, 0);
2594         REG_QUERY_MPD(dst, dst, 0);
2595     };
2596
2597     softmax_forward(const primitive_desc &aprimitive_desc,
2598             const primitive::at &src, const memory &dst) {
2599         mkldnn_primitive_t result;
2600         mkldnn_primitive_at_t inputs[] = { src.data };
2601         const_mkldnn_primitive_t outputs[] = { dst.get() };
2602         check_num_parameters(aprimitive_desc.get(), 1, 1, "softmax forward");
2603         error::wrap_c_api(mkldnn_primitive_create(&result,
2604                 aprimitive_desc.get(), inputs, outputs),
2605             "could not create a softmax forward primitive");
2606         reset(result);
2607     }
2608 };
2609
2610 struct softmax_backward : public primitive {
2611     struct desc {
2612         mkldnn_softmax_desc_t data;
2613         desc(const memory::desc &diff_desc, const memory::desc &data_desc,
2614                 int softmax_axis) {
2615             error::wrap_c_api(mkldnn_softmax_backward_desc_init(&data,
2616                         &diff_desc.data, &data_desc.data, softmax_axis),
2617                     "could not init a backward softmax descriptor");
2618         }
2619     };
2620
2621     struct primitive_desc : public mkldnn::primitive_desc {
2622         primitive_desc(const desc &desc, const engine &e,
2623                 const softmax_forward::primitive_desc &hint_fwd_pd)
2624             : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2625
2626         primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2627                 const softmax_forward::primitive_desc &hint_fwd_pd)
2628             : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2629
2630         REG_QUERY_MPD(dst, dst, 0);
2631         REG_QUERY_MPD(diff_src, diff_src, 0);
2632         REG_QUERY_MPD(diff_dst, diff_dst, 0);
2633         REG_QUERY_MPD(workspace, workspace, 0);
2634     };
2635
2636     softmax_backward(const primitive_desc &aprimitive_desc,
2637             const primitive::at &dst, const primitive::at &diff_dst,
2638             const memory &diff_src) {
2639         mkldnn_primitive_t result;
2640         mkldnn_primitive_at_t inputs[] = { dst.data, diff_dst.data };
2641         const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2642         error::wrap_c_api(mkldnn_primitive_create(&result,
2643                     aprimitive_desc.get(), inputs, outputs),
2644                 "could not create a softmax backward primitive");
2645         reset(result);
2646     }
2647 };
2648
2649 /// @}
2650
2651 /// @addtogroup cpp_api_batch_norm Batch normalization
2652 /// A primitive to perform batch normalization.
2653 ///
2654 /// @sa @ref c_api_batch_normalization in @ref c_api
2655 /// @{
2656
2657 struct batch_normalization_forward : public primitive {
2658     struct desc {
2659         mkldnn_batch_normalization_desc_t data;
2660         template <typename T>
2661         desc(prop_kind aprop_kind, const memory::desc &src_desc, T epsilon,
2662                 unsigned flags) {
2663             error::wrap_c_api(
2664                     mkldnn_batch_normalization_forward_desc_init(&data,
2665                         mkldnn::convert_to_c(aprop_kind), &src_desc.data,
2666                         static_cast<float>(epsilon), flags),
2667                 "could not create a batch normalization forward descriptor");
2668         }
2669     };
2670
2671     struct primitive_desc : public mkldnn::primitive_desc {
2672         primitive_desc(const desc &desc, const engine &e)
2673             : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2674
2675         primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2676             : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2677
2678         REG_QUERY_MPD(src, src, 0);
2679         REG_QUERY_MPD(weights, weights, 0);
2680         REG_QUERY_MPD(dst, dst, 0);
2681         REG_QUERY_MPD(workspace, workspace, 0);
2682
2683         memory::primitive_desc mean_primitive_desc() const
2684         { return stat_primitive_desc(mean); }
2685         memory::primitive_desc variance_primitive_desc() const
2686         { return stat_primitive_desc(var); }
2687
2688     private:
2689         enum { mean = 1, var = 2, };
2690         memory::primitive_desc stat_primitive_desc(int kind) const {
2691             mkldnn_batch_normalization_desc_t *p;
2692             error::wrap_c_api(mkldnn_primitive_desc_query(
2693                     get(), mkldnn::convert_to_c(batch_normalization_d), 0, &p),
2694                     "could not get a batch-normalization descriptor");
2695             return query_mpd(p->flags & use_global_stats ? src_pd : dst_pd, kind);
2696         }
2697     };
2698
2699     batch_normalization_forward(const primitive_desc &aprimitive_desc,
2700             const primitive::at &src, const primitive::at &mean,
2701             const primitive::at &variance, const primitive::at &weights,
2702             const memory &dst) {
2703         mkldnn_primitive_t result;
2704         mkldnn_primitive_at_t inputs[] = { src.data,
2705             mean.data, variance.data, weights.data };
2706         const_mkldnn_primitive_t outputs[] = { dst.get() };
2707         check_num_parameters(aprimitive_desc.get(), 4, 1,
2708             "batch normalization forward");
2709         error::wrap_c_api(mkldnn_primitive_create(&result,
2710                 aprimitive_desc.get(), inputs, outputs),
2711             "could not create a batch normalization forward primitive");
2712         reset(result);
2713     }
2714
2715     batch_normalization_forward(const primitive_desc &aprimitive_desc,
2716             const primitive::at &src, const primitive::at &mean,
2717             const primitive::at &variance, const memory &dst) {
2718         mkldnn_primitive_t result;
2719         mkldnn_primitive_at_t inputs[] = { src.data,
2720             mean.data, variance.data };
2721         const_mkldnn_primitive_t outputs[] = { dst.get() };
2722         check_num_parameters(aprimitive_desc.get(), 3, 1,
2723             "batch normalization forward");
2724         error::wrap_c_api(mkldnn_primitive_create(&result,
2725                 aprimitive_desc.get(), inputs, outputs),
2726             "could not create a batch normalization forward primitive");
2727         reset(result);
2728     }
2729
2730     /// @warning batch_normalization_forward has two constructors with very
2731     ///          similar signatures:
2732     ///           - (pd, src, weights, dst, mean, variance) // 2 in, 3 out
2733     ///           - (pd, src, dst, mean, variance, workspace) // 1 in, 4 out
2734     ///          The only way to distinguish between them is to explicitly
2735     ///          cast all input parameters to their type; that is, to
2736     ///          const primitive:at &.
2737     batch_normalization_forward(const primitive_desc &aprimitive_desc,
2738             const primitive::at &src, const primitive::at &weights,
2739             const memory &dst, const memory &mean, const memory &variance) {
2740         mkldnn_primitive_t result;
2741         mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
2742         const_mkldnn_primitive_t outputs[] = { dst.get(),
2743             mean.get(), variance.get() };
2744         check_num_parameters(aprimitive_desc.get(), 2, 3,
2745             "batch normalization forward");
2746         error::wrap_c_api(mkldnn_primitive_create(&result,
2747                 aprimitive_desc.get(), inputs, outputs),
2748             "could not create a batch normalization forward primitive");
2749         reset(result);
2750     }
2751
2752     batch_normalization_forward(const primitive_desc &aprimitive_desc,
2753             const primitive::at &src, const primitive::at &weights,
2754             const memory &dst, const memory &mean, const memory &variance,
2755             const memory &workspace) {
2756         mkldnn_primitive_t result;
2757         mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
2758         const_mkldnn_primitive_t outputs[] = { dst.get(),
2759             mean.get(), variance.get(), workspace.get() };
2760         check_num_parameters(aprimitive_desc.get(), 2, 4,
2761             "batch normalization forward");
2762         error::wrap_c_api(mkldnn_primitive_create(&result,
2763                 aprimitive_desc.get(), inputs, outputs),
2764             "could not create a batch normalization forward primitive");
2765         reset(result);
2766     }
2767
2768     batch_normalization_forward(const primitive_desc &aprimitive_desc,
2769             const primitive::at &src, const memory &dst, const memory &mean,
2770             const memory &variance) {
2771         mkldnn_primitive_t result;
2772         mkldnn_primitive_at_t inputs[] = { src.data };
2773         const_mkldnn_primitive_t outputs[] = { dst.get(),
2774             mean.get(), variance.get() };
2775         check_num_parameters(aprimitive_desc.get(), 1, 3,
2776             "batch normalization forward");
2777         error::wrap_c_api(mkldnn_primitive_create(&result,
2778                 aprimitive_desc.get(), inputs, outputs),
2779             "could not create a batch normalization forward primitive");
2780         reset(result);
2781     }
2782
2783     /// @warning batch_normalization_forward has two constructors with very
2784     ///          similar signatures:
2785     ///           - (pd, src, weights, dst, mean, variance) // 2 in, 3 out
2786     ///           - (pd, src, dst, mean, variance, workspace) // 1 in, 4 out
2787     ///          The only way to distinguish between them is to explicitly
2788     ///          cast all input parameters to their type; that is, to
2789     ///          const primitive:at &.
2790     /// @note To make users' experience a little better, this constructor
2791     ///       checks whether parameters match the corresponding primitive
2792     ///       descriptor, and if not, calls the other (proper) constructor.
2793     batch_normalization_forward(const primitive_desc &aprimitive_desc,
2794             const primitive::at &src, const memory &dst, const memory &mean,
2795             const memory &variance, const memory &workspace) {
2796         mkldnn_primitive_t result;
2797         mkldnn_primitive_at_t inputs[2] = { src.data };
2798         const_mkldnn_primitive_t outputs[4] = { dst.get(),
2799             mean.get(), variance.get(), workspace.get() };
2800
2801         if (1) { // check whether this is the `wrong` constructor
2802             const int n_inputs_expected = mkldnn_primitive_desc_query_s32(
2803                     aprimitive_desc.get(), mkldnn_query_num_of_inputs_s32, 0);
2804             const int n_outputs_expected = mkldnn_primitive_desc_query_s32(
2805                     aprimitive_desc.get(), mkldnn_query_num_of_outputs_s32, 0);
2806             if (n_inputs_expected == 2 && n_outputs_expected == 3) {
2807                 // shift parameters, get rid of workspace, and add weights...
2808                 auto _weights = dst;
2809                 inputs[1] = {_weights.get(), 0};
2810
2811                 auto _dst = mean, _mean = variance, _variance = workspace;
2812                 outputs[0] = _dst.get();
2813                 outputs[1] = _mean.get();
2814                 outputs[2] = _variance.get();
2815                 outputs[3] = nullptr;
2816             }
2817         }
2818         error::wrap_c_api(mkldnn_primitive_create(&result,
2819                 aprimitive_desc.get(), inputs, outputs),
2820             "could not create a batch normalization forward primitive");
2821         reset(result);
2822     }
2823
2824     batch_normalization_forward(const primitive_desc &aprimitive_desc,
2825             const primitive::at &src, const primitive::at &weights,
2826             const memory &dst) {
2827         mkldnn_primitive_t result;
2828         mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
2829         const_mkldnn_primitive_t outputs[] = { dst.get() };
2830         check_num_parameters(aprimitive_desc.get(), 2, 1,
2831             "batch normalization forward");
2832         error::wrap_c_api(mkldnn_primitive_create(&result,
2833                 aprimitive_desc.get(), inputs, outputs),
2834             "could not create a batch normalization forward primitive");
2835         reset(result);
2836     }
2837
2838     batch_normalization_forward(const primitive_desc &aprimitive_desc,
2839             const primitive::at &src, const memory &dst) {
2840         mkldnn_primitive_t result;
2841         mkldnn_primitive_at_t inputs[] = { src.data };
2842         const_mkldnn_primitive_t outputs[] = { dst.get() };
2843         check_num_parameters(aprimitive_desc.get(), 1, 1,
2844             "batch normalization forward");
2845         error::wrap_c_api(mkldnn_primitive_create(&result,
2846                 aprimitive_desc.get(), inputs, outputs),
2847             "could not create a batch normalization forward primitive");
2848         reset(result);
2849     }
2850 };
2851
2852 struct batch_normalization_backward : public primitive {
2853     struct desc {
2854         mkldnn_batch_normalization_desc_t data;
2855         template <typename T>
2856         desc(prop_kind aprop_kind, const memory::desc &diff_data_desc,
2857                 const memory::desc &data_desc, T epsilon, unsigned flags) {
2858             error::wrap_c_api(
2859                     mkldnn_batch_normalization_backward_desc_init(&data,
2860                         mkldnn::convert_to_c(aprop_kind),
2861                         &diff_data_desc.data, &data_desc.data,
2862                         static_cast<float>(epsilon), flags),
2863                 "could not create a batch normalization backward descriptor");
2864         }
2865     };
2866
2867     struct primitive_desc : public mkldnn::primitive_desc {
2868         primitive_desc(const desc &desc, const engine &e,
2869                 const batch_normalization_forward::primitive_desc &hint_fwd_pd)
2870             : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2871
2872         primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2873                 const batch_normalization_forward::primitive_desc &hint_fwd_pd)
2874             : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2875
2876         REG_QUERY_MPD(src, src, 0);
2877         REG_QUERY_MPD(mean, src, 1);
2878         REG_QUERY_MPD(variance, src, 2);
2879         REG_QUERY_MPD(weights, weights, 0);
2880         REG_QUERY_MPD(dst, dst, 0);
2881         REG_QUERY_MPD(diff_dst, diff_dst, 0);
2882         REG_QUERY_MPD(workspace, workspace, 0);
2883
2884         REG_QUERY_MPD(diff_src, diff_src, 0);
2885         REG_QUERY_MPD(diff_weights, diff_weights, 0);
2886     };
2887
2888     // Prop_kind == backward
2889     batch_normalization_backward(const primitive_desc &aprimitive_desc,
2890             const primitive::at &src, const primitive::at &mean,
2891             const primitive::at &variance, const primitive::at &diff_dst,
2892             const primitive::at &weights, const memory &diff_src,
2893             const memory &diff_weights) {
2894         mkldnn_primitive_t result;
2895         mkldnn_primitive_at_t inputs[] = { src.data,
2896             mean.data, variance.data, diff_dst.data, weights.data };
2897         const_mkldnn_primitive_t outputs[] = { diff_src.get(),
2898                 diff_weights.get() };
2899         check_num_parameters(aprimitive_desc.get(), 5, 2,
2900             "batch normalization backward");
2901         error::wrap_c_api(mkldnn_primitive_create(&result,
2902                 aprimitive_desc.get(), inputs, outputs),
2903             "could not create a batch normalization backward primitive");
2904         reset(result);
2905     }
2906
2907     // Prop_kind == backward (+ws)
2908     batch_normalization_backward(const primitive_desc &aprimitive_desc,
2909             const primitive::at &src, const primitive::at &mean,
2910             const primitive::at &variance, const primitive::at &diff_dst,
2911             const primitive::at &weights, const primitive::at &workspace,
2912             const memory &diff_src, const memory &diff_weights) {
2913         mkldnn_primitive_t result;
2914         mkldnn_primitive_at_t inputs[] = { src.data, mean.data, variance.data,
2915             diff_dst.data, weights.data, workspace.data };
2916         const_mkldnn_primitive_t outputs[] = { diff_src.get(),
2917                 diff_weights.get() };
2918         check_num_parameters(aprimitive_desc.get(), 6, 2,
2919             "batch normalization backward");
2920         error::wrap_c_api(mkldnn_primitive_create(&result,
2921                 aprimitive_desc.get(), inputs, outputs),
2922             "could not create a batch normalization backward primitive");
2923         reset(result);
2924     }
2925
2926     // Prop_kind == backward_data (+ws or +weights)
2927     /// @warning This constructor works for backward_data propagation
2928     ///          - w/ weights but w/o workspace, or
2929     ///          - w/ workspace but w/o weights
2930     batch_normalization_backward(const primitive_desc &aprimitive_desc,
2931             const primitive::at &src, const primitive::at &mean,
2932             const primitive::at &variance,const primitive::at &diff_dst,
2933             const primitive::at &weights_or_workspace, const memory &diff_src) {
2934         mkldnn_primitive_t result;
2935         mkldnn_primitive_at_t inputs[] = { src.data, mean.data, variance.data,
2936             diff_dst.data, weights_or_workspace.data };
2937         const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2938         check_num_parameters(aprimitive_desc.get(), 5, 1,
2939             "batch normalization backward");
2940         error::wrap_c_api(mkldnn_primitive_create(&result,
2941                 aprimitive_desc.get(), inputs, outputs),
2942             "could not create a batch normalization backward primitive");
2943         reset(result);
2944     }
2945
2946     // Prop_kind == backward_data
2947     batch_normalization_backward(const primitive_desc &aprimitive_desc,
2948             const primitive::at &src, const primitive::at &mean,
2949             const primitive::at &variance, const primitive::at &diff_dst,
2950             const memory &diff_src) {
2951         mkldnn_primitive_t result;
2952         mkldnn_primitive_at_t inputs[] = { src.data,
2953             mean.data, variance.data, diff_dst.data };
2954         const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2955         check_num_parameters(aprimitive_desc.get(), 4, 1,
2956             "batch normalization backward");
2957         error::wrap_c_api(mkldnn_primitive_create(&result,
2958                 aprimitive_desc.get(), inputs, outputs),
2959             "could not create a batch normalization backward primitive");
2960         reset(result);
2961     }
2962 };
2963
2964 /// @}
2965
2966 /// @addtogroup cpp_api_inner_product Inner Product
2967 /// A primitive to compute an inner product.
2968 ///
2969 /// @sa @ref c_api_inner_product in @ref c_api
2970 /// @{
2971
2972 struct inner_product_forward: public primitive {
2973     struct desc {
2974         mkldnn_inner_product_desc_t data;
2975         desc(prop_kind aprop_kind, const memory::desc &src_desc,
2976                 const memory::desc &weights_desc,
2977                 const memory::desc &bias_desc,
2978                 const memory::desc &dst_desc) {
2979             error::wrap_c_api(
2980                     mkldnn_inner_product_forward_desc_init(&data,
2981                         mkldnn::convert_to_c(aprop_kind), &src_desc.data,
2982                         &weights_desc.data, &bias_desc.data, &dst_desc.data),
2983                     "could not create a inner product forward descriptor");
2984         }
2985
2986         desc(prop_kind aprop_kind, const memory::desc &src_desc,
2987                 const memory::desc &weights_desc,
2988                 const memory::desc &dst_desc) {
2989             error::wrap_c_api(
2990                     mkldnn_inner_product_forward_desc_init(&data,
2991                         mkldnn::convert_to_c(aprop_kind), &src_desc.data,
2992                         &weights_desc.data, nullptr, &dst_desc.data),
2993                     "could not create a inner product forward descriptor");
2994         }
2995     };
2996
2997     struct primitive_desc : public mkldnn::primitive_desc {
2998         primitive_desc(const desc &desc, const engine &e)
2999             : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
3000
3001         primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
3002             : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
3003
3004         REG_QUERY_MPD(src, src, 0);
3005         REG_QUERY_MPD(weights, weights, 0);
3006         REG_QUERY_MPD(bias, weights, 1);
3007         REG_QUERY_MPD(dst, dst, 0);
3008     };
3009
3010     inner_product_forward(const primitive_desc &aprimitive_desc,
3011             const primitive::at &src, const primitive::at weights,
3012             const primitive::at &bias, const memory &dst) {
3013         mkldnn_primitive_t result;
3014         mkldnn_primitive_at_t inputs[] = { src.data, weights.data,
3015                 bias.data };
3016         const_mkldnn_primitive_t outputs[] = { dst.get() };
3017         check_num_parameters(aprimitive_desc.get(), 3, 1,
3018             "inner product forward");
3019         error::wrap_c_api(mkldnn_primitive_create(&result,
3020                 aprimitive_desc.get(), inputs, outputs),
3021             "could not create a inner product forward primitive");
3022         reset(result);
3023     }
3024
3025     inner_product_forward(const primitive_desc &aprimitive_desc,
3026             const primitive::at &src, const primitive::at weights,
3027             const memory &dst) {
3028         mkldnn_primitive_t result;
3029         mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
3030         const_mkldnn_primitive_t outputs[] = { dst.get() };
3031         check_num_parameters(aprimitive_desc.get(), 2, 1,
3032             "inner product forward");
3033         error::wrap_c_api(mkldnn_primitive_create(&result,
3034                 aprimitive_desc.get(), inputs, outputs),
3035             "could not create a inner product forward primitive");
3036         reset(result);
3037     }
3038 };
3039
3040 struct inner_product_backward_data: public primitive {
3041     struct desc {
3042         mkldnn_inner_product_desc_t data;
3043         desc(const memory::desc &diff_src_desc,
3044                 const memory::desc &weights_desc,
3045                 const memory::desc &diff_dst_desc) {
3046             error::wrap_c_api(
3047                     mkldnn_inner_product_backward_data_desc_init(&data,
3048                         &diff_src_desc.data, &weights_desc.data,
3049                         &diff_dst_desc.data),
3050                 "could not create a inner product backward data descriptor");
3051         }
3052     };
3053
3054     struct primitive_desc : public mkldnn::primitive_desc {
3055         primitive_desc(const desc &desc, const engine &e,
3056                 const inner_product_forward::primitive_desc &hint_fwd_pd)
3057             : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
3058
3059         primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
3060                 const inner_product_forward::primitive_desc &hint_fwd_pd)
3061             : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
3062
3063         REG_QUERY_MPD(diff_src, diff_src, 0);
3064         REG_QUERY_MPD(weights, weights, 0);
3065         REG_QUERY_MPD(diff_dst, diff_dst, 0);
3066     };
3067
3068     inner_product_backward_data(const primitive_desc &aprimitive_desc,
3069             const primitive::at &diff_dst, const primitive::at weights,
3070             const memory &diff_src) {
3071         mkldnn_primitive_t result;
3072         mkldnn_primitive_at_t inputs[] = { diff_dst.data, weights.data };
3073         const_mkldnn_primitive_t outputs[] = { diff_src.get() };
3074         check_num_parameters(aprimitive_desc.get(), 2, 1,
3075             "inner product backward data");
3076         error::wrap_c_api(mkldnn_primitive_create(&result,
3077                 aprimitive_desc.get(), inputs, outputs),
3078             "could not create a inner product backward data primitive");
3079         reset(result);
3080     }
3081 };
3082
3083 struct inner_product_backward_weights: public primitive {
3084     struct desc {
3085         mkldnn_inner_product_desc_t data;
3086         desc(const memory::desc &src_desc,
3087                 const memory::desc &diff_weights_desc,
3088                 const memory::desc &diff_bias_desc,
3089                 const memory::desc &diff_dst_desc) {
3090             error::wrap_c_api(
3091                     mkldnn_inner_product_backward_weights_desc_init(
3092                         &data, &src_desc.data, &diff_weights_desc.data,
3093                         &diff_bias_desc.data, &diff_dst_desc.data),
3094                 "could not create a inner product backward weights descriptor");
3095         }
3096         desc(const memory::desc &src_desc,
3097                 const memory::desc &diff_weights_desc,
3098                 const memory::desc &diff_dst_desc) {
3099             error::wrap_c_api(
3100                     mkldnn_inner_product_backward_weights_desc_init(
3101                         &data, &src_desc.data, &diff_weights_desc.data,
3102                         nullptr, &diff_dst_desc.data),
3103                 "could not create a inner product backward weights descriptor");
3104         }
3105     };
3106
3107     struct primitive_desc : public mkldnn::primitive_desc {
3108         primitive_desc(const desc &desc, const engine &e,
3109                 const inner_product_forward::primitive_desc &hint_fwd_pd)
3110             : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
3111
3112         primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
3113                 const inner_product_forward::primitive_desc &hint_fwd_pd)
3114             : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
3115
3116         REG_QUERY_MPD(src, src, 0);
3117         REG_QUERY_MPD(diff_weights, diff_weights, 0);
3118         REG_QUERY_MPD(diff_bias, diff_weights, 1);
3119         REG_QUERY_MPD(diff_dst, diff_dst, 0);
3120     };
3121
3122     inner_product_backward_weights(const primitive_desc &aprimitive_desc,
3123             const primitive::at &src, const primitive::at diff_dst,
3124             const memory &diff_weights) {
3125         mkldnn_primitive_t result;
3126         mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
3127         const_mkldnn_primitive_t outputs[] = { diff_weights.get() };
3128         check_num_parameters(aprimitive_desc.get(), 2, 1,
3129             "inner product backward weights");
3130         error::wrap_c_api(mkldnn_primitive_create(&result,
3131                 aprimitive_desc.get(), inputs, outputs),
3132             "could not create a inner product backward weights primitive");
3133         reset(result);
3134     }
3135
3136     inner_product_backward_weights(const primitive_desc &aprimitive_desc,
3137             const primitive::at &src, const primitive::at diff_dst,
3138             const memory &diff_weights, const memory &diff_bias) {
3139         mkldnn_primitive_t result;
3140         mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
3141         const_mkldnn_primitive_t outputs[] =
3142                 { diff_weights.get(), diff_bias.get()};
3143         check_num_parameters(aprimitive_desc.get(), 2, 2,
3144             "inner product backward weights");
3145         error::wrap_c_api(mkldnn_primitive_create(&result,
3146                 aprimitive_desc.get(), inputs, outputs),
3147             "could not create a inner product backward weights primitive");
3148         reset(result);
3149     }
3150 };
3151
3152 /// @}
3153
3154 /// @addtogroup cpp_api_rnn RNN
3155 /// A primitive to compute common recurrent layer.
3156 ///
3157 /// @sa @ref c_api_rnn in @ref c_api
3158 /// @{
3159
3160 struct rnn_cell {
3161     struct desc {
3162         mkldnn_rnn_cell_desc_t c_rnn_cell_;
3163
3164         desc(algorithm kind, algorithm activation_f) {
3165             error::wrap_c_api(mkldnn_rnn_cell_desc_init(&c_rnn_cell_,
3166                         mkldnn::convert_to_c(kind),
3167                         mkldnn::convert_to_c(activation_f), 0U, 0, 0),
3168                     "could not init an rnn cell descriptor");
3169         }
3170         desc(algorithm kind): desc(kind, algorithm::algorithm_undef) {}
3171
3172         operator const mkldnn_rnn_cell_desc_t*() const { return &c_rnn_cell_; }
3173
3174         algorithm get_cell_kind() const
3175         { return algorithm(c_rnn_cell_.cell_kind); }
3176         algorithm get_activation() const
3177         { return algorithm(c_rnn_cell_.activation_kind); }
3178
3179         float get_alpha() const { return c_rnn_cell_.alpha; }
3180         void set_alpha(float alpha) {
3181             c_rnn_cell_.flags |= mkldnn_rnn_cell_with_relu;
3182             c_rnn_cell_.alpha = alpha;
3183         }
3184
3185         float get_clipping() const { return c_rnn_cell_.clipping; }
3186         void set_clipping(float clipping) {
3187             c_rnn_cell_.flags |= mkldnn_rnn_cell_with_clipping;
3188             c_rnn_cell_.clipping = clipping;
3189         }
3190
3191         int get_gates_count() const {
3192             return mkldnn_rnn_cell_get_gates_count(&c_rnn_cell_);
3193         }
3194         int get_state_count() const {
3195             return mkldnn_rnn_cell_get_states_count(&c_rnn_cell_);
3196         }
3197     };
3198 };
3199
3200 struct rnn_forward : public primitive {
3201     struct desc {
3202         mkldnn_rnn_desc_t data;
3203         desc(prop_kind aprop_kind, rnn_cell::desc cell,
3204                 const rnn_direction direction,
3205                 const memory::desc &src_layer_desc,
3206                 const memory::desc &src_iter_desc,
3207                 const memory::desc &weights_layer_desc,
3208                 const memory::desc &weights_iter_desc,
3209                 const memory::desc &bias_desc,
3210                 const memory::desc &dst_layer_desc,
3211                 const memory::desc &dst_iter_desc
3212             ) {
3213             error::wrap_c_api(mkldnn_rnn_forward_desc_init(&data,
3214                         mkldnn::convert_to_c(aprop_kind), cell,
3215                         mkldnn::convert_to_c(direction),
3216                         &src_layer_desc.data, &src_iter_desc.data,
3217                         &weights_layer_desc.data, &weights_iter_desc.data,
3218                         &bias_desc.data,
3219                         &dst_layer_desc.data, &dst_iter_desc.data),
3220                     "could not create an RNN forward descriptor");
3221         }
3222
3223     };
3224
3225     struct primitive_desc : public mkldnn::primitive_desc {
3226         primitive_desc(const desc &desc, const engine &e)
3227             : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
3228
3229         primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
3230             : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
3231
3232         REG_QUERY_MPD(src_layer, src, 0);
3233         REG_QUERY_MPD(src_iter, src, 1);
3234         REG_QUERY_MPD(weights_layer, weights, 0);
3235         REG_QUERY_MPD(weights_iter, weights, 1);
3236         REG_QUERY_MPD(bias, weights, 2);
3237         REG_QUERY_MPD(dst_layer, dst, 0);
3238         REG_QUERY_MPD(dst_iter, dst, 1);
3239         REG_QUERY_MPD(workspace, workspace, 0);
3240     };
3241
3242     rnn_forward(const primitive_desc &aprimitive_desc,
3243             const primitive::at &src_layer, const primitive::at &src_iter,
3244             const primitive::at &weights_layer,
3245             const primitive::at &weights_iter, const primitive::at &bias,
3246             const memory &dst_layer, const memory &dst_iter,
3247             const memory &workspace) {
3248         mkldnn_primitive_t result;
3249         mkldnn_primitive_at_t inputs[5];
3250         const_mkldnn_primitive_t outputs[3];
3251         int idx=0;
3252         inputs[idx++] = src_layer.data;
3253         if (!is_null_memory(src_iter.data.primitive))
3254             inputs[idx++] = src_iter.data;
3255         inputs[idx++] = weights_layer.data;
3256         inputs[idx++] = weights_iter.data;
3257         if (!is_null_memory(bias.data.primitive)) inputs[idx++] = bias.data;
3258
3259         idx=0;
3260         outputs[idx++] = dst_layer.get();
3261         if (!is_null_memory(dst_iter.get())) outputs[idx++] = dst_iter.get();
3262         if (!is_null_memory(workspace.get())) outputs[idx++] = workspace.get();
3263
3264         error::wrap_c_api(mkldnn_primitive_create(&result,
3265                     aprimitive_desc.get(), inputs, outputs),
3266                 "could not create an RNN forward primitive");
3267         reset(result);
3268     }
3269 };
3270
3271 struct rnn_backward : public primitive {
3272     struct desc {
3273         mkldnn_rnn_desc_t data;
3274         desc(prop_kind aprop_kind, rnn_cell::desc cell,
3275                 const rnn_direction direction,
3276                 const memory::desc &src_layer_desc,
3277                 const memory::desc &src_iter_desc,
3278                 const memory::desc &weights_layer_desc,
3279                 const memory::desc &weights_iter_desc,
3280                 const memory::desc &bias_desc,
3281                 const memory::desc &dst_layer_desc,
3282                 const memory::desc &dst_iter_desc,
3283                 const memory::desc &diff_src_layer_desc,
3284                 const memory::desc &diff_src_iter_desc,
3285                 const memory::desc &diff_weights_layer_desc,
3286                 const memory::desc &diff_weights_iter_desc,
3287                 const memory::desc &diff_bias_desc,
3288                 const memory::desc &diff_dst_layer_desc,
3289                 const memory::desc &diff_dst_iter_desc) {
3290             error::wrap_c_api(mkldnn_rnn_backward_desc_init(&data,
3291                         mkldnn::convert_to_c(aprop_kind), cell,
3292                         mkldnn::convert_to_c(direction),
3293                         &src_layer_desc.data, &src_iter_desc.data,
3294                         &weights_layer_desc.data, &weights_iter_desc.data,
3295                         &bias_desc.data,
3296                         &dst_layer_desc.data, &dst_iter_desc.data,
3297                         &diff_src_layer_desc.data, &diff_src_iter_desc.data,
3298                         &diff_weights_layer_desc.data,
3299                         &diff_weights_iter_desc.data, &diff_bias_desc.data,
3300                         &diff_dst_layer_desc.data, &diff_dst_iter_desc.data),
3301                     "could not create an RNN backward descriptor");
3302         }
3303
3304     };
3305
3306     struct primitive_desc : public mkldnn::primitive_desc {
3307         primitive_desc(const desc &desc, const engine &e,
3308                 const rnn_forward::primitive_desc &hint_fwd_pd)
3309             : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
3310
3311         primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
3312                 const rnn_forward::primitive_desc &hint_fwd_pd)
3313             : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
3314
3315         REG_QUERY_MPD(src_layer, src, 0);
3316         REG_QUERY_MPD(src_iter, src, 1);
3317         REG_QUERY_MPD(weights_layer, weights, 0);
3318         REG_QUERY_MPD(weights_iter, weights, 1);
3319         REG_QUERY_MPD(bias, weights, 2);
3320         REG_QUERY_MPD(dst_layer, dst, 0);
3321         REG_QUERY_MPD(dst_iter, dst, 1);
3322         REG_QUERY_MPD(workspace, workspace, 0);
3323
3324         REG_QUERY_MPD(diff_src_layer, diff_src, 0);
3325         REG_QUERY_MPD(diff_src_iter, diff_src, 1);
3326         REG_QUERY_MPD(diff_weights_layer, diff_weights, 0);
3327         REG_QUERY_MPD(diff_weights_iter, diff_weights, 1);
3328         REG_QUERY_MPD(diff_bias, diff_weights, 2);
3329         REG_QUERY_MPD(diff_dst_layer, diff_dst, 0);
3330         REG_QUERY_MPD(diff_dst_iter, diff_dst, 1);
3331     };
3332
3333     // With last iteration (with and without input src_iter)
3334     rnn_backward(const primitive_desc &aprimitive_desc,
3335                  const primitive::at &src_layer,
3336                  const primitive::at &src_iter,
3337                  const primitive::at &weights_layer,
3338                  const primitive::at &weights_iter,
3339                  const primitive::at &bias,
3340                  const primitive::at &dst_layer,
3341                  const primitive::at &dst_iter,
3342                  const memory &diff_src_layer,
3343                  const memory &diff_src_iter,
3344                  const memory &diff_weights_layer,
3345                  const memory &diff_weights_iter,
3346                  const memory &diff_bias,
3347                  const primitive::at &diff_dst_layer,
3348                  const primitive::at &diff_dst_iter,
3349                  const primitive::at &workspace) {
3350         mkldnn_primitive_t result;
3351         mkldnn_primitive_at_t inputs[10];
3352         const_mkldnn_primitive_t outputs[5];
3353         int idx=0;
3354         inputs[idx++] = src_layer.data;
3355         if (!is_null_memory(src_iter.data.primitive))
3356             inputs[idx++] = src_iter.data;
3357         inputs[idx++] = weights_layer.data;
3358         inputs[idx++] = weights_iter.data;
3359         if (!is_null_memory(bias.data.primitive))
3360             inputs[idx++] = bias.data;
3361         inputs[idx++] = dst_layer.data;
3362         if (!is_null_memory(dst_iter.data.primitive))
3363             inputs[idx++] = dst_iter.data;
3364         inputs[idx++] = diff_dst_layer.data;
3365         if (!is_null_memory(diff_dst_iter.data.primitive))
3366             inputs[idx++] = diff_dst_iter.data;
3367         inputs[idx++] = workspace.data;
3368
3369         idx = 0;
3370         outputs[idx++] = diff_src_layer.get();
3371         if (!is_null_memory(diff_src_iter.get()))
3372             outputs[idx++] = diff_src_iter.get();
3373         outputs[idx++] = diff_weights_layer.get();
3374         outputs[idx++] = diff_weights_iter.get();
3375         if (!is_null_memory(diff_bias.get())) outputs[idx++] = diff_bias.get();
3376         error::wrap_c_api(mkldnn_primitive_create(&result,
3377                     aprimitive_desc.get(), inputs, outputs),
3378                 "could not create an RNN backward primitive");
3379         reset(result);
3380     }
3381 };
3382
3383 /// @}
3384
3385 /// @addtogroup cpp_api_shuffle Shuffle
3386 /// A primitive to shuffle data along the axis.
3387 ///
3388 /// @sa @ref c_api_shuffle in @ref c_api
3389 /// @{
3390
3391 struct shuffle_forward : public primitive {
3392     struct desc {
3393         mkldnn_shuffle_desc_t data;
3394         desc(prop_kind aprop_kind, const memory::desc &data_desc,
3395                 int axis, int group_size) {
3396             error::wrap_c_api(mkldnn_shuffle_forward_desc_init(&data,
3397                         mkldnn::convert_to_c(aprop_kind), &data_desc.data,
3398                         axis, group_size),
3399                     "could not create a shuffle forward descriptor");
3400         }
3401     };
3402
3403     struct primitive_desc : public mkldnn::primitive_desc {
3404         primitive_desc(const desc &desc, const engine &e)
3405             : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
3406
3407         REG_QUERY_MPD(src, src, 0);
3408         REG_QUERY_MPD(dst, dst, 0);
3409     };
3410
3411     shuffle_forward(const primitive_desc &aprimitive_desc,
3412             const primitive::at &src, const memory &dst) {
3413         mkldnn_primitive_t result;
3414         mkldnn_primitive_at_t inputs[] = { src.data };
3415         const_mkldnn_primitive_t outputs[] = { dst.get() };
3416         check_num_parameters(aprimitive_desc.get(), 1, 1, "shuffle forward");
3417         error::wrap_c_api(mkldnn_primitive_create(&result,
3418             aprimitive_desc.get(), inputs, outputs),
3419             "could not create a shuffle forward primitive");
3420         reset(result);
3421     }
3422 };
3423
3424 struct shuffle_backward : public primitive {
3425     struct desc {
3426         mkldnn_shuffle_desc_t data;
3427         desc(const memory::desc &diff_data_desc, int axis, int group_size) {
3428             error::wrap_c_api(mkldnn_shuffle_backward_desc_init(&data,
3429                         &diff_data_desc.data, axis, group_size),
3430                     "could not create a shuffle backward descriptor");
3431         }
3432     };
3433
3434     struct primitive_desc : public mkldnn::primitive_desc {
3435         primitive_desc(const desc &desc, const engine &e,
3436                 const shuffle_forward::primitive_desc &hint_fwd_pd)
3437             : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
3438
3439         REG_QUERY_MPD(diff_src, diff_src, 0);
3440         REG_QUERY_MPD(diff_dst, diff_dst, 0);
3441     };
3442
3443     shuffle_backward(const primitive_desc &aprimitive_desc,
3444             const primitive::at &diff_dst, const memory &diff_src) {
3445         mkldnn_primitive_t result;
3446         mkldnn_primitive_at_t inputs[] = { diff_dst.data};
3447         const_mkldnn_primitive_t outputs[] = { diff_src.get() };
3448         check_num_parameters(aprimitive_desc.get(), 1, 1, "shuffle backward");
3449         error::wrap_c_api(mkldnn_primitive_create(&result,
3450             aprimitive_desc.get(), inputs, outputs),
3451             "could not create a shuffle backward primitive");
3452         reset(result);
3453     }
3454 };
3455
3456 /// @}
3457
3458 /// @addtogroup cpp_api_binary_convolution Binary convolution
3459 /// A primitive to compute binary convolution using different algorithms.
3460 ///
3461 /// @sa @ref c_api_binary_convolution in @ref c_api
3462 /// @{
3463
3464 struct binary_convolution_forward: public primitive {
3465     struct desc {
3466         mkldnn_binary_convolution_desc_t data;
3467         desc(prop_kind aprop_kind, algorithm aalgorithm,
3468                 const memory::desc &src_desc,
3469                 const memory::desc &weights_desc,
3470                 const memory::desc &dst_desc,
3471                 const memory::dims strides,
3472                 const memory::dims dilates,
3473                 const memory::dims padding_l,
3474                 const memory::dims padding_r,
3475                 const float pad_value) {
3476             memory::validate_dims(strides);
3477             memory::validate_dims(dilates);
3478             memory::validate_dims(padding_l);
3479             memory::validate_dims(padding_r);
3480             error::wrap_c_api(
3481                 mkldnn_dilated_binary_convolution_forward_desc_init(&data,
3482                     mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
3483                         &src_desc.data, &weights_desc.data, &dst_desc.data,
3484                         &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
3485                         pad_value),
3486                     "could not create a dilated binary convolution forward descriptor");
3487         }
3488     };
3489
3490     struct primitive_desc : public mkldnn::primitive_desc {
3491         primitive_desc(const desc &desc, const engine &e)
3492             : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
3493
3494         primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
3495             : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
3496
3497         REG_QUERY_MPD(src, src, 0);
3498         REG_QUERY_MPD(weights, weights, 0);
3499         REG_QUERY_MPD(dst, dst, 0);
3500     };
3501
3502     binary_convolution_forward(const primitive_desc &aprimitive_desc,
3503             const primitive::at &src, const primitive::at &weights, const memory &dst) {
3504         mkldnn_primitive_t result;
3505         mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
3506         const_mkldnn_primitive_t outputs[] = { dst.get() };
3507         check_num_parameters(aprimitive_desc.get(), 2, 1,
3508             "binary convolution forward");
3509         error::wrap_c_api(mkldnn_primitive_create(&result,
3510                     aprimitive_desc.get(), inputs, outputs),
3511                 "could not create a binary convolution forward primitive");
3512         reset(result);
3513     }
3514 };
3515
3516 /// @}
3517
3518 /// @addtogroup cpp_api_binarization Binarization
3519 /// @{
3520
3521 struct binarization_forward : public primitive {
3522     struct desc {
3523         mkldnn_binarization_desc_t data;
3524
3525         desc(prop_kind aprop_kind, algorithm alg_kind,
3526              const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &output_mask_desc,
3527              const memory::desc &dst_desc) {
3528             error::wrap_c_api(mkldnn_binarization_forward_desc_init(&data,
3529                                                                  mkldnn::convert_to_c(aprop_kind),
3530                                                                  mkldnn::convert_to_c(alg_kind),
3531                                                                  &src_desc.data, &dst_desc.data,
3532                                                                  &weights_desc.data, &output_mask_desc.data),
3533                               "could not create a binarization forward descriptor");
3534         }
3535     };
3536
3537     struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
3538         primitive_desc(const desc &adesc, const engine &aengine) {
3539             mkldnn_primitive_desc_t result;
3540             error::wrap_c_api(mkldnn_primitive_desc_create(
3541                     &result, &adesc.data, aengine.get(), nullptr),
3542                               "could not create a binarization forward primitive descriptor");
3543             reset(result);
3544         }
3545
3546         engine get_engine() { return engine::query(*this); }
3547     };
3548
3549     binarization_forward(const primitive_desc &aprimitive_desc,
3550                       const primitive::at &src, const primitive::at &weights, const primitive::at &output_mask,
3551                       const memory &dst) {
3552         mkldnn_primitive_t result;
3553         mkldnn_primitive_at_t inputs[] = { src.data, weights.data, output_mask.data};
3554         const_mkldnn_primitive_t outputs[] = { dst.get() };
3555         error::wrap_c_api(mkldnn_primitive_create(&result, aprimitive_desc.get(), inputs, outputs),
3556                           "could not create a binarization forward primitive");
3557         reset(result);
3558     }
3559 };
3560
3561 /// @}
3562
3563 /// @} Primitives
3564
3565 /// @addtogroup cpp_api_stream Stream
3566 /// Execution stream operations.
3567 ///
3568 /// @sa @ref c_api_stream in @ref c_api
3569 /// @{
3570
3571 #ifndef DOXYGEN_SHOULD_SKIP_THIS
3572 template <> struct handle_traits<mkldnn_stream_t> {
3573     static constexpr auto destructor = &mkldnn_stream_destroy;
3574 };
3575 #endif
3576
3577 struct stream: public handle<mkldnn_stream_t> {
3578     using handle::handle;
3579
3580     enum kind { any = mkldnn_stream_kind_t::mkldnn_any_stream,
3581         eager = mkldnn_stream_kind_t::mkldnn_eager,
3582         lazy = mkldnn_stream_kind_t::mkldnn_lazy };
3583
3584     static mkldnn_stream_kind_t convert_to_c(kind akind) {
3585         return static_cast<mkldnn_stream_kind_t>(akind);
3586     }
3587     /// Constructs a stream.
3588     stream(kind akind) {
3589         mkldnn_stream_t astream;
3590         error::wrap_c_api(mkldnn_stream_create(&astream,
3591                     convert_to_c(akind)),
3592                 "could not create a stream");
3593         reset(astream);
3594     }
3595
3596     /// Submits a vector of primitives to a stream for computations.
3597     ///
3598     /// @param primitives The vector of primitives to submit.
3599     /// @returns The stream.
3600     stream &submit(std::vector<primitive> primitives) {
3601         // TODO: find a proper way to convert vector<primitive> to
3602         // vector<mkldnn_primitive_t>
3603         if (primitives.size() == 0) return *this;
3604         std::vector<mkldnn_primitive_t> c_api_primitives;
3605         c_api_primitives.reserve(primitives.size());
3606         auto convert_to_c = [](primitive p) { return p.get(); };
3607         std::transform(primitives.begin(), primitives.end(),
3608                 std::back_inserter(c_api_primitives), convert_to_c);
3609
3610         mkldnn_primitive_t c_api_error_primitive;
3611         error::wrap_c_api(
3612                 mkldnn_stream_submit(get(),
3613                     c_api_primitives.size(), &c_api_primitives[0],
3614                     &c_api_error_primitive),
3615                 "could not submit primitives to a stream",
3616                 &c_api_error_primitive);
3617
3618         return *this;
3619     }
3620
3621     /// Waits for all computations submitted to the stream to complete.
3622     ///
3623     /// @param block Specifies whether the operation should wait indefinitely or
3624     ///              return immediately.
3625     /// @returns @c true if all computations completed.
3626     /// @returns @c false if not all computations completed.
3627     bool wait(bool block = true) {
3628         mkldnn_primitive_t c_api_error_primitive;
3629         mkldnn_status_t status = mkldnn_stream_wait(get(),
3630                 block, &c_api_error_primitive);
3631         if (status != mkldnn_success
3632                 && status != mkldnn_try_again)
3633             error::wrap_c_api(status, "could not wait on a stream",
3634                     &c_api_error_primitive);
3635         return (status == mkldnn_success);
3636     }
3637
3638     stream &rerun() {
3639         mkldnn_primitive_t c_api_error_primitive;
3640         error::wrap_c_api(
3641                 mkldnn_stream_rerun(get(), &c_api_error_primitive),
3642                 "could not rerun a stream", &c_api_error_primitive);
3643         return *this;
3644     }
3645 };
3646
3647 #undef REG_QUERY_MPD
3648
3649 /// @}
3650
3651 /// @} C++ API
3652
3653 } // namespace mkldnn
3654
3655 #endif