1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef CONVOLUTION_PD_HPP
18 #define CONVOLUTION_PD_HPP
22 #include "c_types_map.hpp"
23 #include "primitive_desc.hpp"
24 #include "memory_pd.hpp"
30 status_t conv_desc_init(convolution_desc_t *conv_desc,
31 prop_kind_t prop_kind, alg_kind_t alg_kind,
32 const memory_desc_t *src_desc, const memory_desc_t *weights_desc,
33 const memory_desc_t *bias_desc, const memory_desc_t *dst_desc,
34 const dims_t strides, const dims_t dilates,
35 const dims_t padding_l, const dims_t padding_r,
36 padding_kind_t padding_kind);
38 template <bool with_relu>
39 struct _convolution_fwd_pd_t: public primitive_desc_t {
40 typedef _convolution_fwd_pd_t base_class;
41 typedef _convolution_fwd_pd_t hint_class;
42 typedef typename utils::conditional<with_relu,
43 convolution_relu_desc_t, convolution_desc_t>::type base_desc_t;
44 static constexpr auto base_pkind =
45 utils::conditional_v<with_relu, primitive_kind_t,
46 primitive_kind::convolution_relu, primitive_kind::convolution>::value;
48 _convolution_fwd_pd_t(mkldnn::impl::engine_t *engine,
49 const base_desc_t *adesc, const primitive_attr_t *attr,
50 const _convolution_fwd_pd_t *hint_fwd_pd)
51 : primitive_desc_t(engine, attr, base_pkind), desc_(*adesc)
52 , hint_fwd_pd_(hint_fwd_pd) {}
53 virtual ~_convolution_fwd_pd_t() {}
55 const base_desc_t *desc() const { return &desc_; }
56 inline const convolution_desc_t *cdesc() const { return &cdesc_(); }
57 virtual const op_desc_t *op_desc() const override
58 { return reinterpret_cast<const op_desc_t *>(this->desc()); }
59 virtual void init_info() override { init_info_conv(this, this->info_); }
61 virtual const memory_pd_t *input_pd(int index = 0) const override {
63 case 0: return src_pd();
64 case 1: case 2: return weights_pd(index - 1);
65 default: return nullptr;
68 virtual const memory_pd_t *output_pd(int index = 0) const override
69 { return index == 0 ? dst_pd() : nullptr; }
71 virtual int n_inputs() const override { return 2 + with_bias(); }
72 virtual int n_outputs() const override { return 1; }
74 virtual status_t query(query_t what, int idx, void *result) const override
77 case pkind_traits<base_pkind>::query_d:
78 *(const base_desc_t**)result = desc(); break;
79 default: return primitive_desc_t::query(what, idx, result);
81 return status::success;
84 /* common conv aux functions */
86 inline int MB() const { return input_pd()->desc()->dims[0]; }
88 inline int IC() const { return input_pd()->desc()->dims[1]; }
89 inline int OC() const { return output_pd()->desc()->dims[1]; }
91 { return with_groups() ? cdesc_().weights_desc.dims[0] : 1; }
93 inline int ID() const { return (ndims() == 5)
94 ? input_pd()->desc()->dims[2] : 1; }
95 inline int IH() const { return input_pd()->desc()->dims[ndims()-2]; }
96 inline int IW() const { return input_pd()->desc()->dims[ndims()-1]; }
97 inline int OD() const { return (ndims() == 5)
98 ? output_pd()->desc()->dims[2] : 1; }
99 inline int OH() const { return output_pd()->desc()->dims[ndims()-2]; }
100 inline int OW() const { return output_pd()->desc()->dims[ndims()-1]; }
101 inline int KD() const { return (ndims() == 5)
102 ? cdesc_().weights_desc.dims[2 + with_groups()] : 1; }
103 inline int KH() const
104 { return cdesc_().weights_desc.dims[ndims() - (2 - with_groups())]; }
105 inline int KW() const
106 { return cdesc_().weights_desc.dims[ndims() - (1 - with_groups())]; }
108 inline int KSD() const { return (ndims() == 5) ? cdesc_().strides[0] : 1; }
109 inline int KSH() const { return cdesc_().strides[ndims()-4]; }
110 inline int KSW() const { return cdesc_().strides[ndims()-3]; }
112 inline int KDD() const { return (ndims() == 5) ? cdesc_().dilates[0] : 0; }
113 inline int KDH() const { return cdesc_().dilates[ndims()-4]; }
114 inline int KDW() const { return cdesc_().dilates[ndims()-3]; }
116 inline int padFront() const
117 { return (ndims() == 5) ? cdesc_().padding[0][0] : 0; }
118 inline int padBack() const
119 { return (ndims() == 5) ? cdesc_().padding[1][0] : 0; }
120 inline int padT() const { return cdesc_().padding[0][ndims()-4]; }
121 inline int padB() const { return cdesc_().padding[1][ndims()-4]; }
122 inline int padL() const { return cdesc_().padding[0][ndims()-3]; }
123 inline int padR() const { return cdesc_().padding[1][ndims()-3]; }
125 inline float negative_slope() const;
127 inline bool with_bias() const
128 { return !memory_desc_wrapper(cdesc_().bias_desc).is_zero(); }
129 inline bool with_groups() const
130 { return cdesc_().weights_desc.ndims == cdesc_().src_desc.ndims + 1; }
132 inline int ndims() const { return cdesc_().src_desc.ndims; }
136 const _convolution_fwd_pd_t *hint_fwd_pd_;
138 inline const convolution_desc_t &cdesc_() const;
140 virtual status_t init() = 0;
143 using convolution_fwd_pd_t = mkldnn::impl::_convolution_fwd_pd_t<false>;
144 using convolution_relu_fwd_pd_t = mkldnn::impl::_convolution_fwd_pd_t<true>;
146 template<> inline float convolution_fwd_pd_t::negative_slope() const
148 template<> inline float convolution_relu_fwd_pd_t::negative_slope() const
149 { return desc()->negative_slope; }
151 template<bool with_relu> inline const
152 convolution_desc_t &_convolution_fwd_pd_t<with_relu>::cdesc_() const
155 inline const convolution_desc_t &convolution_relu_fwd_pd_t::cdesc_() const
156 { return desc_.convolution_desc; }
158 struct convolution_bwd_data_pd_t: public primitive_desc_t {
159 typedef convolution_bwd_data_pd_t base_class;
160 typedef convolution_fwd_pd_t hint_class;
161 static constexpr auto base_pkind = primitive_kind::convolution;
163 convolution_bwd_data_pd_t(mkldnn::impl::engine_t *engine,
164 const convolution_desc_t *adesc,
165 const primitive_attr_t *attr,
166 const convolution_fwd_pd_t *hint_fwd_pd)
167 : primitive_desc_t(engine, attr, base_pkind), desc_(*adesc)
168 , hint_fwd_pd_(hint_fwd_pd) {}
169 virtual ~convolution_bwd_data_pd_t() {}
171 const convolution_desc_t *desc() const { return &desc_; }
172 const convolution_desc_t *cdesc() const { return desc(); }
173 virtual const op_desc_t *op_desc() const override
174 { return reinterpret_cast<const op_desc_t *>(this->desc()); }
175 virtual void init_info() override { init_info_conv(this, this->info_); }
177 virtual const memory_pd_t *input_pd(int index = 0) const override {
179 case 0: return diff_dst_pd();
180 case 1: return weights_pd(0);
181 default: return nullptr;
184 virtual const memory_pd_t *output_pd(int index = 0) const override
185 { return index == 0 ? diff_src_pd() : nullptr; }
187 virtual int n_inputs() const override { return 2; }
188 virtual int n_outputs() const override { return 1; }
190 virtual status_t query(query_t what, int idx, void *result) const override
193 case query::convolution_d:
194 *(const convolution_desc_t**)result = desc(); break;
195 default: return primitive_desc_t::query(what, idx, result);
197 return status::success;
200 /* common conv aux functions */
202 inline int MB() const { return output_pd()->desc()->dims[0]; }
203 inline int IC() const { return output_pd()->desc()->dims[1]; }
204 inline int OC() const { return input_pd()->desc()->dims[1]; }
206 { return with_groups() ? desc_.weights_desc.dims[0] : 1; }
208 inline int ID() const { return (ndims() == 5)
209 ? output_pd()->desc()->dims[2] : 1; }
210 inline int IH() const { return output_pd()->desc()->dims[ndims()-2]; }
211 inline int IW() const { return output_pd()->desc()->dims[ndims()-1]; }
212 inline int OD() const { return (ndims() == 5)
213 ? input_pd()->desc()->dims[2] : 1; }
214 inline int OH() const { return input_pd()->desc()->dims[ndims()-2]; }
215 inline int OW() const { return input_pd()->desc()->dims[ndims()-1]; }
216 inline int KD() const { return (ndims() == 5)
217 ? desc_.weights_desc.dims[2 + with_groups()] : 1; }
218 inline int KH() const
219 { return desc_.weights_desc.dims[ndims() - (2 - with_groups())]; }
220 inline int KW() const
221 { return desc_.weights_desc.dims[ndims() - (1 - with_groups())]; }
223 inline int KSD() const { return (ndims() == 5) ? desc_.strides[0] : 1; }
224 inline int KSH() const { return desc_.strides[ndims()-4]; }
225 inline int KSW() const { return desc_.strides[ndims()-3]; }
227 inline int KDD() const { return (ndims() == 5) ? desc_.dilates[0] : 0; }
228 inline int KDH() const { return desc_.dilates[ndims()-4]; }
229 inline int KDW() const { return desc_.dilates[ndims()-3]; }
231 inline int padFront() const
232 { return (ndims() == 5) ? desc_.padding[0][0] : 0; }
233 inline int padBack() const
234 { return (ndims() == 5) ? desc_.padding[1][0] : 0; }
235 inline int padT() const { return desc_.padding[0][ndims()-4]; }
236 inline int padB() const { return desc_.padding[1][ndims()-4]; }
237 inline int padL() const { return desc_.padding[0][ndims()-3]; }
238 inline int padR() const { return desc_.padding[1][ndims()-3]; }
240 inline bool with_bias() const
241 { return !memory_desc_wrapper(desc_.bias_desc).is_zero(); }
242 inline bool with_groups() const
243 { return desc_.weights_desc.ndims == desc_.diff_src_desc.ndims + 1; }
245 inline int ndims() const { return desc_.diff_src_desc.ndims; }
248 convolution_desc_t desc_;
249 const convolution_fwd_pd_t *hint_fwd_pd_;
251 virtual status_t init() = 0;
254 struct convolution_bwd_weights_pd_t: public primitive_desc_t {
255 typedef convolution_bwd_weights_pd_t base_class;
256 typedef convolution_fwd_pd_t hint_class;
257 static constexpr auto base_pkind = primitive_kind::convolution;
259 convolution_bwd_weights_pd_t(mkldnn::impl::engine_t *engine,
260 const convolution_desc_t *adesc,
261 const primitive_attr_t *attr,
262 const convolution_fwd_pd_t *hint_fwd_pd)
263 : primitive_desc_t(engine, attr, base_pkind), desc_(*adesc)
264 , hint_fwd_pd_(hint_fwd_pd) {}
265 virtual ~convolution_bwd_weights_pd_t() {}
267 const convolution_desc_t *desc() const { return &desc_; }
268 const convolution_desc_t *cdesc() const { return desc(); }
269 virtual const op_desc_t *op_desc() const override
270 { return reinterpret_cast<const op_desc_t *>(this->desc()); }
271 virtual void init_info() override { init_info_conv(this, this->info_); }
273 virtual const memory_pd_t *input_pd(int index = 0) const override {
275 case 0: return src_pd();
276 case 1: return diff_dst_pd();
277 default: return nullptr;
280 virtual const memory_pd_t *output_pd(int index = 0) const override {
282 case 0: return diff_weights_pd(0);
283 case 1: return with_bias() ? diff_weights_pd(1) : nullptr;
284 default: return nullptr;
288 virtual int n_inputs() const override { return 2; }
289 virtual int n_outputs() const override { return 1 + with_bias(); }
291 virtual status_t query(query_t what, int idx, void *result) const override
294 case query::convolution_d:
295 *(const convolution_desc_t**)result = desc(); break;
296 default: return primitive_desc_t::query(what, idx, result);
298 return status::success;
301 /* common conv aux functions */
303 inline int MB() const { return desc_.src_desc.dims[0]; }
305 inline int IC() const { return desc_.src_desc.dims[1]; }
306 inline int OC() const { return desc_.diff_dst_desc.dims[1]; }
308 { return with_groups() ? desc_.diff_weights_desc.dims[0] : 1; }
310 inline int ID() const { return (ndims() == 5)
311 ? desc_.src_desc.dims[2] : 1; }
312 inline int IH() const { return desc_.src_desc.dims[ndims()-2]; }
313 inline int IW() const { return desc_.src_desc.dims[ndims()-1]; }
314 inline int OD() const { return (ndims() == 5)
315 ? desc_.diff_dst_desc.dims[2] : 1; }
316 inline int OH() const { return desc_.diff_dst_desc.dims[ndims()-2]; }
317 inline int OW() const { return desc_.diff_dst_desc.dims[ndims()-1]; }
318 inline int KD() const { return (ndims() == 5)
319 ? desc_.diff_weights_desc.dims[2 + with_groups()] : 1; }
320 inline int KH() const
321 { return desc_.diff_weights_desc.dims[ndims() - (2 - with_groups())]; }
322 inline int KW() const
323 { return desc_.diff_weights_desc.dims[ndims() - (1 - with_groups())]; }
325 inline int KSD() const { return (ndims() == 5) ? desc_.strides[0] : 1; }
326 inline int KSH() const { return desc_.strides[ndims()-4]; }
327 inline int KSW() const { return desc_.strides[ndims()-3]; }
329 inline int KDD() const { return (ndims() == 5) ? desc_.dilates[0] : 0; }
330 inline int KDH() const { return desc_.dilates[ndims()-4]; }
331 inline int KDW() const { return desc_.dilates[ndims()-3]; }
333 inline int padFront() const
334 { return (ndims() == 5) ? desc_.padding[0][0] : 0; }
335 inline int padBack() const
336 { return (ndims() == 5) ? desc_.padding[1][0] : 0; }
337 inline int padT() const { return desc_.padding[0][ndims()-4]; }
338 inline int padB() const { return desc_.padding[1][ndims()-4]; }
339 inline int padL() const { return desc_.padding[0][ndims()-3]; }
340 inline int padR() const { return desc_.padding[1][ndims()-3]; }
342 inline bool with_bias() const
343 { return !memory_desc_wrapper(desc_.diff_bias_desc).is_zero(); }
344 inline bool with_groups() const
345 { return desc_.diff_weights_desc.ndims == desc_.diff_dst_desc.ndims + 1; }
347 inline int ndims() const { return desc_.src_desc.ndims; }
350 convolution_desc_t desc_;
351 const convolution_fwd_pd_t *hint_fwd_pd_;
353 virtual status_t init() = 0;
361 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s