1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef DECONVOLUTION_PD_HPP
18 #define DECONVOLUTION_PD_HPP
22 #include "c_types_map.hpp"
23 #include "memory_pd.hpp"
24 #include "primitive_desc.hpp"
29 struct deconvolution_fwd_pd_t : public primitive_desc_t {
30 typedef deconvolution_fwd_pd_t base_class;
31 typedef deconvolution_fwd_pd_t hint_class;
32 static constexpr auto base_pkind = primitive_kind::deconvolution;
33 deconvolution_fwd_pd_t(mkldnn::impl::engine_t *engine,
34 const deconvolution_desc_t *adesc, const primitive_attr_t *attr,
35 const deconvolution_fwd_pd_t *hint_fwd_pd)
36 : primitive_desc_t(engine, attr, base_pkind)
38 , hint_fwd_pd_(hint_fwd_pd) {}
39 virtual ~deconvolution_fwd_pd_t() {}
41 const deconvolution_desc_t *desc() const { return &desc_; }
42 virtual const op_desc_t *op_desc() const override {
43 return reinterpret_cast<const op_desc_t *>(this->desc());
45 virtual void init_info() override { init_info_conv(this, this->info_); }
47 virtual const memory_pd_t *input_pd(int index = 0) const override {
49 case 0: return src_pd();
51 case 2: return weights_pd(index - 1);
52 default: return nullptr;
55 virtual const memory_pd_t *output_pd(int index = 0) const override {
56 return index == 0 ? dst_pd() : nullptr;
59 virtual int n_inputs() const override { return 2 + with_bias(); }
60 virtual int n_outputs() const override { return 1; }
61 /* Memory format Query */
62 virtual status_t query(query_t what, int idx, void *result) const override {
64 case pkind_traits<base_pkind>::query_d:
65 *(const deconvolution_desc_t **)result = desc();
67 default: return primitive_desc_t::query(what, idx, result);
69 return status::success;
72 /* common conv aux functions */
73 inline int MB() const { return desc_.src_desc.dims[0]; }
75 inline int IC() const { return desc_.src_desc.dims[1]; }
76 inline int OC() const { return desc_.dst_desc.dims[1]; }
78 { return with_groups() ? desc_.weights_desc.dims[0] : 1; }
80 inline int ID() const { return (ndims() == 5)
81 ? desc_.src_desc.dims[2] : 1; }
82 inline int IH() const { return desc_.src_desc.dims[ndims()-2]; }
83 inline int IW() const { return desc_.src_desc.dims[ndims()-1]; }
84 inline int OD() const { return (ndims() == 5)
85 ? desc_.dst_desc.dims[2] : 1; }
86 inline int OH() const { return desc_.dst_desc.dims[ndims()-2]; }
87 inline int OW() const { return desc_.dst_desc.dims[ndims()-1]; }
88 inline int KD() const { return (ndims() == 5)
89 ? desc_.weights_desc.dims[2 + with_groups()] : 1; }
91 { return desc_.weights_desc.dims[ndims() - (2 - with_groups())]; }
93 { return desc_.weights_desc.dims[ndims() - (1 - with_groups())]; }
95 inline int KSD() const { return (ndims() == 5) ? desc_.strides[0] : 1; }
96 inline int KSH() const { return desc_.strides[ndims()-4]; }
97 inline int KSW() const { return desc_.strides[ndims()-3]; }
99 inline int KDD() const { return (ndims() == 5) ? desc_.dilates[0] : 0; }
100 inline int KDH() const { return desc_.dilates[ndims()-4]; }
101 inline int KDW() const { return desc_.dilates[ndims()-3]; }
103 inline int padFront() const
104 { return (ndims() == 5) ? desc_.padding[0][0] : 0; }
105 inline int padBack() const
106 { return (ndims() == 5) ? desc_.padding[1][0] : 0; }
107 inline int padT() const { return desc_.padding[0][ndims()-4]; }
108 inline int padB() const { return desc_.padding[1][ndims()-4]; }
109 inline int padL() const { return desc_.padding[0][ndims()-3]; }
110 inline int padR() const { return desc_.padding[1][ndims()-3]; }
112 inline bool with_bias() const {
113 return !memory_desc_wrapper(desc_.bias_desc).is_zero();
115 inline bool with_groups() const {
116 return desc_.weights_desc.ndims == desc_.src_desc.ndims + 1;
118 inline int ndims() const { return desc_.src_desc.ndims; }
120 bool has_zero_dim_memory() const {
122 || memory_desc_wrapper(desc_.src_desc).has_zero_dim()
123 || memory_desc_wrapper(desc_.dst_desc).has_zero_dim();
127 deconvolution_desc_t desc_;
128 const deconvolution_fwd_pd_t *hint_fwd_pd_;
129 virtual status_t init() = 0;
132 struct deconvolution_bwd_data_pd_t : public primitive_desc_t {
133 typedef deconvolution_bwd_data_pd_t base_class;
134 typedef deconvolution_fwd_pd_t hint_class;
135 static constexpr auto base_pkind = primitive_kind::deconvolution;
137 deconvolution_bwd_data_pd_t(mkldnn::impl::engine_t *engine,
138 const deconvolution_desc_t *adesc, const primitive_attr_t *attr,
139 const deconvolution_fwd_pd_t *hint_fwd_pd)
140 : primitive_desc_t(engine, attr, base_pkind)
142 , hint_fwd_pd_(hint_fwd_pd) {}
143 virtual ~deconvolution_bwd_data_pd_t() {}
145 const deconvolution_desc_t *desc() const { return &desc_; }
146 virtual const op_desc_t *op_desc() const override {
147 return reinterpret_cast<const op_desc_t *>(this->desc());
149 virtual void init_info() override { init_info_conv(this, this->info_); }
151 virtual const memory_pd_t *input_pd(int index = 0) const override {
153 case 0: return diff_dst_pd();
154 case 1: return weights_pd(0);
155 default: return nullptr;
158 virtual const memory_pd_t *output_pd(int index = 0) const override {
159 return index == 0 ? diff_src_pd() : nullptr;
162 virtual int n_inputs() const override { return 2; }
163 virtual int n_outputs() const override { return 1; }
165 virtual status_t query(query_t what, int idx, void *result) const override {
167 case query::deconvolution_d:
168 *(const deconvolution_desc_t **)result = desc();
170 default: return primitive_desc_t::query(what, idx, result);
172 return status::success;
175 /* common conv aux functions */
176 inline int MB() const { return desc_.diff_src_desc.dims[0]; }
178 inline int IC() const { return desc_.diff_src_desc.dims[1]; }
179 inline int OC() const { return desc_.diff_dst_desc.dims[1]; }
181 { return with_groups() ? desc_.weights_desc.dims[0] : 1; }
183 inline int ID() const { return (ndims() == 5)
184 ? desc_.diff_src_desc.dims[2] : 1; }
185 inline int IH() const { return desc_.diff_src_desc.dims[ndims()-2]; }
186 inline int IW() const { return desc_.diff_src_desc.dims[ndims()-1]; }
187 inline int OD() const { return (ndims() == 5)
188 ? desc_.diff_dst_desc.dims[2] : 1; }
189 inline int OH() const { return desc_.diff_dst_desc.dims[ndims()-2]; }
190 inline int OW() const { return desc_.diff_dst_desc.dims[ndims()-1]; }
191 inline int KD() const { return (ndims() == 5)
192 ? desc_.weights_desc.dims[2 + with_groups()] : 1; }
193 inline int KH() const
194 { return desc_.weights_desc.dims[ndims() - (2 - with_groups())]; }
195 inline int KW() const
196 { return desc_.weights_desc.dims[ndims() - (1 - with_groups())]; }
198 inline int KSD() const { return (ndims() == 5) ? desc_.strides[0] : 1; }
199 inline int KSH() const { return desc_.strides[ndims()-4]; }
200 inline int KSW() const { return desc_.strides[ndims()-3]; }
202 inline int KDD() const { return (ndims() == 5) ? desc_.dilates[0] : 0; }
203 inline int KDH() const { return desc_.dilates[ndims()-4]; }
204 inline int KDW() const { return desc_.dilates[ndims()-3]; }
206 inline int padFront() const
207 { return (ndims() == 5) ? desc_.padding[0][0] : 0; }
208 inline int padBack() const
209 { return (ndims() == 5) ? desc_.padding[1][0] : 0; }
210 inline int padT() const { return desc_.padding[0][ndims()-4]; }
211 inline int padB() const { return desc_.padding[1][ndims()-4]; }
212 inline int padL() const { return desc_.padding[0][ndims()-3]; }
213 inline int padR() const { return desc_.padding[1][ndims()-3]; }
215 inline bool with_bias() const {
216 return !memory_desc_wrapper(desc_.bias_desc).is_zero();
218 inline bool with_groups() const {
219 return desc_.weights_desc.ndims == desc_.diff_src_desc.ndims + 1;
221 inline int ndims() const { return desc_.diff_src_desc.ndims; }
224 deconvolution_desc_t desc_;
225 const deconvolution_fwd_pd_t *hint_fwd_pd_;
226 virtual status_t init() = 0;
229 struct deconvolution_bwd_weights_pd_t : public primitive_desc_t {
230 typedef deconvolution_bwd_weights_pd_t base_class;
231 typedef deconvolution_fwd_pd_t hint_class;
232 static constexpr auto base_pkind = primitive_kind::deconvolution;
234 deconvolution_bwd_weights_pd_t(mkldnn::impl::engine_t *engine,
235 const deconvolution_desc_t *adesc, const primitive_attr_t *attr,
236 const deconvolution_fwd_pd_t *hint_fwd_pd)
237 : primitive_desc_t(engine, attr, base_pkind)
239 , hint_fwd_pd_(hint_fwd_pd) {}
240 virtual ~deconvolution_bwd_weights_pd_t() {}
242 const deconvolution_desc_t *desc() const { return &desc_; }
243 virtual const op_desc_t *op_desc() const override {
244 return reinterpret_cast<const op_desc_t *>(this->desc());
246 virtual void init_info() override { init_info_conv(this, this->info_); }
247 virtual const memory_pd_t *input_pd(int index = 0) const override {
249 case 0: return src_pd();
250 case 1: return diff_dst_pd();
251 default: return nullptr;
254 virtual const memory_pd_t *output_pd(int index = 0) const override {
256 case 0: return diff_weights_pd(0);
257 case 1: return with_bias() ? diff_weights_pd(1) : nullptr;
258 default: return nullptr;
262 virtual int n_inputs() const override { return 2; }
263 virtual int n_outputs() const override { return 1 + with_bias(); }
265 virtual status_t query(query_t what, int idx, void *result) const override {
267 case query::deconvolution_d:
268 *(const deconvolution_desc_t **)result = desc();
270 default: return primitive_desc_t::query(what, idx, result);
272 return status::success;
275 /* common conv aux functions */
276 inline int MB() const { return desc_.src_desc.dims[0]; }
278 inline int IC() const { return desc_.src_desc.dims[1]; }
279 inline int OC() const { return desc_.diff_dst_desc.dims[1]; }
281 { return with_groups() ? desc_.diff_weights_desc.dims[0] : 1; }
283 inline int ID() const { return (ndims() == 5)
284 ? desc_.src_desc.dims[2] : 1; }
285 inline int IH() const { return desc_.src_desc.dims[ndims()-2]; }
286 inline int IW() const { return desc_.src_desc.dims[ndims()-1]; }
287 inline int OD() const { return (ndims() == 5)
288 ? desc_.diff_dst_desc.dims[2] : 1; }
289 inline int OH() const { return desc_.diff_dst_desc.dims[ndims()-2]; }
290 inline int OW() const { return desc_.diff_dst_desc.dims[ndims()-1]; }
291 inline int KD() const { return (ndims() == 5)
292 ? desc_.diff_weights_desc.dims[2 + with_groups()] : 1; }
293 inline int KH() const
294 { return desc_.diff_weights_desc.dims[ndims() - (2 - with_groups())]; }
295 inline int KW() const
296 { return desc_.diff_weights_desc.dims[ndims() - (1 - with_groups())]; }
298 inline int KSD() const { return (ndims() == 5) ? desc_.strides[0] : 1; }
299 inline int KSH() const { return desc_.strides[ndims()-4]; }
300 inline int KSW() const { return desc_.strides[ndims()-3]; }
302 inline int KDD() const { return (ndims() == 5) ? desc_.dilates[0] : 0; }
303 inline int KDH() const { return desc_.dilates[ndims()-4]; }
304 inline int KDW() const { return desc_.dilates[ndims()-3]; }
306 inline int padFront() const
307 { return (ndims() == 5) ? desc_.padding[0][0] : 0; }
308 inline int padBack() const
309 { return (ndims() == 5) ? desc_.padding[1][0] : 0; }
310 inline int padT() const { return desc_.padding[0][ndims()-4]; }
311 inline int padB() const { return desc_.padding[1][ndims()-4]; }
312 inline int padL() const { return desc_.padding[0][ndims()-3]; }
313 inline int padR() const { return desc_.padding[1][ndims()-3]; }
315 inline bool with_bias() const {
316 return !memory_desc_wrapper(desc_.diff_bias_desc).is_zero();
318 inline bool with_groups() const {
319 return desc_.diff_weights_desc.ndims == desc_.diff_dst_desc.ndims + 1;
321 inline int ndims() const { return desc_.src_desc.ndims; }
324 deconvolution_desc_t desc_;
325 const deconvolution_fwd_pd_t *hint_fwd_pd_;
326 virtual status_t init() = 0;
333 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s