Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / common / primitive_attr.hpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef PRIMITIVE_ATTR_HPP
18 #define PRIMITIVE_ATTR_HPP
19
20 #include <mkldnn.hpp>
21 #include "mkldnn.h"
22
23 #include "c_types_map.hpp"
24 #include "nstl.hpp"
25 #include "utils.hpp"
26
27 namespace mkldnn {
28 namespace impl {
29
30 struct rnn_data_qparams_t : public c_compatible {
31     rnn_data_qparams_t() : scale_(1.), shift_(0.) {}
32     bool has_default_values() const { return (scale_ == 1. && shift_ == 0.); }
33
34     status_t set(float scale, float shift) {
35         scale_ = scale;
36         shift_ = shift;
37         return status::success;
38     }
39
40     float scale_;
41     float shift_;
42 };
43
44 struct scales_t: public c_compatible {
45     scales_t(): count_(1), mask_(0), scales_(scales_buf_)
46     { set(1.); }
47
48     scales_t(const scales_t &rhs): scales_t()
49     { set(rhs.count_, rhs.mask_, rhs.scales_); }
50
51     ~scales_t() { cleanup(); }
52
53     scales_t &operator=(const scales_t &rhs) {
54         if (&rhs == this)
55             return *this;
56         status_t status = set(rhs.count_, rhs.mask_, rhs.scales_);
57         assert(status == status::success);
58         (void)status;
59         return *this;
60     }
61
62     bool has_default_values() const {
63         for (int c = 0; c < count_; ++c) {
64             if(scales_[c] != 1.) return false;
65         }
66         return true;
67     }
68
69     status_t set(int count, int mask, const float *scales);
70     status_t set(float single_scale) { return this->set(1, 0, &single_scale); }
71
72     int count_;
73     int mask_;
74     float *scales_;
75
76 private:
77     enum { scales_buf_size = 16 };
78     float scales_buf_[scales_buf_size];
79
80     void cleanup() {
81         if (scales_ != scales_buf_ && scales_ != nullptr)
82             impl::free(scales_);
83
84         count_ = 1;
85         mask_ = 0;
86         scales_ = scales_buf_;
87     }
88 };
89
90 }
91 }
92
93 struct mkldnn_post_ops: public mkldnn::impl::c_compatible {
94     struct entry_t {
95         struct eltwise_t {
96             mkldnn::impl::alg_kind_t alg;
97             float scale, alpha, beta;
98         };
99
100         mkldnn::impl::primitive_kind_t kind;
101         union {
102             struct { float scale; } sum;
103             eltwise_t eltwise;
104             struct {
105                 mkldnn::impl::alg_kind_t alg;
106                 const float* weights_data;
107                 const float* biases_data;
108             } depthwise;
109             struct {
110                 int in_h;
111                 int in_w;
112                 int ker_h;
113                 int ker_w;
114                 int str_h;
115                 int str_w;
116                 const float* weights_data;
117                 const float* biases_data;
118             } dw_conv;
119             struct {
120                 mkldnn::impl::alg_kind_t alg;
121                 const float* weights_data;
122             } binarization;
123         };
124
125         bool is_eltwise(bool require_scale_one = true) const {
126             using namespace mkldnn::impl;
127             return kind == primitive_kind::eltwise
128                 && IMPLICATION(require_scale_one, eltwise.scale == 1.f);
129         }
130
131         bool is_relu(bool require_scale_one = true,
132                 bool require_nslope_zero = true) const {
133             using namespace mkldnn::impl;
134             return is_eltwise(require_scale_one)
135                 && eltwise.alg == alg_kind::eltwise_relu
136                 && IMPLICATION(require_nslope_zero, eltwise.alpha == 0.f);
137         }
138
139         bool is_sum(bool require_scale_one = true) const {
140             using namespace mkldnn::impl;
141             return kind == primitive_kind::sum
142                 && IMPLICATION(require_scale_one, sum.scale == 1.f);
143         }
144
145         bool is_depthwise() const {
146             using namespace mkldnn::impl;
147             return kind == primitive_kind::depthwise;
148         }
149
150         bool is_dw_conv() const {
151             using namespace mkldnn::impl;
152             return kind == primitive_kind::convolution;
153         }
154         bool is_binarization() const {
155             using namespace mkldnn::impl;
156             return kind == primitive_kind::binarization;
157         }
158     };
159
160     mkldnn_post_ops(): len_(0) {}
161
162     mkldnn::impl::status_t append_sum(float scale);
163     mkldnn::impl::status_t append_eltwise(float scale,
164             mkldnn::impl::alg_kind_t alg, float alpha, float beta);
165     mkldnn::impl::status_t append_depthwise(mkldnn::impl::alg_kind_t alg,
166             const float* weights_data, const float* biases_data);
167     mkldnn::impl::status_t append_dw_conv(int in_h, int in_w, int ker_h, int ker_w, int str_h, int str_w,
168                                           const float* weights_data,
169                                           const float* biases_data);
170     mkldnn::impl::status_t append_binarization(mkldnn::impl::alg_kind_t alg, const float* weights_data);
171
172     int find(mkldnn::impl::primitive_kind_t kind, int start = 0,
173             int stop = -1) const {
174         if (stop == -1) stop = len_;
175         stop = mkldnn::impl::nstl::min(stop, len_);
176         for (int idx = start; idx < stop; ++idx)
177             if (entry_[idx].kind == kind) return idx;
178         return -1;
179     }
180
181     bool has_default_values() const { return len_ == 0; }
182
183     bool contain(mkldnn::impl::primitive_kind_t kind, int index) const
184     { return find(kind, index, index + 1) == index; }
185
186     enum { capacity = 4 };
187
188     int len_;
189     entry_t entry_[capacity];
190 };
191
192 struct mkldnn_primitive_attr: public mkldnn::impl::c_compatible {
193     mkldnn_primitive_attr()
194         : round_mode_(mkldnn::impl::round_mode::nearest) {}
195
196     mkldnn_primitive_attr *clone() const
197     { return new mkldnn_primitive_attr(*this); }
198
199     bool has_default_values() const {
200        return true
201             && round_mode_ == mkldnn::impl::round_mode::nearest
202             && output_scales_.has_default_values()
203             && post_ops_.has_default_values()
204             && rnn_data_qparams_.has_default_values()
205             && rnn_weights_qparams_.has_default_values();
206     }
207
208     mkldnn::impl::status_t set_round_mode(
209             mkldnn::impl::round_mode_t round_mode);
210     mkldnn::impl::status_t set_post_ops(
211             const mkldnn::impl::post_ops_t &post_ops);
212
213     mkldnn::impl::round_mode_t round_mode_;
214     mkldnn::impl::scales_t output_scales_;
215     mkldnn::impl::post_ops_t post_ops_;
216     mkldnn::impl::rnn_data_qparams_t rnn_data_qparams_;
217     mkldnn::impl::scales_t rnn_weights_qparams_;
218 };
219
220 #endif