Publishing R3
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / common / convolution_pd.hpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef CONVOLUTION_PD_HPP
18 #define CONVOLUTION_PD_HPP
19
20 #include "mkldnn.h"
21
22 #include "c_types_map.hpp"
23 #include "primitive_desc.hpp"
24 #include "memory_pd.hpp"
25 #include "utils.hpp"
26
27 namespace mkldnn {
28 namespace impl {
29
30 status_t conv_desc_init(convolution_desc_t *conv_desc,
31         prop_kind_t prop_kind, alg_kind_t alg_kind,
32         const memory_desc_t *src_desc, const memory_desc_t *weights_desc,
33         const memory_desc_t *bias_desc, const memory_desc_t *dst_desc,
34         const dims_t strides, const dims_t dilates,
35         const dims_t padding_l, const dims_t padding_r,
36         padding_kind_t padding_kind);
37
38 template <bool with_relu>
39 struct _convolution_fwd_pd_t: public primitive_desc_t {
40     typedef _convolution_fwd_pd_t base_class;
41     typedef _convolution_fwd_pd_t hint_class;
42     typedef typename utils::conditional<with_relu,
43             convolution_relu_desc_t, convolution_desc_t>::type base_desc_t;
44     static constexpr auto base_pkind =
45         utils::conditional_v<with_relu, primitive_kind_t,
46         primitive_kind::convolution_relu, primitive_kind::convolution>::value;
47
48     _convolution_fwd_pd_t(mkldnn::impl::engine_t *engine,
49             const base_desc_t *adesc, const primitive_attr_t *attr,
50             const _convolution_fwd_pd_t *hint_fwd_pd)
51         : primitive_desc_t(engine, attr, base_pkind), desc_(*adesc)
52         , hint_fwd_pd_(hint_fwd_pd) {}
53     virtual ~_convolution_fwd_pd_t() {}
54
55     const base_desc_t *desc() const { return &desc_; }
56     inline const convolution_desc_t *cdesc() const { return &cdesc_(); }
57     virtual const op_desc_t *op_desc() const override
58     { return reinterpret_cast<const op_desc_t *>(this->desc()); }
59     virtual void init_info() override { init_info_conv(this, this->info_); }
60
61     virtual const memory_pd_t *input_pd(int index = 0) const override {
62         switch (index) {
63         case 0: return src_pd();
64         case 1: case 2: return weights_pd(index - 1);
65         default: return nullptr;
66         }
67     }
68     virtual const memory_pd_t *output_pd(int index = 0) const override
69     { return index == 0 ? dst_pd() : nullptr; }
70
71     virtual int n_inputs() const override { return 2 + with_bias(); }
72     virtual int n_outputs() const override { return 1; }
73
74     virtual status_t query(query_t what, int idx, void *result) const override
75     {
76         switch (what) {
77         case pkind_traits<base_pkind>::query_d:
78             *(const base_desc_t**)result = desc(); break;
79         default: return primitive_desc_t::query(what, idx, result);
80         }
81         return status::success;
82     }
83
84     /* common conv aux functions */
85
86     inline int MB() const { return input_pd()->desc()->dims[0]; }
87
88     inline int IC() const { return input_pd()->desc()->dims[1]; }
89     inline int OC() const { return output_pd()->desc()->dims[1]; }
90     inline int G() const
91     { return with_groups() ? cdesc_().weights_desc.dims[0] : 1; }
92
93     inline int ID() const { return (ndims() == 5)
94         ? input_pd()->desc()->dims[2] : 1; }
95     inline int IH() const { return input_pd()->desc()->dims[ndims()-2]; }
96     inline int IW() const { return input_pd()->desc()->dims[ndims()-1]; }
97     inline int OD() const { return (ndims() == 5)
98         ? output_pd()->desc()->dims[2] : 1; }
99     inline int OH() const { return output_pd()->desc()->dims[ndims()-2]; }
100     inline int OW() const { return output_pd()->desc()->dims[ndims()-1]; }
101     inline int KD() const { return (ndims() == 5)
102         ? cdesc_().weights_desc.dims[2 + with_groups()] : 1; }
103     inline int KH() const
104     { return cdesc_().weights_desc.dims[ndims() - (2 - with_groups())]; }
105     inline int KW() const
106     { return cdesc_().weights_desc.dims[ndims() - (1 - with_groups())]; }
107
108     inline int KSD() const { return (ndims() == 5) ? cdesc_().strides[0] : 1; }
109     inline int KSH() const { return cdesc_().strides[ndims()-4]; }
110     inline int KSW() const { return cdesc_().strides[ndims()-3]; }
111
112     inline int KDD() const { return (ndims() == 5) ? cdesc_().dilates[0] : 0; }
113     inline int KDH() const { return cdesc_().dilates[ndims()-4]; }
114     inline int KDW() const { return cdesc_().dilates[ndims()-3]; }
115
116     inline int padFront() const
117         { return (ndims() == 5) ? cdesc_().padding[0][0] : 0; }
118     inline int padBack() const
119         { return (ndims() == 5) ? cdesc_().padding[1][0] : 0; }
120     inline int padT() const { return cdesc_().padding[0][ndims()-4]; }
121     inline int padB() const { return cdesc_().padding[1][ndims()-4]; }
122     inline int padL() const { return cdesc_().padding[0][ndims()-3]; }
123     inline int padR() const { return cdesc_().padding[1][ndims()-3]; }
124
125     inline float negative_slope() const;
126
127     inline bool with_bias() const
128     { return !memory_desc_wrapper(cdesc_().bias_desc).is_zero(); }
129     inline bool with_groups() const
130     { return cdesc_().weights_desc.ndims == cdesc_().src_desc.ndims + 1; }
131
132     inline int ndims() const { return cdesc_().src_desc.ndims; }
133
134 protected:
135     base_desc_t desc_;
136     const _convolution_fwd_pd_t *hint_fwd_pd_;
137
138     inline const convolution_desc_t &cdesc_() const;
139
140     virtual status_t init() = 0;
141 };
142
143 using convolution_fwd_pd_t = mkldnn::impl::_convolution_fwd_pd_t<false>;
144 using convolution_relu_fwd_pd_t = mkldnn::impl::_convolution_fwd_pd_t<true>;
145
146 template<> inline float convolution_fwd_pd_t::negative_slope() const
147 { return 0.; }
148 template<> inline float convolution_relu_fwd_pd_t::negative_slope() const
149 { return desc()->negative_slope; }
150
151 template<bool with_relu> inline const
152 convolution_desc_t &_convolution_fwd_pd_t<with_relu>::cdesc_() const
153 { return desc_; }
154 template<>
155 inline const convolution_desc_t &convolution_relu_fwd_pd_t::cdesc_() const
156 { return desc_.convolution_desc; }
157
158 struct convolution_bwd_data_pd_t: public primitive_desc_t {
159     typedef convolution_bwd_data_pd_t base_class;
160     typedef convolution_fwd_pd_t hint_class;
161     static constexpr auto base_pkind = primitive_kind::convolution;
162
163     convolution_bwd_data_pd_t(mkldnn::impl::engine_t *engine,
164             const convolution_desc_t *adesc,
165             const primitive_attr_t *attr,
166             const convolution_fwd_pd_t *hint_fwd_pd)
167         : primitive_desc_t(engine, attr, base_pkind), desc_(*adesc)
168         , hint_fwd_pd_(hint_fwd_pd) {}
169     virtual ~convolution_bwd_data_pd_t() {}
170
171     const convolution_desc_t *desc() const { return &desc_; }
172     const convolution_desc_t *cdesc() const { return desc(); }
173     virtual const op_desc_t *op_desc() const override
174     { return reinterpret_cast<const op_desc_t *>(this->desc()); }
175     virtual void init_info() override { init_info_conv(this, this->info_); }
176
177     virtual const memory_pd_t *input_pd(int index = 0) const override {
178         switch (index) {
179         case 0: return diff_dst_pd();
180         case 1: return weights_pd(0);
181         default: return nullptr;
182         }
183     }
184     virtual const memory_pd_t *output_pd(int index = 0) const override
185     { return index == 0 ? diff_src_pd() : nullptr; }
186
187     virtual int n_inputs() const override { return 2; }
188     virtual int n_outputs() const override { return 1; }
189
190     virtual status_t query(query_t what, int idx, void *result) const override
191     {
192         switch (what) {
193         case query::convolution_d:
194             *(const convolution_desc_t**)result = desc(); break;
195         default: return primitive_desc_t::query(what, idx, result);
196         }
197         return status::success;
198     }
199
200     /* common conv aux functions */
201
202     inline int MB() const { return output_pd()->desc()->dims[0]; }
203     inline int IC() const { return output_pd()->desc()->dims[1]; }
204     inline int OC() const { return input_pd()->desc()->dims[1]; }
205     inline int G() const
206     { return with_groups() ? desc_.weights_desc.dims[0] : 1; }
207
208     inline int ID() const { return (ndims() == 5)
209         ? output_pd()->desc()->dims[2] : 1; }
210     inline int IH() const { return output_pd()->desc()->dims[ndims()-2]; }
211     inline int IW() const { return output_pd()->desc()->dims[ndims()-1]; }
212     inline int OD() const { return (ndims() == 5)
213         ? input_pd()->desc()->dims[2] : 1; }
214     inline int OH() const { return input_pd()->desc()->dims[ndims()-2]; }
215     inline int OW() const { return input_pd()->desc()->dims[ndims()-1]; }
216     inline int KD() const { return (ndims() == 5)
217         ? desc_.weights_desc.dims[2 + with_groups()] : 1; }
218     inline int KH() const
219     { return desc_.weights_desc.dims[ndims() - (2 - with_groups())]; }
220     inline int KW() const
221     { return desc_.weights_desc.dims[ndims() - (1 - with_groups())]; }
222
223     inline int KSD() const { return (ndims() == 5) ? desc_.strides[0] : 1; }
224     inline int KSH() const { return desc_.strides[ndims()-4]; }
225     inline int KSW() const { return desc_.strides[ndims()-3]; }
226
227     inline int KDD() const { return (ndims() == 5) ? desc_.dilates[0] : 0; }
228     inline int KDH() const { return desc_.dilates[ndims()-4]; }
229     inline int KDW() const { return desc_.dilates[ndims()-3]; }
230
231     inline int padFront() const
232         { return (ndims() == 5) ? desc_.padding[0][0] : 0; }
233     inline int padBack() const
234         { return (ndims() == 5) ? desc_.padding[1][0] : 0; }
235     inline int padT() const { return desc_.padding[0][ndims()-4]; }
236     inline int padB() const { return desc_.padding[1][ndims()-4]; }
237     inline int padL() const { return desc_.padding[0][ndims()-3]; }
238     inline int padR() const { return desc_.padding[1][ndims()-3]; }
239
240     inline bool with_bias() const
241     { return !memory_desc_wrapper(desc_.bias_desc).is_zero(); }
242     inline bool with_groups() const
243     { return desc_.weights_desc.ndims == desc_.diff_src_desc.ndims + 1; }
244
245     inline int ndims() const { return desc_.diff_src_desc.ndims; }
246
247 protected:
248     convolution_desc_t desc_;
249     const convolution_fwd_pd_t *hint_fwd_pd_;
250
251     virtual status_t init() = 0;
252 };
253
254 struct convolution_bwd_weights_pd_t: public primitive_desc_t {
255     typedef convolution_bwd_weights_pd_t base_class;
256     typedef convolution_fwd_pd_t hint_class;
257     static constexpr auto base_pkind = primitive_kind::convolution;
258
259     convolution_bwd_weights_pd_t(mkldnn::impl::engine_t *engine,
260             const convolution_desc_t *adesc,
261             const primitive_attr_t *attr,
262             const convolution_fwd_pd_t *hint_fwd_pd)
263         : primitive_desc_t(engine, attr, base_pkind), desc_(*adesc)
264         , hint_fwd_pd_(hint_fwd_pd) {}
265     virtual ~convolution_bwd_weights_pd_t() {}
266
267     const convolution_desc_t *desc() const { return &desc_; }
268     const convolution_desc_t *cdesc() const { return desc(); }
269     virtual const op_desc_t *op_desc() const override
270     { return reinterpret_cast<const op_desc_t *>(this->desc()); }
271     virtual void init_info() override { init_info_conv(this, this->info_); }
272
273     virtual const memory_pd_t *input_pd(int index = 0) const override {
274         switch (index) {
275         case 0: return src_pd();
276         case 1: return diff_dst_pd();
277         default: return nullptr;
278         }
279     }
280     virtual const memory_pd_t *output_pd(int index = 0) const override {
281         switch (index) {
282         case 0: return diff_weights_pd(0);
283         case 1: return with_bias() ? diff_weights_pd(1) : nullptr;
284         default: return nullptr;
285         }
286     }
287
288     virtual int n_inputs() const override { return 2; }
289     virtual int n_outputs() const override { return 1 + with_bias(); }
290
291     virtual status_t query(query_t what, int idx, void *result) const override
292     {
293         switch (what) {
294         case query::convolution_d:
295             *(const convolution_desc_t**)result = desc(); break;
296         default: return primitive_desc_t::query(what, idx, result);
297         }
298         return status::success;
299     }
300
301     /* common conv aux functions */
302
303     inline int MB() const { return desc_.src_desc.dims[0]; }
304
305     inline int IC() const { return desc_.src_desc.dims[1]; }
306     inline int OC() const { return desc_.diff_dst_desc.dims[1]; }
307     inline int G() const
308     { return with_groups() ? desc_.diff_weights_desc.dims[0] : 1; }
309
310     inline int ID() const { return (ndims() == 5)
311         ? desc_.src_desc.dims[2] : 1; }
312     inline int IH() const { return desc_.src_desc.dims[ndims()-2]; }
313     inline int IW() const { return desc_.src_desc.dims[ndims()-1]; }
314     inline int OD() const { return (ndims() == 5)
315         ? desc_.diff_dst_desc.dims[2] : 1; }
316     inline int OH() const { return desc_.diff_dst_desc.dims[ndims()-2]; }
317     inline int OW() const { return desc_.diff_dst_desc.dims[ndims()-1]; }
318     inline int KD() const { return (ndims() == 5)
319         ? desc_.diff_weights_desc.dims[2 + with_groups()] : 1; }
320     inline int KH() const
321     { return desc_.diff_weights_desc.dims[ndims() - (2 - with_groups())]; }
322     inline int KW() const
323     { return desc_.diff_weights_desc.dims[ndims() - (1 - with_groups())]; }
324
325     inline int KSD() const { return (ndims() == 5) ? desc_.strides[0] : 1; }
326     inline int KSH() const { return desc_.strides[ndims()-4]; }
327     inline int KSW() const { return desc_.strides[ndims()-3]; }
328
329     inline int KDD() const { return (ndims() == 5) ? desc_.dilates[0] : 0; }
330     inline int KDH() const { return desc_.dilates[ndims()-4]; }
331     inline int KDW() const { return desc_.dilates[ndims()-3]; }
332
333     inline int padFront() const
334         { return (ndims() == 5) ? desc_.padding[0][0] : 0; }
335     inline int padBack() const
336         { return (ndims() == 5) ? desc_.padding[1][0] : 0; }
337     inline int padT() const { return desc_.padding[0][ndims()-4]; }
338     inline int padB() const { return desc_.padding[1][ndims()-4]; }
339     inline int padL() const { return desc_.padding[0][ndims()-3]; }
340     inline int padR() const { return desc_.padding[1][ndims()-3]; }
341
342     inline bool with_bias() const
343     { return !memory_desc_wrapper(desc_.diff_bias_desc).is_zero(); }
344     inline bool with_groups() const
345     { return desc_.diff_weights_desc.ndims == desc_.diff_dst_desc.ndims + 1; }
346
347     inline int ndims() const { return desc_.src_desc.ndims; }
348
349 protected:
350     convolution_desc_t desc_;
351     const convolution_fwd_pd_t *hint_fwd_pd_;
352
353     virtual status_t init() = 0;
354 };
355
356 }
357 }
358
359 #endif
360
361 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s