Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / common / convolution_pd.hpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef CONVOLUTION_PD_HPP
18 #define CONVOLUTION_PD_HPP
19
20 #include "mkldnn.h"
21
22 #include "c_types_map.hpp"
23 #include "primitive_desc.hpp"
24 #include "memory_pd.hpp"
25 #include "utils.hpp"
26
27 namespace mkldnn {
28 namespace impl {
29
30 status_t conv_desc_init(convolution_desc_t *conv_desc,
31         prop_kind_t prop_kind, alg_kind_t alg_kind,
32         const memory_desc_t *src_desc, const memory_desc_t *weights_desc,
33         const memory_desc_t *bias_desc, const memory_desc_t *dst_desc,
34         const dims_t strides, const dims_t dilates,
35         const dims_t padding_l, const dims_t padding_r,
36         padding_kind_t padding_kind);
37
38 memory_desc_t *conv_prop_agnostic_src_d(convolution_desc_t *desc);
39 memory_desc_t *conv_prop_agnostic_wei_d(convolution_desc_t *desc);
40 memory_desc_t *conv_prop_agnostic_bia_d(convolution_desc_t *desc);
41 memory_desc_t *conv_prop_agnostic_dst_d(convolution_desc_t *desc);
42 const memory_desc_t *conv_prop_agnostic_src_d(const convolution_desc_t *desc);
43 const memory_desc_t *conv_prop_agnostic_wei_d(const convolution_desc_t *desc);
44 const memory_desc_t *conv_prop_agnostic_bia_d(const convolution_desc_t *desc);
45 const memory_desc_t *conv_prop_agnostic_dst_d(const convolution_desc_t *desc);
46
47 struct convolution_fwd_pd_t: public primitive_desc_t {
48     typedef convolution_fwd_pd_t base_class;
49     typedef convolution_fwd_pd_t hint_class;
50     static constexpr auto base_pkind = primitive_kind::convolution;
51
52     convolution_fwd_pd_t(mkldnn::impl::engine_t *engine,
53             const convolution_desc_t *adesc, const primitive_attr_t *attr,
54             const convolution_fwd_pd_t *hint_fwd_pd)
55         : primitive_desc_t(engine, attr, base_pkind), desc_(*adesc)
56         , hint_fwd_pd_(hint_fwd_pd) {}
57     virtual ~convolution_fwd_pd_t() {}
58
59     const convolution_desc_t *desc() const { return &desc_; }
60     virtual const op_desc_t *op_desc() const override
61     { return reinterpret_cast<const op_desc_t *>(this->desc()); }
62     virtual void init_info() override { init_info_conv(this, this->info_); }
63
64     virtual const memory_pd_t *input_pd(int index = 0) const override {
65         switch (index) {
66         case 0: return src_pd();
67         case 1: case 2: return weights_pd(index - 1);
68         default: return nullptr;
69         }
70     }
71     virtual const memory_pd_t *output_pd(int index = 0) const override
72     { return index == 0 ? dst_pd() : nullptr; }
73
74     virtual int n_inputs() const override { return 2 + with_bias(); }
75     virtual int n_outputs() const override { return 1; }
76
77     virtual status_t query(query_t what, int idx, void *result) const override
78     {
79         switch (what) {
80         case pkind_traits<base_pkind>::query_d:
81             *(const convolution_desc_t**)result = desc(); break;
82         default: return primitive_desc_t::query(what, idx, result);
83         }
84         return status::success;
85     }
86
87     /* common conv aux functions */
88
89     inline int MB() const { return input_pd()->desc()->dims[0]; }
90
91     inline int IC() const { return input_pd()->desc()->dims[1]; }
92     inline int OC() const { return output_pd()->desc()->dims[1]; }
93     inline int G() const
94     { return with_groups() ? desc_.weights_desc.dims[0] : 1; }
95
96     inline int ID() const { return (ndims() == 5) ? input_pd()->desc()->dims[2] : 1; }
97     inline int IH() const { return (ndims() == 3) ? 1 : input_pd()->desc()->dims[ndims()-2]; }
98     inline int IW() const { return input_pd()->desc()->dims[ndims()-1]; }
99     inline int OD() const { return (ndims() == 5) ? output_pd()->desc()->dims[2] : 1; }
100     inline int OH() const { return (ndims() == 3) ? 1 : output_pd()->desc()->dims[ndims()-2]; }
101     inline int OW() const { return output_pd()->desc()->dims[ndims()-1]; }
102     inline int KD() const { return (ndims() == 5)
103         ? desc_.weights_desc.dims[2 + with_groups()] : 1; }
104     inline int KH() const
105     { return (ndims() == 3)
106         ? 1 : desc_.weights_desc.dims[ndims() - (2 - with_groups())]; }
107     inline int KW() const
108     { return desc_.weights_desc.dims[ndims() - (1 - with_groups())]; }
109
110     inline int KSD() const { return (ndims() == 5) ? desc_.strides[0] : 1; }
111     inline int KSH() const { return (ndims() == 3)
112         ? 1 : desc_.strides[ndims()-4]; }
113     inline int KSW() const { return desc_.strides[ndims()-3]; }
114
115     inline int KDD() const { return (ndims() == 5) ? desc_.dilates[0] : 0; }
116     inline int KDH() const { return (ndims() == 3)
117         ? 0 : desc_.dilates[ndims()-4]; }
118     inline int KDW() const { return desc_.dilates[ndims()-3]; }
119
120     inline int padFront() const
121         { return (ndims() == 5) ? desc_.padding[0][0] : 0; }
122     inline int padBack() const
123         { return (ndims() == 5) ? desc_.padding[1][0] : 0; }
124     inline int padT() const { return (ndims() == 3)
125         ? 0 : desc_.padding[0][ndims()-4]; }
126     inline int padB() const { return (ndims() == 3)
127         ? 0 : desc_.padding[1][ndims()-4]; }
128     inline int padL() const { return desc_.padding[0][ndims()-3]; }
129     inline int padR() const { return desc_.padding[1][ndims()-3]; }
130
131     inline bool with_bias() const
132     { return !memory_desc_wrapper(desc_.bias_desc).is_zero(); }
133     inline bool with_groups() const
134     { return desc_.weights_desc.ndims == desc_.src_desc.ndims + 1; }
135
136     inline int ndims() const { return desc_.src_desc.ndims; }
137
138     virtual status_t set_alg_kind(alg_kind_t alg) {
139         if (alg == alg_kind::undef) return status::invalid_arguments;
140         desc_.alg_kind = alg;
141         return status::success;
142     }
143
144     bool has_zero_dim_memory() const {
145         return false
146             || memory_desc_wrapper(desc_.src_desc).has_zero_dim()
147             || memory_desc_wrapper(desc_.dst_desc).has_zero_dim();
148     }
149
150
151 protected:
152     convolution_desc_t desc_;
153     const convolution_fwd_pd_t *hint_fwd_pd_;
154
155     virtual status_t init() = 0;
156 };
157
158 struct convolution_bwd_data_pd_t: public primitive_desc_t {
159     typedef convolution_bwd_data_pd_t base_class;
160     typedef convolution_fwd_pd_t hint_class;
161     static constexpr auto base_pkind = primitive_kind::convolution;
162
163     convolution_bwd_data_pd_t(mkldnn::impl::engine_t *engine,
164             const convolution_desc_t *adesc,
165             const primitive_attr_t *attr,
166             const convolution_fwd_pd_t *hint_fwd_pd)
167         : primitive_desc_t(engine, attr, base_pkind), desc_(*adesc)
168         , hint_fwd_pd_(hint_fwd_pd) {}
169     virtual ~convolution_bwd_data_pd_t() {}
170
171     const convolution_desc_t *desc() const { return &desc_; }
172     virtual const op_desc_t *op_desc() const override
173     { return reinterpret_cast<const op_desc_t *>(this->desc()); }
174     virtual void init_info() override { init_info_conv(this, this->info_); }
175
176     virtual const memory_pd_t *input_pd(int index = 0) const override {
177         switch (index) {
178         case 0: return diff_dst_pd();
179         case 1: return weights_pd(0);
180         default: return nullptr;
181         }
182     }
183     virtual const memory_pd_t *output_pd(int index = 0) const override
184     { return index == 0 ? diff_src_pd() : nullptr; }
185
186     virtual int n_inputs() const override { return 2 + with_bias(); }
187     virtual int n_outputs() const override { return 1; }
188
189     virtual status_t query(query_t what, int idx, void *result) const override
190     {
191         switch (what) {
192         case query::convolution_d:
193             *(const convolution_desc_t**)result = desc(); break;
194         default: return primitive_desc_t::query(what, idx, result);
195         }
196         return status::success;
197     }
198
199     /* common conv aux functions */
200
201     inline int MB() const { return output_pd()->desc()->dims[0]; }
202     inline int IC() const { return output_pd()->desc()->dims[1]; }
203     inline int OC() const { return input_pd()->desc()->dims[1]; }
204     inline int G() const
205     { return with_groups() ? desc_.weights_desc.dims[0] : 1; }
206
207     inline int ID() const { return (ndims() == 5) ? output_pd()->desc()->dims[2] : 1; }
208     inline int IH() const { return (ndims() == 3) ? 1 : output_pd()->desc()->dims[ndims()-2]; }
209     inline int IW() const { return output_pd()->desc()->dims[ndims()-1]; }
210     inline int OD() const { return (ndims() == 5) ? input_pd()->desc()->dims[2] : 1; }
211     inline int OH() const { return (ndims() == 3) ? 1 : input_pd()->desc()->dims[ndims()-2]; }
212     inline int OW() const { return input_pd()->desc()->dims[ndims()-1]; }
213     inline int KD() const { return (ndims() == 5)
214         ? desc_.weights_desc.dims[2 + with_groups()] : 1; }
215     inline int KH() const
216     { return (ndims() == 3)
217         ? 1 : desc_.weights_desc.dims[ndims() - (2 - with_groups())]; }
218     inline int KW() const
219     { return desc_.weights_desc.dims[ndims() - (1 - with_groups())]; }
220
221     inline int KSD() const { return (ndims() == 5) ? desc_.strides[0] : 1; }
222     inline int KSH() const { return (ndims() == 3)
223         ? 1 : desc_.strides[ndims()-4]; }
224     inline int KSW() const { return desc_.strides[ndims()-3]; }
225
226     inline int KDD() const { return (ndims() == 5) ? desc_.dilates[0] : 0; }
227     inline int KDH() const { return (ndims() == 3)
228         ? 0 : desc_.dilates[ndims()-4]; }
229     inline int KDW() const { return desc_.dilates[ndims()-3]; }
230
231     inline int padFront() const
232         { return (ndims() == 5) ? desc_.padding[0][0] : 0; }
233     inline int padBack() const
234         { return (ndims() == 5) ? desc_.padding[1][0] : 0; }
235     inline int padT() const { return (ndims() == 3)
236         ? 0 : desc_.padding[0][ndims()-4]; }
237     inline int padB() const { return (ndims() == 3)
238         ? 0 : desc_.padding[1][ndims()-4]; }
239     inline int padL() const { return desc_.padding[0][ndims()-3]; }
240     inline int padR() const { return desc_.padding[1][ndims()-3]; }
241
242     inline bool with_bias() const
243     { return !memory_desc_wrapper(desc_.bias_desc).is_zero(); }
244     inline bool with_groups() const
245     { return desc_.weights_desc.ndims == desc_.diff_src_desc.ndims + 1; }
246
247     inline int ndims() const { return desc_.diff_src_desc.ndims; }
248     virtual bool support_bias() const { return false; }
249
250     virtual status_t set_alg_kind(alg_kind_t alg) {
251         if (alg == alg_kind::undef) return status::invalid_arguments;
252         desc_.alg_kind = alg;
253         return status::success;
254     }
255
256     bool has_zero_dim_memory() const {
257         return false
258             || memory_desc_wrapper(desc_.diff_src_desc).has_zero_dim()
259             || memory_desc_wrapper(desc_.diff_dst_desc).has_zero_dim();
260     }
261
262 protected:
263     convolution_desc_t desc_;
264     const convolution_fwd_pd_t *hint_fwd_pd_;
265
266     virtual status_t init() = 0;
267 };
268
269 struct convolution_bwd_weights_pd_t: public primitive_desc_t {
270     typedef convolution_bwd_weights_pd_t base_class;
271     typedef convolution_fwd_pd_t hint_class;
272     static constexpr auto base_pkind = primitive_kind::convolution;
273
274     convolution_bwd_weights_pd_t(mkldnn::impl::engine_t *engine,
275             const convolution_desc_t *adesc,
276             const primitive_attr_t *attr,
277             const convolution_fwd_pd_t *hint_fwd_pd)
278         : primitive_desc_t(engine, attr, base_pkind), desc_(*adesc)
279         , hint_fwd_pd_(hint_fwd_pd) {}
280     virtual ~convolution_bwd_weights_pd_t() {}
281
282     const convolution_desc_t *desc() const { return &desc_; }
283     virtual const op_desc_t *op_desc() const override
284     { return reinterpret_cast<const op_desc_t *>(this->desc()); }
285     virtual void init_info() override { init_info_conv(this, this->info_); }
286
287     virtual const memory_pd_t *input_pd(int index = 0) const override {
288         switch (index) {
289         case 0: return src_pd();
290         case 1: return diff_dst_pd();
291         default: return nullptr;
292         }
293     }
294     virtual const memory_pd_t *output_pd(int index = 0) const override {
295         switch (index) {
296         case 0: return diff_weights_pd(0);
297         case 1: return with_bias() ? diff_weights_pd(1) : nullptr;
298         default: return nullptr;
299         }
300     }
301
302     virtual int n_inputs() const override { return 2; }
303     virtual int n_outputs() const override { return 1 + with_bias(); }
304
305     virtual status_t query(query_t what, int idx, void *result) const override
306     {
307         switch (what) {
308         case query::convolution_d:
309             *(const convolution_desc_t**)result = desc(); break;
310         default: return primitive_desc_t::query(what, idx, result);
311         }
312         return status::success;
313     }
314
315     /* common conv aux functions */
316
317     inline int MB() const { return desc_.src_desc.dims[0]; }
318
319     inline int IC() const { return desc_.src_desc.dims[1]; }
320     inline int OC() const { return desc_.diff_dst_desc.dims[1]; }
321     inline int G() const
322     { return with_groups() ? desc_.diff_weights_desc.dims[0] : 1; }
323
324     inline int ID() const { return (ndims() == 5)
325         ? desc_.src_desc.dims[2] : 1; }
326     inline int IH() const { return (ndims() == 3)
327         ? 1 : desc_.src_desc.dims[ndims()-2]; }
328     inline int IW() const { return desc_.src_desc.dims[ndims()-1]; }
329     inline int OD() const { return (ndims() == 5)
330         ? desc_.diff_dst_desc.dims[2] : 1; }
331     inline int OH() const { return (ndims() == 3)
332         ? 1 : desc_.diff_dst_desc.dims[ndims()-2]; }
333     inline int OW() const { return desc_.diff_dst_desc.dims[ndims()-1]; }
334     inline int KD() const { return (ndims() == 5)
335         ? desc_.diff_weights_desc.dims[2 + with_groups()] : 1; }
336     inline int KH() const
337     { return (ndims() == 3)
338         ? 1 : desc_.diff_weights_desc.dims[ndims() - (2 - with_groups())]; }
339     inline int KW() const
340     { return desc_.diff_weights_desc.dims[ndims() - (1 - with_groups())]; }
341
342     inline int KSD() const { return (ndims() == 5) ? desc_.strides[0] : 1; }
343     inline int KSH() const { return (ndims() == 3)
344         ? 1 : desc_.strides[ndims()-4]; }
345     inline int KSW() const { return desc_.strides[ndims()-3]; }
346
347     inline int KDD() const { return (ndims() == 5) ? desc_.dilates[0] : 0; }
348     inline int KDH() const { return (ndims() == 3)
349         ? 0 : desc_.dilates[ndims()-4]; }
350     inline int KDW() const { return desc_.dilates[ndims()-3]; }
351
352     inline int padFront() const
353         { return (ndims() == 5) ? desc_.padding[0][0] : 0; }
354     inline int padBack() const
355         { return (ndims() == 5) ? desc_.padding[1][0] : 0; }
356     inline int padT() const { return (ndims() == 3)
357         ? 0 : desc_.padding[0][ndims()-4]; }
358     inline int padB() const { return (ndims() == 3)
359         ? 0 : desc_.padding[1][ndims()-4]; }
360     inline int padL() const { return desc_.padding[0][ndims()-3]; }
361     inline int padR() const { return desc_.padding[1][ndims()-3]; }
362
363     inline bool with_bias() const
364     { return !memory_desc_wrapper(desc_.diff_bias_desc).is_zero(); }
365     inline bool with_groups() const
366     { return desc_.diff_weights_desc.ndims == desc_.diff_dst_desc.ndims + 1; }
367
368     inline int ndims() const { return desc_.src_desc.ndims; }
369
370     virtual status_t set_alg_kind(alg_kind_t alg) {
371         if (alg == alg_kind::undef) return status::invalid_arguments;
372         desc_.alg_kind = alg;
373         return status::success;
374     }
375
376     bool has_zero_dim_memory() const {
377         return false
378             || memory_desc_wrapper(desc_.src_desc).has_zero_dim()
379             || memory_desc_wrapper(desc_.diff_dst_desc).has_zero_dim();
380     }
381
382 protected:
383     convolution_desc_t desc_;
384     const convolution_fwd_pd_t *hint_fwd_pd_;
385
386     virtual status_t init() = 0;
387 };
388
389 }
390 }
391
392 #endif
393
394 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s