Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / common / primitive_attr.cpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "mkldnn.h"
18
19 #include "c_types_map.hpp"
20 #include "primitive_attr.hpp"
21 #include "type_helpers.hpp"
22 #include "utils.hpp"
23
24 using namespace mkldnn::impl;
25 using namespace mkldnn::impl::status;
26 using namespace mkldnn::impl::utils;
27
28 namespace mkldnn {
29 namespace impl {
30
31 status_t scales_t::set(int count, int mask, const float *scales) {
32     cleanup();
33
34     count_ = count;
35     mask_ = mask;
36
37     if (count_ == 1) {
38         scales_ = scales_buf_;
39         utils::array_set(scales_, scales[0], scales_buf_size);
40     } else {
41         scales_ = (float *)impl::malloc(count_ * sizeof(*scales_), 64);
42         if (scales_ == nullptr)
43             return status::out_of_memory;
44
45         for (int c = 0; c < count_; ++c)
46             scales_[c] = scales[c];
47     }
48
49     return status::success;
50 }
51
52 }
53 }
54
55 status_t post_ops_t::append_sum(float scale) {
56     if (len_ == capacity)
57         return out_of_memory;
58
59     entry_[len_].kind = primitive_kind::sum;
60     entry_[len_].sum.scale = scale;
61
62     len_++;
63
64     return success;
65 }
66
67 status_t post_ops_t::append_eltwise(float scale, alg_kind_t alg, float alpha,
68         float beta) {
69     using namespace mkldnn::impl::alg_kind;
70     bool known_alg = one_of(alg, eltwise_relu, eltwise_tanh, eltwise_elu,
71             eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
72             eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic,
73             eltwise_clamp, eltwise_exp, eltwise_not);
74     if (!known_alg)
75         return invalid_arguments;
76
77     if (len_ == capacity)
78         return out_of_memory;
79
80     entry_[len_].kind = primitive_kind::eltwise;
81     entry_[len_].eltwise.scale = scale;
82     entry_[len_].eltwise.alg = alg;
83     entry_[len_].eltwise.alpha = alpha;
84     entry_[len_].eltwise.beta = beta;
85
86     len_++;
87
88     return success;
89 }
90
91 status_t post_ops_t::append_depthwise(alg_kind_t alg,
92         const float* weights_data, const float* biases_data) {
93     using namespace mkldnn::impl::alg_kind;
94     bool known_alg = one_of(alg, depthwise_scale_shift, depthwise_prelu);
95     if (!known_alg)
96         return invalid_arguments;
97
98     if (len_ == capacity)
99         return out_of_memory;
100
101     entry_[len_].kind = primitive_kind::depthwise;
102     entry_[len_].depthwise.alg = alg;
103     entry_[len_].depthwise.weights_data = weights_data;
104     entry_[len_].depthwise.biases_data = biases_data;
105
106     len_++;
107
108     return success;
109 }
110
111 status_t post_ops_t::append_dw_conv(int in_h, int in_w, int ker_h, int ker_w, int str_h, int str_w,
112                                     const float* weights_data,
113                                     const float* biases_data) {
114     if (len_ == capacity)
115         return out_of_memory;
116
117     entry_[len_].kind = primitive_kind::convolution;
118     entry_[len_].dw_conv.in_h = in_h;
119     entry_[len_].dw_conv.in_w = in_w;
120     entry_[len_].dw_conv.ker_h = ker_h;
121     entry_[len_].dw_conv.ker_w = ker_w;
122     entry_[len_].dw_conv.str_h = str_h;
123     entry_[len_].dw_conv.str_w = str_w;
124     entry_[len_].dw_conv.weights_data = weights_data;
125     entry_[len_].dw_conv.biases_data = biases_data;
126
127     len_++;
128
129     return success;
130 }
131
132 status_t post_ops_t::append_binarization(alg_kind_t alg, const float* weights_data) {
133     using namespace mkldnn::impl::alg_kind;
134     bool known_alg = one_of(alg, binarization_depthwise);
135     if (!known_alg)
136         return invalid_arguments;
137
138     if (len_ == capacity)
139         return out_of_memory;
140
141     entry_[len_].kind = primitive_kind::binarization;
142     entry_[len_].binarization.alg = alg;
143     entry_[len_].binarization.weights_data = weights_data;
144
145     len_++;
146
147     return success;
148 }
149
150 status_t primitive_attr_t::set_round_mode(round_mode_t round_mode) {
151     using namespace mkldnn::impl::round_mode;
152
153     const bool ok = one_of(round_mode, nearest, down);
154     if (!ok)
155         return invalid_arguments;
156
157     round_mode_ = round_mode;
158     return success;
159 }
160
161 status_t primitive_attr_t::set_post_ops(const post_ops_t &post_ops) {
162     this->post_ops_ = post_ops;
163     return success;
164 }
165
166 /* Public C API */
167
168 status_t mkldnn_primitive_attr_create(primitive_attr_t **attr) {
169     if (attr == nullptr)
170         return invalid_arguments;
171
172     return safe_ptr_assign<mkldnn_primitive_attr>(*attr,
173             new mkldnn_primitive_attr);
174 }
175
176 status_t mkldnn_primitive_attr_clone(primitive_attr_t **attr,
177         const primitive_attr_t *existing_attr) {
178     if (any_null(attr, existing_attr))
179         return invalid_arguments;
180
181     return safe_ptr_assign<mkldnn_primitive_attr>(*attr,
182             existing_attr->clone());
183 }
184
185 status_t mkldnn_primitive_attr_destroy(primitive_attr_t *attr) {
186     if (attr)
187         delete attr;
188
189     return success;
190 }
191
192 status_t mkldnn_primitive_attr_get_int_output_round_mode(
193         const primitive_attr_t *attr, round_mode_t *round_mode) {
194     if (any_null(attr, round_mode))
195         return invalid_arguments;
196
197     *round_mode = attr->round_mode_;
198
199     return success;
200 }
201
202 status_t mkldnn_primitive_attr_set_int_output_round_mode(
203         primitive_attr_t *attr, round_mode_t round_mode) {
204     if (any_null(attr))
205         return invalid_arguments;
206
207     return attr->set_round_mode(round_mode);
208 }
209
210 status_t mkldnn_primitive_attr_get_output_scales(const primitive_attr_t *attr,
211         int *count, int *mask, const float **scales) {
212     if (any_null(attr, count, mask, scales))
213         return invalid_arguments;
214
215     *count = attr->output_scales_.count_;
216     *mask = attr->output_scales_.mask_;
217     *scales = attr->output_scales_.scales_;
218
219     return success;
220 }
221
222 status_t mkldnn_primitive_attr_set_output_scales(primitive_attr_t *attr,
223         int count, int mask, const float *scales) {
224     bool ok = !any_null(attr, scales) && count > 0 && mask >= 0;
225     if (!ok)
226         return invalid_arguments;
227
228     return attr->output_scales_.set(count, mask, scales);
229 }
230
231 status_t mkldnn_primitive_attr_get_post_ops(const primitive_attr_t *attr,
232         const post_ops_t **post_ops) {
233     if (any_null(attr, post_ops))
234         return invalid_arguments;
235
236     *post_ops = &attr->post_ops_;
237     return success;
238 }
239
240 status_t mkldnn_primitive_attr_set_post_ops(primitive_attr_t *attr,
241         const post_ops_t *post_ops) {
242     if (any_null(attr, post_ops))
243         return invalid_arguments;
244
245     return attr->set_post_ops(*post_ops);
246 }
247
248 status_t mkldnn_post_ops_create(post_ops_t **post_ops) {
249     if (post_ops == nullptr)
250         return invalid_arguments;
251
252     return safe_ptr_assign<mkldnn_post_ops>(*post_ops, new mkldnn_post_ops);
253 }
254
255 status_t mkldnn_post_ops_destroy(post_ops_t *post_ops) {
256     if (post_ops)
257         delete post_ops;
258
259     return success;
260 }
261
262 int mkldnn_post_ops_len(const post_ops_t *post_ops) {
263     if (post_ops)
264         return post_ops->len_;
265
266     return 0;
267 }
268
269 primitive_kind_t mkldnn_post_ops_get_kind(const post_ops_t *post_ops,
270         int index) {
271     bool ok = post_ops && 0 <= index && index < post_ops->len_;
272     if (!ok)
273         return primitive_kind::undefined;
274
275     return post_ops->entry_[index].kind;
276 }
277
278 status_t mkldnn_post_ops_append_sum(post_ops_t *post_ops, float scale) {
279     if (post_ops == nullptr)
280         return invalid_arguments;
281
282     return post_ops->append_sum(scale);
283 }
284
285 namespace {
286 bool simple_get_params_check(const post_ops_t *post_ops, int index,
287         primitive_kind_t kind) {
288     bool ok = true
289         && post_ops != nullptr
290         && 0 <= index
291         && index < post_ops->len_
292         && post_ops->entry_[index].kind == kind;
293    return ok;
294 }
295 }
296
297 status_t mkldnn_post_ops_get_params_sum(const post_ops_t *post_ops, int index,
298         float *scale) {
299     bool ok = true
300         && simple_get_params_check(post_ops, index, primitive_kind::sum)
301         && !any_null(scale);
302     if (!ok)
303         return invalid_arguments;
304
305     *scale = post_ops->entry_[index].sum.scale;
306     return success;
307 }
308
309 status_t mkldnn_post_ops_append_eltwise(post_ops_t *post_ops, float scale,
310         alg_kind_t kind, float alpha, float beta) {
311     if (post_ops == nullptr)
312         return invalid_arguments;
313
314     return post_ops->append_eltwise(scale, kind, alpha, beta);
315 }
316
317 status_t mkldnn_post_ops_get_params_eltwise(const post_ops_t *post_ops,
318         int index, float *scale, alg_kind_t *alg, float *alpha, float *beta) {
319     bool ok = true
320         && simple_get_params_check(post_ops, index, primitive_kind::eltwise)
321         && !any_null(scale, alpha, beta);
322     if (!ok)
323         return invalid_arguments;
324
325     const auto &e = post_ops->entry_[index].eltwise;
326     *scale = e.scale;
327     *alg = e.alg;
328     *alpha = e.alpha;
329     *beta = e.beta;
330
331     return success;
332 }
333
334 status_t mkldnn_primitive_attr_set_rnn_data_qparams(
335         primitive_attr_t *attr, const float scale, const float shift) {
336     if (attr == nullptr)
337         return invalid_arguments;
338
339     return attr->rnn_data_qparams_.set(scale, shift);
340 }
341
342 status_t mkldnn_primitive_attr_set_rnn_weights_qparams(
343         primitive_attr_t *attr, int count, int mask, const float *scales) {
344     bool ok = !any_null(attr, scales) && count > 0 && mask >= 0;
345     if (!ok)
346         return invalid_arguments;
347
348     return attr->rnn_weights_qparams_.set(count, mask, scales);
349 }
350
351 status_t mkldnn_post_ops_append_depthwise(post_ops_t *post_ops,
352         alg_kind_t kind, const float* weights_data, const float* biases_data) {
353     if (post_ops == nullptr)
354         return invalid_arguments;
355
356     return post_ops->append_depthwise(kind, weights_data, biases_data);
357 }
358
359 status_t mkldnn_post_ops_get_params_depthwise(const post_ops_t *post_ops,
360         int index, alg_kind_t *alg, const float** weights_data, const float** biases_data) {
361     bool ok = true
362         && simple_get_params_check(post_ops, index, primitive_kind::depthwise)
363         && !any_null(alg, weights_data, biases_data);
364     if (!ok)
365         return invalid_arguments;
366
367     const auto &e = post_ops->entry_[index].depthwise;
368     *alg = e.alg;
369     *weights_data = e.weights_data;
370     *biases_data = e.biases_data;
371
372     return success;
373 }
374
375 status_t mkldnn_post_ops_append_dw_conv(post_ops_t *post_ops,
376                                         int in_h, int in_w, int ker_h, int ker_w, int str_h, int str_w,
377                                         const float* weights_data,
378                                         const float* biases_data) {
379     if (post_ops == nullptr)
380         return invalid_arguments;
381
382     return post_ops->append_dw_conv(in_h, in_w, ker_h, ker_w, str_h, str_w, weights_data, biases_data);
383 }
384
385 status_t mkldnn_post_ops_get_params_dw_conv(const post_ops_t *post_ops,
386                                             int index, int *in_h, int *in_w, int *ker_h, int *ker_w, int *str_h, int *str_w,
387                                             const float** weights_data,
388                                             const float** biases_data) {
389     bool ok = true
390               && simple_get_params_check(post_ops, index, primitive_kind::convolution)
391               && !any_null(in_h, in_w, ker_h, ker_w, str_h, str_w, weights_data, biases_data);
392     if (!ok)
393         return invalid_arguments;
394
395     const auto &e = post_ops->entry_[index].dw_conv;
396     *in_h = e.in_h;
397     *in_w = e.in_w;
398     *ker_h = e.ker_h;
399     *ker_w = e.ker_w;
400     *str_h = e.str_h;
401     *str_w = e.str_w;
402     *weights_data = e.weights_data;
403     *biases_data = e.biases_data;
404
405     return success;
406 }
407
408 status_t mkldnn_post_ops_append_binarization(post_ops_t *post_ops, alg_kind_t kind, const float* weights_data) {
409     if (post_ops == nullptr)
410         return invalid_arguments;
411
412     return post_ops->append_binarization(kind, weights_data);
413 }
414
415 status_t mkldnn_post_ops_get_params_binarization(const post_ops_t *post_ops, int index, alg_kind_t *alg,
416         const float** weights_data) {
417     bool ok = true
418         && simple_get_params_check(post_ops, index, primitive_kind::binarization)
419         && !any_null(alg, weights_data);
420     if (!ok)
421         return invalid_arguments;
422
423     const auto &e = post_ops->entry_[index].binarization;
424     *alg = e.alg;
425     *weights_data = e.weights_data;
426
427     return success;
428 }