namespace mkldnn {
namespace impl {
+struct rnn_data_qparams_t : public c_compatible {
+ rnn_data_qparams_t() : scale_(1.), shift_(0.) {}
+ bool has_default_values() const { return (scale_ == 1. && shift_ == 0.); }
+
+ status_t set(float scale, float shift) {
+ scale_ = scale;
+ shift_ = shift;
+ return status::success;
+ }
+
+ float scale_;
+ float shift_;
+};
+
struct scales_t: public c_compatible {
scales_t(): count_(1), mask_(0), scales_(scales_buf_)
{ set(1.); }
status_t set(int count, int mask, const float *scales);
status_t set(float single_scale) { return this->set(1, 0, &single_scale); }
- status_t scale(float factor);
int count_;
int mask_;
struct mkldnn_post_ops: public mkldnn::impl::c_compatible {
struct entry_t {
+ struct eltwise_t {
+ mkldnn::impl::alg_kind_t alg;
+ float scale, alpha, beta;
+ };
+
mkldnn::impl::primitive_kind_t kind;
union {
struct { float scale; } sum;
- struct {
- mkldnn::impl::alg_kind_t alg;
- float scale, alpha, beta;
- } eltwise;
+ eltwise_t eltwise;
struct {
mkldnn::impl::alg_kind_t alg;
const float* weights_data;
const float* weights_data;
const float* biases_data;
} dw_conv;
+ struct {
+ mkldnn::impl::alg_kind_t alg;
+ const float* weights_data;
+ } binarization;
};
+ bool is_eltwise(bool require_scale_one = true) const {
+ using namespace mkldnn::impl;
+ return kind == primitive_kind::eltwise
+ && IMPLICATION(require_scale_one, eltwise.scale == 1.f);
+ }
+
bool is_relu(bool require_scale_one = true,
bool require_nslope_zero = true) const {
using namespace mkldnn::impl;
- return kind == primitive_kind::eltwise
- && IMPLICATION(require_scale_one, eltwise.scale == 1.f)
+ return is_eltwise(require_scale_one)
&& eltwise.alg == alg_kind::eltwise_relu
&& IMPLICATION(require_nslope_zero, eltwise.alpha == 0.f);
}
+
bool is_sum(bool require_scale_one = true) const {
using namespace mkldnn::impl;
return kind == primitive_kind::sum
&& IMPLICATION(require_scale_one, sum.scale == 1.f);
}
- bool is_eltwise(bool require_scale_one = true) const {
- using namespace mkldnn::impl;
- return kind == primitive_kind::eltwise
- && IMPLICATION(require_scale_one, eltwise.scale == 1.f);
- }
+
bool is_depthwise() const {
using namespace mkldnn::impl;
return kind == primitive_kind::depthwise;
}
+
bool is_dw_conv() const {
using namespace mkldnn::impl;
return kind == primitive_kind::convolution;
}
+ bool is_binarization() const {
+ using namespace mkldnn::impl;
+ return kind == primitive_kind::binarization;
+ }
};
mkldnn_post_ops(): len_(0) {}
mkldnn::impl::status_t append_dw_conv(int in_h, int in_w, int ker_h, int ker_w, int str_h, int str_w,
const float* weights_data,
const float* biases_data);
+ mkldnn::impl::status_t append_binarization(mkldnn::impl::alg_kind_t alg, const float* weights_data);
int find(mkldnn::impl::primitive_kind_t kind, int start = 0,
int stop = -1) const {
return true
&& round_mode_ == mkldnn::impl::round_mode::nearest
&& output_scales_.has_default_values()
- && post_ops_.has_default_values() ;
+ && post_ops_.has_default_values()
+ && rnn_data_qparams_.has_default_values()
+ && rnn_weights_qparams_.has_default_values();
}
mkldnn::impl::status_t set_round_mode(
mkldnn::impl::round_mode_t round_mode_;
mkldnn::impl::scales_t output_scales_;
mkldnn::impl::post_ops_t post_ops_;
+ mkldnn::impl::rnn_data_qparams_t rnn_data_qparams_;
+ mkldnn::impl::scales_t rnn_weights_qparams_;
};
#endif