shuffle = mkldnn_shuffle,
eltwise = mkldnn_eltwise,
depthwise = mkldnn_depthwise,
- relu = mkldnn_relu,
softmax = mkldnn_softmax,
pooling = mkldnn_pooling,
lrn = mkldnn_lrn,
batch_normalization = mkldnn_batch_normalization,
inner_product = mkldnn_inner_product,
- convolution_relu = mkldnn_convolution_relu,
rnn = mkldnn_rnn,
+ binary_convolution = mkldnn_binary_convolution,
+ binarization = mkldnn_binarization,
};
/// A wrapper structure to specify a particular output of a primitive.
inline operator primitive() const;
};
- /// Returns the descriptor of the underlying C API primitive
+ /// Returns the descriptor of the underlying C API primitive.
inline const_mkldnn_primitive_desc_t get_primitive_desc() const;
// TODO: use the C++ API wrapper structure.
};
enum algorithm {
algorithm_undef = mkldnn_alg_kind_undef,
+ convolution_auto = mkldnn_convolution_auto,
convolution_direct = mkldnn_convolution_direct,
convolution_winograd = mkldnn_convolution_winograd,
deconvolution_direct = mkldnn_deconvolution_direct,
eltwise_soft_relu = mkldnn_eltwise_soft_relu,
eltwise_logistic = mkldnn_eltwise_logistic,
eltwise_clamp = mkldnn_eltwise_clamp,
+ eltwise_exp = mkldnn_eltwise_exp,
+ eltwise_not = mkldnn_eltwise_not,
depthwise_scale_shift = mkldnn_depthwise_scale_shift,
depthwise_prelu = mkldnn_depthwise_prelu,
lrn_across_channels = mkldnn_lrn_across_channels,
vanilla_gru = mkldnn_vanilla_gru,
gru_linear_before_reset = mkldnn_gru_linear_before_reset,
roi_pooling_max = mkldnn_roi_pooling_max,
- roi_pooling_bilinear = mkldnn_roi_pooling_bilinear
+ roi_pooling_bilinear = mkldnn_roi_pooling_bilinear,
+ binary_convolution_direct = mkldnn_binary_convolution_direct,
+ binarization_depthwise = mkldnn_binarization_depthwise
};
inline mkldnn_alg_kind_t convert_to_c(algorithm aalgorithm) {
enum batch_normalization_flag {
use_global_stats = mkldnn_use_global_stats,
use_scale_shift = mkldnn_use_scaleshift,
- omit_stats = mkldnn_omit_stats,
fuse_bn_relu = mkldnn_fuse_bn_relu
};
shuffle_d = mkldnn_query_shuffle_d,
eltwise_d = mkldnn_query_eltwise_d,
depthwise_d = mkldnn_query_depthwise_d,
- relu_d = mkldnn_query_relu_d,
softmax_d = mkldnn_query_softmax_d,
pooling_d = mkldnn_query_pooling_d,
lrn_d = mkldnn_query_lrn_d,
batch_normalization_d = mkldnn_query_batch_normalization_d,
inner_product_d = mkldnn_query_inner_product_d,
- convolution_relu_d = mkldnn_query_convolution_relu_d,
rnn_d = mkldnn_query_rnn_d,
+ binary_convolution_d = mkldnn_query_binary_convolution_d,
+ binarization_d = mkldnn_query_binarization_d,
input_pd = mkldnn_query_input_pd,
output_pd = mkldnn_query_output_pd,
&in_h, &in_w, &ker_h, &ker_w, &str_h, &str_w, weights_data, biases_data),
"could not get dw conv params");
}
+
+ void append_binarization(algorithm alg, const float* weights_data) {
+ error::wrap_c_api(mkldnn_post_ops_append_binarization(get(), convert_to_c(alg), weights_data),
+ "could not append binarization");
+ }
+
+ void get_params_binarization(int index, algorithm &alg, const float** weights_data) const {
+ mkldnn_alg_kind_t c_alg;
+ error::wrap_c_api(mkldnn_post_ops_get_params_binarization(get(), index, &c_alg, weights_data),
+ "could not get binarization params");
+ alg = static_cast<algorithm>(c_alg);
+ }
};
#ifndef DOXYGEN_SHOULD_SKIP_THIS
error::wrap_c_api(mkldnn_primitive_attr_set_post_ops(get(), ops.get()),
"could not set post operation sequence");
}
+
+ void set_rnn_data_qparams(const float scale, const float shift)
+ {
+ error::wrap_c_api(mkldnn_primitive_attr_set_rnn_data_qparams(get(),
+ scale, shift), "could not set rnn data int scale/shift");
+ }
+
+ void set_rnn_weights_qparams(int mask, const std::vector<float> &scales)
+ {
+ error::wrap_c_api(mkldnn_primitive_attr_set_rnn_weights_qparams(get(),
+ (int)scales.size(), mask, &scales[0]),
+ "could not set rnn weights int scales");
+ }
};
/// @}
/// @addtogroup cpp_api_engine Engine
-/// Engine operations
+/// Engine operations.
///
/// @sa @ref c_api_engine in @ref c_api
/// @{
friend class primitive;
// gcc bug??? using handle::handle;
- /// Kinds of engines
+ /// Kinds of engines.
enum kind {
/// An unspecified engine
any = mkldnn_any_engine,
/// @addtogroup cpp_api_memory Memory
/// A primitive to describe and store data.
///
-/// For more information please refer to @ref c_api_memory in @ref c_api
+/// For more information, refer to @ref c_api_memory in @ref c_api.
/// @{
/// Memory primitive that describes the data.
s16 = mkldnn_s16,
s8 = mkldnn_s8,
u8 = mkldnn_u8,
+ bin = mkldnn_bin,
};
/// Memory format specification. See #mkldnn_memory_format_t
nchw = mkldnn_nchw,
nhwc = mkldnn_nhwc,
chwn = mkldnn_chwn,
+ nCw4c = mkldnn_nCw4c,
nCw8c = mkldnn_nCw8c,
+ nChw4c = mkldnn_nChw4c,
nChw8c = mkldnn_nChw8c,
nChw16c = mkldnn_nChw16c,
ncdhw = mkldnn_ncdhw,
ndhwc = mkldnn_ndhwc,
+ nCdhw4c = mkldnn_nCdhw4c,
nCdhw8c = mkldnn_nCdhw8c,
nCdhw16c = mkldnn_nCdhw16c,
oi = mkldnn_oi,
io = mkldnn_io,
oiw = mkldnn_oiw,
wio = mkldnn_wio,
+ Owi4o = mkldnn_Owi4o,
+ OIw4i4o = mkldnn_OIw4i4o,
Owi8o = mkldnn_Owi8o,
OIw8o8i = mkldnn_OIw8o8i,
OIw8i8o = mkldnn_OIw8i8o,
OIw16i16o = mkldnn_OIw16i16o,
OIw16o16i = mkldnn_OIw16o16i,
+ Oiw4o = mkldnn_Oiw4o,
Oiw16o = mkldnn_Oiw16o,
Owi16o = mkldnn_Owi16o,
OIw8i16o2i = mkldnn_OIw8i16o2i,
oihw = mkldnn_oihw,
ihwo = mkldnn_ihwo,
hwio = mkldnn_hwio,
+ iohw = mkldnn_iohw,
hwio_s8s8 = mkldnn_hwio_s8s8,
dhwio = mkldnn_dhwio,
oidhw = mkldnn_oidhw,
+ OIdhw4i4o = mkldnn_OIdhw4i4o,
+ Odhwi4o = mkldnn_Odhwi4o,
OIdhw8i8o = mkldnn_OIdhw8i8o,
OIdhw8o8i = mkldnn_OIdhw8o8i,
Odhwi8o = mkldnn_Odhwi8o,
OIdhw16i16o = mkldnn_OIdhw16i16o,
OIdhw16o16i = mkldnn_OIdhw16o16i,
+ Oidhw4o = mkldnn_Oidhw4o,
Oidhw16o = mkldnn_Oidhw16o,
Odhwi16o = mkldnn_Odhwi16o,
oIhw8i = mkldnn_oIhw8i,
oIhw16i = mkldnn_oIhw16i,
oIdhw8i = mkldnn_oIdhw8i,
oIdhw16i = mkldnn_oIdhw16i,
+ OIhw4i4o = mkldnn_OIhw4i4o,
OIhw8i8o = mkldnn_OIhw8i8o,
OIhw16i16o = mkldnn_OIhw16i16o,
OIhw8o8i = mkldnn_OIhw8o8i,
OIhw4i16o4i = mkldnn_OIhw4i16o4i,
OIhw4i16o4i_s8s8 = mkldnn_OIhw4i16o4i_s8s8,
Oihw8o = mkldnn_Oihw8o,
+ Oihw4o = mkldnn_Oihw4o,
Oihw16o = mkldnn_Oihw16o,
Ohwi8o = mkldnn_Ohwi8o,
+ Ohwi4o = mkldnn_Ohwi4o,
Ohwi16o = mkldnn_Ohwi16o,
OhIw16o4i = mkldnn_OhIw16o4i,
OhIw8o4i = mkldnn_OhIw8o4i,
+ OhIw8o32i = mkldnn_OhIw8o32i,
+ OhIw16o32i = mkldnn_OhIw16o32i,
OhIw8o4i_s8s8 = mkldnn_OhIw8o4i_s8s8,
goiw = mkldnn_goiw,
+ gOwi4o = mkldnn_gOwi4o,
+ gOIw4i4o = mkldnn_gOIw4i4o,
gOwi8o = mkldnn_gOwi8o,
gOIw8o8i = mkldnn_gOIw8o8i,
gOIw8i8o = mkldnn_gOIw8i8o,
gOIw16i16o = mkldnn_gOIw16i16o,
gOIw16o16i = mkldnn_gOIw16o16i,
+ gOiw4o = mkldnn_gOiw4o,
gOiw16o = mkldnn_gOiw16o,
gOwi16o = mkldnn_gOwi16o,
gOIw8i16o2i = mkldnn_gOIw8i16o2i,
gOIw8o16i2o = mkldnn_gOIw8o16i2o,
goihw = mkldnn_goihw,
hwigo = mkldnn_hwigo,
+ giohw = mkldnn_giohw,
hwigo_s8s8 = mkldnn_hwigo_s8s8,
+ gOIdhw4i4o = mkldnn_gOIdhw4i4o,
+ gOdhwi4o = mkldnn_gOdhwi4o,
gOIdhw8i8o = mkldnn_gOIdhw8i8o,
gOIdhw8o8i = mkldnn_gOIdhw8o8i,
gOdhwi8o = mkldnn_gOdhwi8o,
+ gOIhw4i4o = mkldnn_gOIhw4i4o,
gOIhw8i8o = mkldnn_gOIhw8i8o,
gOIhw16i16o = mkldnn_gOIhw16i16o,
gOIhw8i16o2i = mkldnn_gOIhw8i16o2i,
gOIhw8o16i2o = mkldnn_gOIhw8o16i2o,
gOIhw4i16o4i = mkldnn_gOIhw4i16o4i,
gOIhw4i16o4i_s8s8 = mkldnn_gOIhw4i16o4i_s8s8,
+ gOIhw2i8o4i = mkldnn_gOIhw2i8o4i,
+ gOIhw2i8o4i_s8s8 = mkldnn_gOIhw2i8o4i_s8s8,
gOihw8o = mkldnn_gOihw8o,
+ gOihw4o = mkldnn_gOihw4o,
gOihw16o = mkldnn_gOihw16o,
+ gOhwi4o = mkldnn_gOhwi4o,
gOhwi8o = mkldnn_gOhwi8o,
gOhwi16o = mkldnn_gOhwi16o,
Goihw8g = mkldnn_Goihw8g,
Goihw16g = mkldnn_Goihw16g,
+ Goihw16g_s8s8 = mkldnn_Goihw16g_s8s8,
+ gOIhw4o4i = mkldnn_gOIhw4o4i,
+ gOIhw4o4i_s8s8 = mkldnn_gOIhw4o4i_s8s8,
gOIhw8o8i = mkldnn_gOIhw8o8i,
gOIhw16o16i = mkldnn_gOIhw16o16i,
gIOhw16o16i = mkldnn_gIOhw16o16i,
goidhw = mkldnn_goidhw,
gOIdhw16i16o = mkldnn_gOIdhw16i16o,
gOIdhw16o16i = mkldnn_gOIdhw16o16i,
+ gOidhw4o = mkldnn_gOidhw4o,
gOidhw16o = mkldnn_gOidhw16o,
gOdhwi16o = mkldnn_gOdhwi16o,
ntc = mkldnn_ntc,
tnc = mkldnn_tnc,
ldsnc = mkldnn_ldsnc,
ldigo = mkldnn_ldigo,
- ldigo_p = mkldnn_ldigo_p,
ldgoi = mkldnn_ldgoi,
- ldgoi_p = mkldnn_ldgoi_p,
ldgo = mkldnn_ldgo,
+ rnn_packed = mkldnn_rnn_packed,
wino_fmt = mkldnn_wino_fmt,
format_last = mkldnn_format_last,
};
/// @}
/// @addtogroup cpp_api_concat Concat
-/// A primitive to concatenate data by arbitrary dimension
+/// A primitive to concatenate data by arbitrary dimension.
///
/// @sa @ref c_api_concat in @ref c_api
/// @{
/// @}
/// @addtogroup cpp_api_sum Sum
-/// A primitive to sum data
+/// A primitive to sum data.
///
/// @sa @ref c_api_sum in @ref c_api
/// @{
reset(result);
}
- /** @deprecated: api backwards compatibility for double scales type */
- MKLDNN_DEPRECATED
- primitive_desc(const memory::desc &output, std::vector<double> scale,
- std::vector<memory::primitive_desc> inputs) {
- mkldnn_primitive_desc_t result;
-
- auto c_api_inputs = cpp_to_c(inputs);
- auto scale_f = scale_to_float(scale);
-
- error::wrap_c_api(mkldnn_sum_primitive_desc_create(
- &result, &output.data, (int)c_api_inputs.size(),
- &scale_f[0], &c_api_inputs[0]),
- "could not create a sum primitive descriptor");
- reset(result);
- }
-
- /** @deprecated: api backwards compatibility for double scales type */
- MKLDNN_DEPRECATED
- primitive_desc(std::vector<double> scale,
- std::vector<memory::primitive_desc> inputs) {
- mkldnn_primitive_desc_t result;
-
- auto c_api_inputs = cpp_to_c(inputs);
- auto scale_f = scale_to_float(scale);
-
- error::wrap_c_api(mkldnn_sum_primitive_desc_create(
- &result, nullptr, (int)c_api_inputs.size(), &scale_f[0],
- &c_api_inputs[0]),
- "could not create a sum primitive descriptor");
- reset(result);
- }
-
memory::primitive_desc dst_primitive_desc() const {
memory::primitive_desc adesc;
mkldnn_primitive_desc_t cdesc;
"could not create a sum primitive");
reset(result);
}
-
-private:
- static std::vector<float> scale_to_float(const std::vector<double> &vd) {
- std::vector<float> vf(vd.size());
- std::transform(vd.begin(), vd.end(), vf.begin(),
- [=](double x){return (float)x;});
- return vf;
- }
};
/// @}
/// @addtogroup cpp_api_primitive_descriptors Primitive descriptors
/// @{
-/// A base class for all primitive descriptors
+/// A base class for all primitive descriptors.
struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
primitive_desc(const_mkldnn_op_desc_t desc, const primitive_attr *attr,
const engine &e, const_mkldnn_primitive_desc_t hint_fwd_pd) {
return res;
}
- /// Advances the next implementation for the given op descriptor
+ /// Advances the next implementation for the given op descriptor.
///
/// Returns:
/// - @c true on success
return true;
}
- /// Queries and returns requested memory primitive descriptor
+ /// Queries and returns requested memory primitive descriptor.
memory::primitive_desc query_mpd(query what, int idx = 0) const {
std::vector<query> valid_w{input_pd, output_pd, src_pd, diff_src_pd,
weights_pd, diff_weights_pd, dst_pd, diff_dst_pd, workspace_pd};
}
};
-/// A merged convolution-relu primitive for inference mode only
-///
-/// @deprecated consider using convolution_forward with post_ops
-/// (e.g. post_ops::append_eltwise(1.f, #eltwise_relu, negative_slope, 0.f)
-struct convolution_relu_forward : public primitive {
- struct desc {
- mkldnn_convolution_relu_desc_t data;
-
- desc(const convolution_forward::desc conv_desc,
- const float negative_slope) {
- error::wrap_c_api(mkldnn_convolution_relu_desc_init(&data,
- &conv_desc.data, negative_slope),
- "could not create a convolution_relu_forward descriptor");
- }
- };
-
- struct primitive_desc : public mkldnn::primitive_desc {
- primitive_desc(const desc &desc, const engine &e)
- : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
-
- REG_QUERY_MPD(src, src, 0);
- REG_QUERY_MPD(weights, weights, 0);
- REG_QUERY_MPD(bias, weights, 1);
- REG_QUERY_MPD(dst, dst, 0);
- };
-
- /// @deprecated consider using convolution_forward + post_ops
- MKLDNN_DEPRECATED
- convolution_relu_forward(const primitive_desc &aprimitive_desc,
- const primitive::at &src, const primitive::at &weights,
- const primitive::at &bias, const memory &dst) {
- mkldnn_primitive_t result;
- mkldnn_primitive_at_t inputs[] = { src.data, weights.data,
- bias.data };
- const_mkldnn_primitive_t outputs[] = { dst.get() };
- check_num_parameters(aprimitive_desc.get(), 3, 1,
- "convolution relu forward");
- error::wrap_c_api(mkldnn_primitive_create(&result,
- aprimitive_desc.get(), inputs, outputs),
- "could not create a convolution relu forward primitive");
- reset(result);
- }
-
- /// @deprecated consider using convolution_forward + post_ops
- MKLDNN_DEPRECATED
- convolution_relu_forward(const primitive_desc &aprimitive_desc,
- const primitive::at &src, const primitive::at &weights,
- const memory &dst) {
- mkldnn_primitive_t result;
- mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
- const_mkldnn_primitive_t outputs[] = { dst.get() };
- check_num_parameters(aprimitive_desc.get(), 2, 1,
- "convolution relu forward");
- error::wrap_c_api(mkldnn_primitive_create(&result,
- aprimitive_desc.get(), inputs, outputs),
- "could not create a convolution relu forward primitive");
- reset(result);
- }
-};
-
/// @}
/// @addtogroup cpp_api_deconvolution Deconvolution
/// @}
/// @addtogroup cpp_api_eltwise Eltwise
-/// A primitive to compute element wise operations like parametric rectifier
+/// A primitive to compute element-wise operations like parametric rectifier
/// linear unit (ReLU).
///
/// @sa @ref c_api_eltwise in @ref c_api
static_cast<float>(alpha), static_cast<float>(beta)),
"could not create a eltwise forward descriptor");
}
-
- /** @deprecated: api backward compatibility for relu */
- template <typename T>
- MKLDNN_DEPRECATED
- desc(prop_kind aprop_kind, const memory::desc &src_desc,
- T negative_slope)
- : desc(aprop_kind, eltwise_relu, src_desc, negative_slope) {}
};
struct primitive_desc : public mkldnn::primitive_desc {
}
};
-typedef eltwise_forward relu_forward;
-
struct eltwise_backward : public primitive {
struct desc {
mkldnn_eltwise_desc_t data;
static_cast<float>(beta)),
"could not create a eltwise backward descriptor");
}
-
- /** @deprecated: api backward compatibility for relu */
- template <typename T>
- MKLDNN_DEPRECATED
- desc(const memory::desc &diff_data_desc, const memory::desc &data_desc,
- T negative_slope): desc(eltwise_relu, diff_data_desc, data_desc,
- negative_slope) {}
};
struct primitive_desc : public mkldnn::primitive_desc {
}
};
-typedef eltwise_backward relu_backward;
-
/// @}
/// @addtogroup cpp_api_depthwise Depthwise
const memory::desc &bias_desc) {
error::wrap_c_api(mkldnn_depthwise_forward_desc_init(&data,
mkldnn::convert_to_c(aprop_kind),
- mkldnn::convert_to_c(alg_kind),
- &src_desc.data, &dst_desc.data,
+ mkldnn::convert_to_c(alg_kind),
+ &src_desc.data, &dst_desc.data,
&weights_desc.data, &bias_desc.data),
"could not create a depthwise forward descriptor");
}
}
};
- struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
- primitive_desc(const desc &adesc, const engine &aengine) {
- mkldnn_primitive_desc_t result;
- error::wrap_c_api(mkldnn_primitive_desc_create(
- &result, &adesc.data, aengine.get(), nullptr),
- "could not create a depthwise forward primitive descriptor");
- reset(result);
- }
+ struct primitive_desc : public mkldnn::primitive_desc {
+ primitive_desc(const desc &desc, const engine &e)
+ : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
- engine get_engine() { return engine::query(*this); }
+ primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
+ : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
+
+ REG_QUERY_MPD(src, src, 0);
+ REG_QUERY_MPD(dst, dst, 0);
};
depthwise_forward(const primitive_desc &aprimitive_desc,
reset(result);
}
- /// @warning batch_normalization_forward has 2 constructors with very
+ /// @warning batch_normalization_forward has two constructors with very
/// similar signatures:
/// - (pd, src, weights, dst, mean, variance) // 2 in, 3 out
/// - (pd, src, dst, mean, variance, workspace) // 1 in, 4 out
- /// The only way to distinguish between those is to explicitly
- /// cast all input parameters to their type, i.e. to
+ /// The only way to distinguish between them is to explicitly
+ /// cast all input parameters to their type; that is, to
/// const primitive:at &.
batch_normalization_forward(const primitive_desc &aprimitive_desc,
const primitive::at &src, const primitive::at &weights,
reset(result);
}
- /// @warning batch_normalization_forward has 2 constructors with very
+ /// @warning batch_normalization_forward has two constructors with very
/// similar signatures:
/// - (pd, src, weights, dst, mean, variance) // 2 in, 3 out
/// - (pd, src, dst, mean, variance, workspace) // 1 in, 4 out
- /// The only way to distinguish between those is to explicitly
- /// cast all input parameters to their type, i.e. to
+ /// The only way to distinguish between them is to explicitly
+ /// cast all input parameters to their type; that is, to
/// const primitive:at &.
- /// @note to make users' experience a little bit better this constructor
- /// checks if whether parameters match corresponding primitive
- /// descriptor, and if they are not -- call the other (proper)
- /// constructor. Yeah, this is still very ugly...
+ /// @note To make users' experience a little better, this constructor
+ /// checks whether parameters match the corresponding primitive
+ /// descriptor, and if not, calls the other (proper) constructor.
batch_normalization_forward(const primitive_desc &aprimitive_desc,
const primitive::at &src, const memory &dst, const memory &mean,
const memory &variance, const memory &workspace) {
};
struct primitive_desc : public mkldnn::primitive_desc {
- MKLDNN_DEPRECATED
- primitive_desc(const desc &desc, const engine &e)
- : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
-
primitive_desc(const desc &desc, const engine &e,
const rnn_forward::primitive_desc &hint_fwd_pd)
: mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
/// @}
+/// @addtogroup cpp_api_binary_convolution Binary convolution
+/// A primitive to compute binary convolution using different algorithms.
+///
+/// @sa @ref c_api_binary_convolution in @ref c_api
+/// @{
+
+struct binary_convolution_forward: public primitive {
+ struct desc {
+ mkldnn_binary_convolution_desc_t data;
+ desc(prop_kind aprop_kind, algorithm aalgorithm,
+ const memory::desc &src_desc,
+ const memory::desc &weights_desc,
+ const memory::desc &dst_desc,
+ const memory::dims strides,
+ const memory::dims dilates,
+ const memory::dims padding_l,
+ const memory::dims padding_r,
+ const float pad_value) {
+ memory::validate_dims(strides);
+ memory::validate_dims(dilates);
+ memory::validate_dims(padding_l);
+ memory::validate_dims(padding_r);
+ error::wrap_c_api(
+ mkldnn_dilated_binary_convolution_forward_desc_init(&data,
+ mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
+ &src_desc.data, &weights_desc.data, &dst_desc.data,
+ &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
+ pad_value),
+ "could not create a dilated binary convolution forward descriptor");
+ }
+ };
+
+ struct primitive_desc : public mkldnn::primitive_desc {
+ primitive_desc(const desc &desc, const engine &e)
+ : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
+
+ primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
+ : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
+
+ REG_QUERY_MPD(src, src, 0);
+ REG_QUERY_MPD(weights, weights, 0);
+ REG_QUERY_MPD(dst, dst, 0);
+ };
+
+ binary_convolution_forward(const primitive_desc &aprimitive_desc,
+ const primitive::at &src, const primitive::at &weights, const memory &dst) {
+ mkldnn_primitive_t result;
+ mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
+ const_mkldnn_primitive_t outputs[] = { dst.get() };
+ check_num_parameters(aprimitive_desc.get(), 2, 1,
+ "binary convolution forward");
+ error::wrap_c_api(mkldnn_primitive_create(&result,
+ aprimitive_desc.get(), inputs, outputs),
+ "could not create a binary convolution forward primitive");
+ reset(result);
+ }
+};
+
+/// @}
+
+/// @addtogroup cpp_api_binarization Binarization
+/// @{
+
+struct binarization_forward : public primitive {
+ struct desc {
+ mkldnn_binarization_desc_t data;
+
+ desc(prop_kind aprop_kind, algorithm alg_kind,
+ const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc) {
+ error::wrap_c_api(mkldnn_binarization_forward_desc_init(&data,
+ mkldnn::convert_to_c(aprop_kind),
+ mkldnn::convert_to_c(alg_kind),
+ &src_desc.data, &dst_desc.data,
+ &weights_desc.data),
+ "could not create a binarization forward descriptor");
+ }
+ };
+
+ struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
+ primitive_desc(const desc &adesc, const engine &aengine) {
+ mkldnn_primitive_desc_t result;
+ error::wrap_c_api(mkldnn_primitive_desc_create(
+ &result, &adesc.data, aengine.get(), nullptr),
+ "could not create a binarization forward primitive descriptor");
+ reset(result);
+ }
+
+ engine get_engine() { return engine::query(*this); }
+ };
+
+ binarization_forward(const primitive_desc &aprimitive_desc,
+ const primitive::at &src, const primitive::at &weights, const memory &dst) {
+ mkldnn_primitive_t result;
+ mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
+ const_mkldnn_primitive_t outputs[] = { dst.get() };
+ error::wrap_c_api(mkldnn_primitive_create(&result, aprimitive_desc.get(), inputs, outputs),
+ "could not create a binarization forward primitive");
+ reset(result);
+ }
+};
+
+/// @}
+
/// @} Primitives
/// @addtogroup cpp_api_stream Stream
-/// Execution stream operations
+/// Execution stream operations.
///
/// @sa @ref c_api_stream in @ref c_api
/// @{
/// Waits for all computations submitted to the stream to complete.
///
- /// @param block Specifies whether the operation should wait indefinitely or return
- /// immediately.
+ /// @param block Specifies whether the operation should wait indefinitely or
+ /// return immediately.
/// @returns @c true if all computations completed.
/// @returns @c false if not all computations completed.
bool wait(bool block = true) {