1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef PRIMITIVE_ATTR_HPP
18 #define PRIMITIVE_ATTR_HPP
23 #include "c_types_map.hpp"
30 struct rnn_data_qparams_t : public c_compatible {
31 rnn_data_qparams_t() : scale_(1.), shift_(0.) {}
32 bool has_default_values() const { return (scale_ == 1. && shift_ == 0.); }
34 status_t set(float scale, float shift) {
37 return status::success;
44 struct scales_t: public c_compatible {
45 scales_t(): count_(1), mask_(0), scales_(scales_buf_)
48 scales_t(const scales_t &rhs): scales_t()
49 { set(rhs.count_, rhs.mask_, rhs.scales_); }
51 ~scales_t() { cleanup(); }
53 scales_t &operator=(const scales_t &rhs) {
56 status_t status = set(rhs.count_, rhs.mask_, rhs.scales_);
57 assert(status == status::success);
62 bool has_default_values() const {
63 for (int c = 0; c < count_; ++c) {
64 if(scales_[c] != 1.) return false;
69 status_t set(int count, int mask, const float *scales);
70 status_t set(float single_scale) { return this->set(1, 0, &single_scale); }
77 enum { scales_buf_size = 16 };
78 float scales_buf_[scales_buf_size];
81 if (scales_ != scales_buf_ && scales_ != nullptr)
86 scales_ = scales_buf_;
93 struct mkldnn_post_ops: public mkldnn::impl::c_compatible {
96 mkldnn::impl::alg_kind_t alg;
97 float scale, alpha, beta;
100 mkldnn::impl::primitive_kind_t kind;
102 struct { float scale; } sum;
105 mkldnn::impl::alg_kind_t alg;
106 const float* weights_data;
107 const float* biases_data;
116 const float* weights_data;
117 const float* biases_data;
120 mkldnn::impl::alg_kind_t alg;
121 const float* weights_data;
125 bool is_eltwise(bool require_scale_one = true) const {
126 using namespace mkldnn::impl;
127 return kind == primitive_kind::eltwise
128 && IMPLICATION(require_scale_one, eltwise.scale == 1.f);
131 bool is_relu(bool require_scale_one = true,
132 bool require_nslope_zero = true) const {
133 using namespace mkldnn::impl;
134 return is_eltwise(require_scale_one)
135 && eltwise.alg == alg_kind::eltwise_relu
136 && IMPLICATION(require_nslope_zero, eltwise.alpha == 0.f);
139 bool is_sum(bool require_scale_one = true) const {
140 using namespace mkldnn::impl;
141 return kind == primitive_kind::sum
142 && IMPLICATION(require_scale_one, sum.scale == 1.f);
145 bool is_depthwise() const {
146 using namespace mkldnn::impl;
147 return kind == primitive_kind::depthwise;
150 bool is_dw_conv() const {
151 using namespace mkldnn::impl;
152 return kind == primitive_kind::convolution;
154 bool is_binarization() const {
155 using namespace mkldnn::impl;
156 return kind == primitive_kind::binarization;
160 mkldnn_post_ops(): len_(0) {}
162 mkldnn::impl::status_t append_sum(float scale);
163 mkldnn::impl::status_t append_eltwise(float scale,
164 mkldnn::impl::alg_kind_t alg, float alpha, float beta);
165 mkldnn::impl::status_t append_depthwise(mkldnn::impl::alg_kind_t alg,
166 const float* weights_data, const float* biases_data);
167 mkldnn::impl::status_t append_dw_conv(int in_h, int in_w, int ker_h, int ker_w, int str_h, int str_w,
168 const float* weights_data,
169 const float* biases_data);
170 mkldnn::impl::status_t append_binarization(mkldnn::impl::alg_kind_t alg, const float* weights_data);
172 int find(mkldnn::impl::primitive_kind_t kind, int start = 0,
173 int stop = -1) const {
174 if (stop == -1) stop = len_;
175 stop = mkldnn::impl::nstl::min(stop, len_);
176 for (int idx = start; idx < stop; ++idx)
177 if (entry_[idx].kind == kind) return idx;
181 bool has_default_values() const { return len_ == 0; }
183 bool contain(mkldnn::impl::primitive_kind_t kind, int index) const
184 { return find(kind, index, index + 1) == index; }
186 enum { capacity = 4 };
189 entry_t entry_[capacity];
192 struct mkldnn_primitive_attr: public mkldnn::impl::c_compatible {
193 mkldnn_primitive_attr()
194 : round_mode_(mkldnn::impl::round_mode::nearest) {}
196 mkldnn_primitive_attr *clone() const
197 { return new mkldnn_primitive_attr(*this); }
199 bool has_default_values() const {
201 && round_mode_ == mkldnn::impl::round_mode::nearest
202 && output_scales_.has_default_values()
203 && post_ops_.has_default_values()
204 && rnn_data_qparams_.has_default_values()
205 && rnn_weights_qparams_.has_default_values();
208 mkldnn::impl::status_t set_round_mode(
209 mkldnn::impl::round_mode_t round_mode);
210 mkldnn::impl::status_t set_post_ops(
211 const mkldnn::impl::post_ops_t &post_ops);
213 mkldnn::impl::round_mode_t round_mode_;
214 mkldnn::impl::scales_t output_scales_;
215 mkldnn::impl::post_ops_t post_ops_;
216 mkldnn::impl::rnn_data_qparams_t rnn_data_qparams_;
217 mkldnn::impl::scales_t rnn_weights_qparams_;