1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
19 #include "c_types_map.hpp"
20 #include "primitive_attr.hpp"
21 #include "type_helpers.hpp"
24 using namespace mkldnn::impl;
25 using namespace mkldnn::impl::status;
26 using namespace mkldnn::impl::utils;
31 status_t scales_t::set(int count, int mask, const float *scales) {
38 scales_ = scales_buf_;
39 utils::array_set(scales_, scales[0], scales_buf_size);
41 scales_ = (float *)impl::malloc(count_ * sizeof(*scales_), 64);
42 if (scales_ == nullptr)
43 return status::out_of_memory;
45 for (int c = 0; c < count_; ++c)
46 scales_[c] = scales[c];
49 return status::success;
55 status_t post_ops_t::append_sum(float scale) {
59 entry_[len_].kind = primitive_kind::sum;
60 entry_[len_].sum.scale = scale;
67 status_t post_ops_t::append_eltwise(float scale, alg_kind_t alg, float alpha,
69 using namespace mkldnn::impl::alg_kind;
70 bool known_alg = one_of(alg, eltwise_relu, eltwise_tanh, eltwise_elu,
71 eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
72 eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic,
73 eltwise_clamp, eltwise_exp, eltwise_not);
75 return invalid_arguments;
80 entry_[len_].kind = primitive_kind::eltwise;
81 entry_[len_].eltwise.scale = scale;
82 entry_[len_].eltwise.alg = alg;
83 entry_[len_].eltwise.alpha = alpha;
84 entry_[len_].eltwise.beta = beta;
91 status_t post_ops_t::append_depthwise(alg_kind_t alg,
92 const float* weights_data, const float* biases_data) {
93 using namespace mkldnn::impl::alg_kind;
94 bool known_alg = one_of(alg, depthwise_scale_shift, depthwise_prelu);
96 return invalid_arguments;
101 entry_[len_].kind = primitive_kind::depthwise;
102 entry_[len_].depthwise.alg = alg;
103 entry_[len_].depthwise.weights_data = weights_data;
104 entry_[len_].depthwise.biases_data = biases_data;
111 status_t post_ops_t::append_dw_conv(int in_h, int in_w, int ker_h, int ker_w, int str_h, int str_w,
112 const float* weights_data,
113 const float* biases_data) {
114 if (len_ == capacity)
115 return out_of_memory;
117 entry_[len_].kind = primitive_kind::convolution;
118 entry_[len_].dw_conv.in_h = in_h;
119 entry_[len_].dw_conv.in_w = in_w;
120 entry_[len_].dw_conv.ker_h = ker_h;
121 entry_[len_].dw_conv.ker_w = ker_w;
122 entry_[len_].dw_conv.str_h = str_h;
123 entry_[len_].dw_conv.str_w = str_w;
124 entry_[len_].dw_conv.weights_data = weights_data;
125 entry_[len_].dw_conv.biases_data = biases_data;
132 status_t post_ops_t::append_binarization(alg_kind_t alg, const float* weights_data) {
133 using namespace mkldnn::impl::alg_kind;
134 bool known_alg = one_of(alg, binarization_depthwise);
136 return invalid_arguments;
138 if (len_ == capacity)
139 return out_of_memory;
141 entry_[len_].kind = primitive_kind::binarization;
142 entry_[len_].binarization.alg = alg;
143 entry_[len_].binarization.weights_data = weights_data;
150 status_t primitive_attr_t::set_round_mode(round_mode_t round_mode) {
151 using namespace mkldnn::impl::round_mode;
153 const bool ok = one_of(round_mode, nearest, down);
155 return invalid_arguments;
157 round_mode_ = round_mode;
161 status_t primitive_attr_t::set_post_ops(const post_ops_t &post_ops) {
162 this->post_ops_ = post_ops;
168 status_t mkldnn_primitive_attr_create(primitive_attr_t **attr) {
170 return invalid_arguments;
172 return safe_ptr_assign<mkldnn_primitive_attr>(*attr,
173 new mkldnn_primitive_attr);
176 status_t mkldnn_primitive_attr_clone(primitive_attr_t **attr,
177 const primitive_attr_t *existing_attr) {
178 if (any_null(attr, existing_attr))
179 return invalid_arguments;
181 return safe_ptr_assign<mkldnn_primitive_attr>(*attr,
182 existing_attr->clone());
185 status_t mkldnn_primitive_attr_destroy(primitive_attr_t *attr) {
192 status_t mkldnn_primitive_attr_get_int_output_round_mode(
193 const primitive_attr_t *attr, round_mode_t *round_mode) {
194 if (any_null(attr, round_mode))
195 return invalid_arguments;
197 *round_mode = attr->round_mode_;
202 status_t mkldnn_primitive_attr_set_int_output_round_mode(
203 primitive_attr_t *attr, round_mode_t round_mode) {
205 return invalid_arguments;
207 return attr->set_round_mode(round_mode);
210 status_t mkldnn_primitive_attr_get_output_scales(const primitive_attr_t *attr,
211 int *count, int *mask, const float **scales) {
212 if (any_null(attr, count, mask, scales))
213 return invalid_arguments;
215 *count = attr->output_scales_.count_;
216 *mask = attr->output_scales_.mask_;
217 *scales = attr->output_scales_.scales_;
222 status_t mkldnn_primitive_attr_set_output_scales(primitive_attr_t *attr,
223 int count, int mask, const float *scales) {
224 bool ok = !any_null(attr, scales) && count > 0 && mask >= 0;
226 return invalid_arguments;
228 return attr->output_scales_.set(count, mask, scales);
231 status_t mkldnn_primitive_attr_get_post_ops(const primitive_attr_t *attr,
232 const post_ops_t **post_ops) {
233 if (any_null(attr, post_ops))
234 return invalid_arguments;
236 *post_ops = &attr->post_ops_;
240 status_t mkldnn_primitive_attr_set_post_ops(primitive_attr_t *attr,
241 const post_ops_t *post_ops) {
242 if (any_null(attr, post_ops))
243 return invalid_arguments;
245 return attr->set_post_ops(*post_ops);
248 status_t mkldnn_post_ops_create(post_ops_t **post_ops) {
249 if (post_ops == nullptr)
250 return invalid_arguments;
252 return safe_ptr_assign<mkldnn_post_ops>(*post_ops, new mkldnn_post_ops);
255 status_t mkldnn_post_ops_destroy(post_ops_t *post_ops) {
262 int mkldnn_post_ops_len(const post_ops_t *post_ops) {
264 return post_ops->len_;
269 primitive_kind_t mkldnn_post_ops_get_kind(const post_ops_t *post_ops,
271 bool ok = post_ops && 0 <= index && index < post_ops->len_;
273 return primitive_kind::undefined;
275 return post_ops->entry_[index].kind;
278 status_t mkldnn_post_ops_append_sum(post_ops_t *post_ops, float scale) {
279 if (post_ops == nullptr)
280 return invalid_arguments;
282 return post_ops->append_sum(scale);
286 bool simple_get_params_check(const post_ops_t *post_ops, int index,
287 primitive_kind_t kind) {
289 && post_ops != nullptr
291 && index < post_ops->len_
292 && post_ops->entry_[index].kind == kind;
297 status_t mkldnn_post_ops_get_params_sum(const post_ops_t *post_ops, int index,
300 && simple_get_params_check(post_ops, index, primitive_kind::sum)
303 return invalid_arguments;
305 *scale = post_ops->entry_[index].sum.scale;
309 status_t mkldnn_post_ops_append_eltwise(post_ops_t *post_ops, float scale,
310 alg_kind_t kind, float alpha, float beta) {
311 if (post_ops == nullptr)
312 return invalid_arguments;
314 return post_ops->append_eltwise(scale, kind, alpha, beta);
317 status_t mkldnn_post_ops_get_params_eltwise(const post_ops_t *post_ops,
318 int index, float *scale, alg_kind_t *alg, float *alpha, float *beta) {
320 && simple_get_params_check(post_ops, index, primitive_kind::eltwise)
321 && !any_null(scale, alpha, beta);
323 return invalid_arguments;
325 const auto &e = post_ops->entry_[index].eltwise;
334 status_t mkldnn_primitive_attr_set_rnn_data_qparams(
335 primitive_attr_t *attr, const float scale, const float shift) {
337 return invalid_arguments;
339 return attr->rnn_data_qparams_.set(scale, shift);
342 status_t mkldnn_primitive_attr_set_rnn_weights_qparams(
343 primitive_attr_t *attr, int count, int mask, const float *scales) {
344 bool ok = !any_null(attr, scales) && count > 0 && mask >= 0;
346 return invalid_arguments;
348 return attr->rnn_weights_qparams_.set(count, mask, scales);
351 status_t mkldnn_post_ops_append_depthwise(post_ops_t *post_ops,
352 alg_kind_t kind, const float* weights_data, const float* biases_data) {
353 if (post_ops == nullptr)
354 return invalid_arguments;
356 return post_ops->append_depthwise(kind, weights_data, biases_data);
359 status_t mkldnn_post_ops_get_params_depthwise(const post_ops_t *post_ops,
360 int index, alg_kind_t *alg, const float** weights_data, const float** biases_data) {
362 && simple_get_params_check(post_ops, index, primitive_kind::depthwise)
363 && !any_null(alg, weights_data, biases_data);
365 return invalid_arguments;
367 const auto &e = post_ops->entry_[index].depthwise;
369 *weights_data = e.weights_data;
370 *biases_data = e.biases_data;
375 status_t mkldnn_post_ops_append_dw_conv(post_ops_t *post_ops,
376 int in_h, int in_w, int ker_h, int ker_w, int str_h, int str_w,
377 const float* weights_data,
378 const float* biases_data) {
379 if (post_ops == nullptr)
380 return invalid_arguments;
382 return post_ops->append_dw_conv(in_h, in_w, ker_h, ker_w, str_h, str_w, weights_data, biases_data);
385 status_t mkldnn_post_ops_get_params_dw_conv(const post_ops_t *post_ops,
386 int index, int *in_h, int *in_w, int *ker_h, int *ker_w, int *str_h, int *str_w,
387 const float** weights_data,
388 const float** biases_data) {
390 && simple_get_params_check(post_ops, index, primitive_kind::convolution)
391 && !any_null(in_h, in_w, ker_h, ker_w, str_h, str_w, weights_data, biases_data);
393 return invalid_arguments;
395 const auto &e = post_ops->entry_[index].dw_conv;
402 *weights_data = e.weights_data;
403 *biases_data = e.biases_data;
408 status_t mkldnn_post_ops_append_binarization(post_ops_t *post_ops, alg_kind_t kind, const float* weights_data) {
409 if (post_ops == nullptr)
410 return invalid_arguments;
412 return post_ops->append_binarization(kind, weights_data);
415 status_t mkldnn_post_ops_get_params_binarization(const post_ops_t *post_ops, int index, alg_kind_t *alg,
416 const float** weights_data) {
418 && simple_get_params_check(post_ops, index, primitive_kind::binarization)
419 && !any_null(alg, weights_data);
421 return invalid_arguments;
423 const auto &e = post_ops->entry_[index].binarization;
425 *weights_data = e.weights_data;