Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / common / deconvolution_pd.hpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef DECONVOLUTION_PD_HPP
18 #define DECONVOLUTION_PD_HPP
19
20 #include "mkldnn.h"
21
22 #include "c_types_map.hpp"
23 #include "memory_pd.hpp"
24 #include "primitive_desc.hpp"
25 #include "utils.hpp"
26 namespace mkldnn {
27 namespace impl {
28
29 struct deconvolution_fwd_pd_t : public primitive_desc_t {
30     typedef deconvolution_fwd_pd_t base_class;
31     typedef deconvolution_fwd_pd_t hint_class;
32     static constexpr auto base_pkind = primitive_kind::deconvolution;
33     deconvolution_fwd_pd_t(mkldnn::impl::engine_t *engine,
34             const deconvolution_desc_t *adesc, const primitive_attr_t *attr,
35             const deconvolution_fwd_pd_t *hint_fwd_pd)
36         : primitive_desc_t(engine, attr, base_pkind)
37         , desc_(*adesc)
38         , hint_fwd_pd_(hint_fwd_pd) {}
39     virtual ~deconvolution_fwd_pd_t() {}
40
41     const deconvolution_desc_t *desc() const { return &desc_; }
42     virtual const op_desc_t *op_desc() const override {
43         return reinterpret_cast<const op_desc_t *>(this->desc());
44     }
45     virtual void init_info() override { init_info_conv(this, this->info_); }
46
47     virtual const memory_pd_t *input_pd(int index = 0) const override {
48         switch (index) {
49         case 0: return src_pd();
50         case 1:
51         case 2: return weights_pd(index - 1);
52         default: return nullptr;
53         }
54     }
55     virtual const memory_pd_t *output_pd(int index = 0) const override {
56         return index == 0 ? dst_pd() : nullptr;
57     }
58
59     virtual int n_inputs() const override { return 2 + with_bias(); }
60     virtual int n_outputs() const override { return 1; }
61     /* Memory format Query */
62     virtual status_t query(query_t what, int idx, void *result) const override {
63         switch (what) {
64         case pkind_traits<base_pkind>::query_d:
65             *(const deconvolution_desc_t **)result = desc();
66             break;
67         default: return primitive_desc_t::query(what, idx, result);
68         }
69         return status::success;
70     }
71
72     /* common conv aux functions */
73     inline int MB() const { return desc_.src_desc.dims[0]; }
74
75     inline int IC() const { return desc_.src_desc.dims[1]; }
76     inline int OC() const { return desc_.dst_desc.dims[1]; }
77     inline int G() const
78     { return with_groups() ? desc_.weights_desc.dims[0] : 1; }
79
80     inline int ID() const { return (ndims() == 5)
81         ? desc_.src_desc.dims[2] : 1; }
82     inline int IH() const { return desc_.src_desc.dims[ndims()-2]; }
83     inline int IW() const { return desc_.src_desc.dims[ndims()-1]; }
84     inline int OD() const { return (ndims() == 5)
85         ? desc_.dst_desc.dims[2] : 1; }
86     inline int OH() const { return desc_.dst_desc.dims[ndims()-2]; }
87     inline int OW() const { return desc_.dst_desc.dims[ndims()-1]; }
88     inline int KD() const { return (ndims() == 5)
89         ? desc_.weights_desc.dims[2 + with_groups()] : 1; }
90     inline int KH() const
91     { return desc_.weights_desc.dims[ndims() - (2 - with_groups())]; }
92     inline int KW() const
93     { return desc_.weights_desc.dims[ndims() - (1 - with_groups())]; }
94
95     inline int KSD() const { return (ndims() == 5) ? desc_.strides[0] : 1; }
96     inline int KSH() const { return desc_.strides[ndims()-4]; }
97     inline int KSW() const { return desc_.strides[ndims()-3]; }
98
99     inline int KDD() const { return (ndims() == 5) ? desc_.dilates[0] : 0; }
100     inline int KDH() const { return desc_.dilates[ndims()-4]; }
101     inline int KDW() const { return desc_.dilates[ndims()-3]; }
102
103     inline int padFront() const
104         { return (ndims() == 5) ? desc_.padding[0][0] : 0; }
105     inline int padBack() const
106         { return (ndims() == 5) ? desc_.padding[1][0] : 0; }
107     inline int padT() const { return desc_.padding[0][ndims()-4]; }
108     inline int padB() const { return desc_.padding[1][ndims()-4]; }
109     inline int padL() const { return desc_.padding[0][ndims()-3]; }
110     inline int padR() const { return desc_.padding[1][ndims()-3]; }
111
112     inline bool with_bias() const {
113         return !memory_desc_wrapper(desc_.bias_desc).is_zero();
114     }
115     inline bool with_groups() const {
116         return desc_.weights_desc.ndims == desc_.src_desc.ndims + 1;
117     }
118     inline int ndims() const { return desc_.src_desc.ndims; }
119
120     bool has_zero_dim_memory() const {
121         return false
122             || memory_desc_wrapper(desc_.src_desc).has_zero_dim()
123             || memory_desc_wrapper(desc_.dst_desc).has_zero_dim();
124     }
125
126 protected:
127     deconvolution_desc_t desc_;
128     const deconvolution_fwd_pd_t *hint_fwd_pd_;
129     virtual status_t init() = 0;
130 };
131
132 struct deconvolution_bwd_data_pd_t : public primitive_desc_t {
133     typedef deconvolution_bwd_data_pd_t base_class;
134     typedef deconvolution_fwd_pd_t hint_class;
135     static constexpr auto base_pkind = primitive_kind::deconvolution;
136
137     deconvolution_bwd_data_pd_t(mkldnn::impl::engine_t *engine,
138             const deconvolution_desc_t *adesc, const primitive_attr_t *attr,
139             const deconvolution_fwd_pd_t *hint_fwd_pd)
140         : primitive_desc_t(engine, attr, base_pkind)
141         , desc_(*adesc)
142         , hint_fwd_pd_(hint_fwd_pd) {}
143     virtual ~deconvolution_bwd_data_pd_t() {}
144
145     const deconvolution_desc_t *desc() const { return &desc_; }
146     virtual const op_desc_t *op_desc() const override {
147         return reinterpret_cast<const op_desc_t *>(this->desc());
148     }
149     virtual void init_info() override { init_info_conv(this, this->info_); }
150
151     virtual const memory_pd_t *input_pd(int index = 0) const override {
152         switch (index) {
153         case 0: return diff_dst_pd();
154         case 1: return weights_pd(0);
155         default: return nullptr;
156         }
157     }
158     virtual const memory_pd_t *output_pd(int index = 0) const override {
159         return index == 0 ? diff_src_pd() : nullptr;
160     }
161
162     virtual int n_inputs() const override { return 2; }
163     virtual int n_outputs() const override { return 1; }
164
165     virtual status_t query(query_t what, int idx, void *result) const override {
166         switch (what) {
167         case query::deconvolution_d:
168             *(const deconvolution_desc_t **)result = desc();
169             break;
170         default: return primitive_desc_t::query(what, idx, result);
171         }
172         return status::success;
173     }
174
175     /* common conv aux functions */
176     inline int MB() const { return desc_.diff_src_desc.dims[0]; }
177
178     inline int IC() const { return desc_.diff_src_desc.dims[1]; }
179     inline int OC() const { return desc_.diff_dst_desc.dims[1]; }
180     inline int G() const
181     { return with_groups() ? desc_.weights_desc.dims[0] : 1; }
182
183     inline int ID() const { return (ndims() == 5)
184         ? desc_.diff_src_desc.dims[2] : 1; }
185     inline int IH() const { return desc_.diff_src_desc.dims[ndims()-2]; }
186     inline int IW() const { return desc_.diff_src_desc.dims[ndims()-1]; }
187     inline int OD() const { return (ndims() == 5)
188         ? desc_.diff_dst_desc.dims[2] : 1; }
189     inline int OH() const { return desc_.diff_dst_desc.dims[ndims()-2]; }
190     inline int OW() const { return desc_.diff_dst_desc.dims[ndims()-1]; }
191     inline int KD() const { return (ndims() == 5)
192         ? desc_.weights_desc.dims[2 + with_groups()] : 1; }
193     inline int KH() const
194     { return desc_.weights_desc.dims[ndims() - (2 - with_groups())]; }
195     inline int KW() const
196     { return desc_.weights_desc.dims[ndims() - (1 - with_groups())]; }
197
198     inline int KSD() const { return (ndims() == 5) ? desc_.strides[0] : 1; }
199     inline int KSH() const { return desc_.strides[ndims()-4]; }
200     inline int KSW() const { return desc_.strides[ndims()-3]; }
201
202     inline int KDD() const { return (ndims() == 5) ? desc_.dilates[0] : 0; }
203     inline int KDH() const { return desc_.dilates[ndims()-4]; }
204     inline int KDW() const { return desc_.dilates[ndims()-3]; }
205
206     inline int padFront() const
207         { return (ndims() == 5) ? desc_.padding[0][0] : 0; }
208     inline int padBack() const
209         { return (ndims() == 5) ? desc_.padding[1][0] : 0; }
210     inline int padT() const { return desc_.padding[0][ndims()-4]; }
211     inline int padB() const { return desc_.padding[1][ndims()-4]; }
212     inline int padL() const { return desc_.padding[0][ndims()-3]; }
213     inline int padR() const { return desc_.padding[1][ndims()-3]; }
214
215     inline bool with_bias() const {
216         return !memory_desc_wrapper(desc_.bias_desc).is_zero();
217     }
218     inline bool with_groups() const {
219         return desc_.weights_desc.ndims == desc_.diff_src_desc.ndims + 1;
220     }
221     inline int ndims() const { return desc_.diff_src_desc.ndims; }
222
223 protected:
224     deconvolution_desc_t desc_;
225     const deconvolution_fwd_pd_t *hint_fwd_pd_;
226     virtual status_t init() = 0;
227 };
228
229 struct deconvolution_bwd_weights_pd_t : public primitive_desc_t {
230     typedef deconvolution_bwd_weights_pd_t base_class;
231     typedef deconvolution_fwd_pd_t hint_class;
232     static constexpr auto base_pkind = primitive_kind::deconvolution;
233
234     deconvolution_bwd_weights_pd_t(mkldnn::impl::engine_t *engine,
235             const deconvolution_desc_t *adesc, const primitive_attr_t *attr,
236             const deconvolution_fwd_pd_t *hint_fwd_pd)
237         : primitive_desc_t(engine, attr, base_pkind)
238         , desc_(*adesc)
239         , hint_fwd_pd_(hint_fwd_pd) {}
240     virtual ~deconvolution_bwd_weights_pd_t() {}
241
242     const deconvolution_desc_t *desc() const { return &desc_; }
243     virtual const op_desc_t *op_desc() const override {
244         return reinterpret_cast<const op_desc_t *>(this->desc());
245     }
246     virtual void init_info() override { init_info_conv(this, this->info_); }
247     virtual const memory_pd_t *input_pd(int index = 0) const override {
248         switch (index) {
249         case 0: return src_pd();
250         case 1: return diff_dst_pd();
251         default: return nullptr;
252         }
253     }
254     virtual const memory_pd_t *output_pd(int index = 0) const override {
255         switch (index) {
256         case 0: return diff_weights_pd(0);
257         case 1: return with_bias() ? diff_weights_pd(1) : nullptr;
258         default: return nullptr;
259         }
260     }
261
262     virtual int n_inputs() const override { return 2; }
263     virtual int n_outputs() const override { return 1 + with_bias(); }
264
265     virtual status_t query(query_t what, int idx, void *result) const override {
266         switch (what) {
267         case query::deconvolution_d:
268             *(const deconvolution_desc_t **)result = desc();
269             break;
270         default: return primitive_desc_t::query(what, idx, result);
271         }
272         return status::success;
273     }
274
275     /* common conv aux functions */
276     inline int MB() const { return desc_.src_desc.dims[0]; }
277
278     inline int IC() const { return desc_.src_desc.dims[1]; }
279     inline int OC() const { return desc_.diff_dst_desc.dims[1]; }
280     inline int G() const
281     { return with_groups() ? desc_.diff_weights_desc.dims[0] : 1; }
282
283     inline int ID() const { return (ndims() == 5)
284         ? desc_.src_desc.dims[2] : 1; }
285     inline int IH() const { return desc_.src_desc.dims[ndims()-2]; }
286     inline int IW() const { return desc_.src_desc.dims[ndims()-1]; }
287     inline int OD() const { return (ndims() == 5)
288         ? desc_.diff_dst_desc.dims[2] : 1; }
289     inline int OH() const { return desc_.diff_dst_desc.dims[ndims()-2]; }
290     inline int OW() const { return desc_.diff_dst_desc.dims[ndims()-1]; }
291     inline int KD() const { return (ndims() == 5)
292         ? desc_.diff_weights_desc.dims[2 + with_groups()] : 1; }
293     inline int KH() const
294     { return desc_.diff_weights_desc.dims[ndims() - (2 - with_groups())]; }
295     inline int KW() const
296     { return desc_.diff_weights_desc.dims[ndims() - (1 - with_groups())]; }
297
298     inline int KSD() const { return (ndims() == 5) ? desc_.strides[0] : 1; }
299     inline int KSH() const { return desc_.strides[ndims()-4]; }
300     inline int KSW() const { return desc_.strides[ndims()-3]; }
301
302     inline int KDD() const { return (ndims() == 5) ? desc_.dilates[0] : 0; }
303     inline int KDH() const { return desc_.dilates[ndims()-4]; }
304     inline int KDW() const { return desc_.dilates[ndims()-3]; }
305
306     inline int padFront() const
307         { return (ndims() == 5) ? desc_.padding[0][0] : 0; }
308     inline int padBack() const
309         { return (ndims() == 5) ? desc_.padding[1][0] : 0; }
310     inline int padT() const { return desc_.padding[0][ndims()-4]; }
311     inline int padB() const { return desc_.padding[1][ndims()-4]; }
312     inline int padL() const { return desc_.padding[0][ndims()-3]; }
313     inline int padR() const { return desc_.padding[1][ndims()-3]; }
314
315     inline bool with_bias() const {
316         return !memory_desc_wrapper(desc_.diff_bias_desc).is_zero();
317     }
318     inline bool with_groups() const {
319         return desc_.diff_weights_desc.ndims == desc_.diff_dst_desc.ndims + 1;
320     }
321     inline int ndims() const { return desc_.src_desc.ndims; }
322
323 protected:
324     deconvolution_desc_t desc_;
325     const deconvolution_fwd_pd_t *hint_fwd_pd_;
326     virtual status_t init() = 0;
327 };
328 }
329 }
330
331 #endif
332
333 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s