1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef CONVOLUTION_PD_HPP
18 #define CONVOLUTION_PD_HPP
22 #include "c_types_map.hpp"
23 #include "primitive_desc.hpp"
24 #include "memory_pd.hpp"
30 status_t conv_desc_init(convolution_desc_t *conv_desc,
31 prop_kind_t prop_kind, alg_kind_t alg_kind,
32 const memory_desc_t *src_desc, const memory_desc_t *weights_desc,
33 const memory_desc_t *bias_desc, const memory_desc_t *dst_desc,
34 const dims_t strides, const dims_t dilates,
35 const dims_t padding_l, const dims_t padding_r,
36 padding_kind_t padding_kind);
38 memory_desc_t *conv_prop_agnostic_src_d(convolution_desc_t *desc);
39 memory_desc_t *conv_prop_agnostic_wei_d(convolution_desc_t *desc);
40 memory_desc_t *conv_prop_agnostic_bia_d(convolution_desc_t *desc);
41 memory_desc_t *conv_prop_agnostic_dst_d(convolution_desc_t *desc);
42 const memory_desc_t *conv_prop_agnostic_src_d(const convolution_desc_t *desc);
43 const memory_desc_t *conv_prop_agnostic_wei_d(const convolution_desc_t *desc);
44 const memory_desc_t *conv_prop_agnostic_bia_d(const convolution_desc_t *desc);
45 const memory_desc_t *conv_prop_agnostic_dst_d(const convolution_desc_t *desc);
47 struct convolution_fwd_pd_t: public primitive_desc_t {
48 typedef convolution_fwd_pd_t base_class;
49 typedef convolution_fwd_pd_t hint_class;
50 static constexpr auto base_pkind = primitive_kind::convolution;
52 convolution_fwd_pd_t(mkldnn::impl::engine_t *engine,
53 const convolution_desc_t *adesc, const primitive_attr_t *attr,
54 const convolution_fwd_pd_t *hint_fwd_pd)
55 : primitive_desc_t(engine, attr, base_pkind), desc_(*adesc)
56 , hint_fwd_pd_(hint_fwd_pd) {}
57 virtual ~convolution_fwd_pd_t() {}
59 const convolution_desc_t *desc() const { return &desc_; }
60 virtual const op_desc_t *op_desc() const override
61 { return reinterpret_cast<const op_desc_t *>(this->desc()); }
62 virtual void init_info() override { init_info_conv(this, this->info_); }
64 virtual const memory_pd_t *input_pd(int index = 0) const override {
66 case 0: return src_pd();
67 case 1: case 2: return weights_pd(index - 1);
68 default: return nullptr;
71 virtual const memory_pd_t *output_pd(int index = 0) const override
72 { return index == 0 ? dst_pd() : nullptr; }
74 virtual int n_inputs() const override { return 2 + with_bias(); }
75 virtual int n_outputs() const override { return 1; }
77 virtual status_t query(query_t what, int idx, void *result) const override
80 case pkind_traits<base_pkind>::query_d:
81 *(const convolution_desc_t**)result = desc(); break;
82 default: return primitive_desc_t::query(what, idx, result);
84 return status::success;
87 /* common conv aux functions */
89 inline int MB() const { return input_pd()->desc()->dims[0]; }
91 inline int IC() const { return input_pd()->desc()->dims[1]; }
92 inline int OC() const { return output_pd()->desc()->dims[1]; }
94 { return with_groups() ? desc_.weights_desc.dims[0] : 1; }
96 inline int ID() const { return (ndims() == 5) ? input_pd()->desc()->dims[2] : 1; }
97 inline int IH() const { return (ndims() == 3) ? 1 : input_pd()->desc()->dims[ndims()-2]; }
98 inline int IW() const { return input_pd()->desc()->dims[ndims()-1]; }
99 inline int OD() const { return (ndims() == 5) ? output_pd()->desc()->dims[2] : 1; }
100 inline int OH() const { return (ndims() == 3) ? 1 : output_pd()->desc()->dims[ndims()-2]; }
101 inline int OW() const { return output_pd()->desc()->dims[ndims()-1]; }
102 inline int KD() const { return (ndims() == 5)
103 ? desc_.weights_desc.dims[2 + with_groups()] : 1; }
104 inline int KH() const
105 { return (ndims() == 3)
106 ? 1 : desc_.weights_desc.dims[ndims() - (2 - with_groups())]; }
107 inline int KW() const
108 { return desc_.weights_desc.dims[ndims() - (1 - with_groups())]; }
110 inline int KSD() const { return (ndims() == 5) ? desc_.strides[0] : 1; }
111 inline int KSH() const { return (ndims() == 3)
112 ? 1 : desc_.strides[ndims()-4]; }
113 inline int KSW() const { return desc_.strides[ndims()-3]; }
115 inline int KDD() const { return (ndims() == 5) ? desc_.dilates[0] : 0; }
116 inline int KDH() const { return (ndims() == 3)
117 ? 0 : desc_.dilates[ndims()-4]; }
118 inline int KDW() const { return desc_.dilates[ndims()-3]; }
120 inline int padFront() const
121 { return (ndims() == 5) ? desc_.padding[0][0] : 0; }
122 inline int padBack() const
123 { return (ndims() == 5) ? desc_.padding[1][0] : 0; }
124 inline int padT() const { return (ndims() == 3)
125 ? 0 : desc_.padding[0][ndims()-4]; }
126 inline int padB() const { return (ndims() == 3)
127 ? 0 : desc_.padding[1][ndims()-4]; }
128 inline int padL() const { return desc_.padding[0][ndims()-3]; }
129 inline int padR() const { return desc_.padding[1][ndims()-3]; }
131 inline bool with_bias() const
132 { return !memory_desc_wrapper(desc_.bias_desc).is_zero(); }
133 inline bool with_groups() const
134 { return desc_.weights_desc.ndims == desc_.src_desc.ndims + 1; }
136 inline int ndims() const { return desc_.src_desc.ndims; }
138 virtual status_t set_alg_kind(alg_kind_t alg) {
139 if (alg == alg_kind::undef) return status::invalid_arguments;
140 desc_.alg_kind = alg;
141 return status::success;
144 bool has_zero_dim_memory() const {
146 || memory_desc_wrapper(desc_.src_desc).has_zero_dim()
147 || memory_desc_wrapper(desc_.dst_desc).has_zero_dim();
152 convolution_desc_t desc_;
153 const convolution_fwd_pd_t *hint_fwd_pd_;
155 virtual status_t init() = 0;
158 struct convolution_bwd_data_pd_t: public primitive_desc_t {
159 typedef convolution_bwd_data_pd_t base_class;
160 typedef convolution_fwd_pd_t hint_class;
161 static constexpr auto base_pkind = primitive_kind::convolution;
163 convolution_bwd_data_pd_t(mkldnn::impl::engine_t *engine,
164 const convolution_desc_t *adesc,
165 const primitive_attr_t *attr,
166 const convolution_fwd_pd_t *hint_fwd_pd)
167 : primitive_desc_t(engine, attr, base_pkind), desc_(*adesc)
168 , hint_fwd_pd_(hint_fwd_pd) {}
169 virtual ~convolution_bwd_data_pd_t() {}
171 const convolution_desc_t *desc() const { return &desc_; }
172 virtual const op_desc_t *op_desc() const override
173 { return reinterpret_cast<const op_desc_t *>(this->desc()); }
174 virtual void init_info() override { init_info_conv(this, this->info_); }
176 virtual const memory_pd_t *input_pd(int index = 0) const override {
178 case 0: return diff_dst_pd();
179 case 1: return weights_pd(0);
180 default: return nullptr;
183 virtual const memory_pd_t *output_pd(int index = 0) const override
184 { return index == 0 ? diff_src_pd() : nullptr; }
186 virtual int n_inputs() const override { return 2 + with_bias(); }
187 virtual int n_outputs() const override { return 1; }
189 virtual status_t query(query_t what, int idx, void *result) const override
192 case query::convolution_d:
193 *(const convolution_desc_t**)result = desc(); break;
194 default: return primitive_desc_t::query(what, idx, result);
196 return status::success;
199 /* common conv aux functions */
201 inline int MB() const { return output_pd()->desc()->dims[0]; }
202 inline int IC() const { return output_pd()->desc()->dims[1]; }
203 inline int OC() const { return input_pd()->desc()->dims[1]; }
205 { return with_groups() ? desc_.weights_desc.dims[0] : 1; }
207 inline int ID() const { return (ndims() == 5) ? output_pd()->desc()->dims[2] : 1; }
208 inline int IH() const { return (ndims() == 3) ? 1 : output_pd()->desc()->dims[ndims()-2]; }
209 inline int IW() const { return output_pd()->desc()->dims[ndims()-1]; }
210 inline int OD() const { return (ndims() == 5) ? input_pd()->desc()->dims[2] : 1; }
211 inline int OH() const { return (ndims() == 3) ? 1 : input_pd()->desc()->dims[ndims()-2]; }
212 inline int OW() const { return input_pd()->desc()->dims[ndims()-1]; }
213 inline int KD() const { return (ndims() == 5)
214 ? desc_.weights_desc.dims[2 + with_groups()] : 1; }
215 inline int KH() const
216 { return (ndims() == 3)
217 ? 1 : desc_.weights_desc.dims[ndims() - (2 - with_groups())]; }
218 inline int KW() const
219 { return desc_.weights_desc.dims[ndims() - (1 - with_groups())]; }
221 inline int KSD() const { return (ndims() == 5) ? desc_.strides[0] : 1; }
222 inline int KSH() const { return (ndims() == 3)
223 ? 1 : desc_.strides[ndims()-4]; }
224 inline int KSW() const { return desc_.strides[ndims()-3]; }
226 inline int KDD() const { return (ndims() == 5) ? desc_.dilates[0] : 0; }
227 inline int KDH() const { return (ndims() == 3)
228 ? 0 : desc_.dilates[ndims()-4]; }
229 inline int KDW() const { return desc_.dilates[ndims()-3]; }
231 inline int padFront() const
232 { return (ndims() == 5) ? desc_.padding[0][0] : 0; }
233 inline int padBack() const
234 { return (ndims() == 5) ? desc_.padding[1][0] : 0; }
235 inline int padT() const { return (ndims() == 3)
236 ? 0 : desc_.padding[0][ndims()-4]; }
237 inline int padB() const { return (ndims() == 3)
238 ? 0 : desc_.padding[1][ndims()-4]; }
239 inline int padL() const { return desc_.padding[0][ndims()-3]; }
240 inline int padR() const { return desc_.padding[1][ndims()-3]; }
242 inline bool with_bias() const
243 { return !memory_desc_wrapper(desc_.bias_desc).is_zero(); }
244 inline bool with_groups() const
245 { return desc_.weights_desc.ndims == desc_.diff_src_desc.ndims + 1; }
247 inline int ndims() const { return desc_.diff_src_desc.ndims; }
248 virtual bool support_bias() const { return false; }
250 virtual status_t set_alg_kind(alg_kind_t alg) {
251 if (alg == alg_kind::undef) return status::invalid_arguments;
252 desc_.alg_kind = alg;
253 return status::success;
256 bool has_zero_dim_memory() const {
258 || memory_desc_wrapper(desc_.diff_src_desc).has_zero_dim()
259 || memory_desc_wrapper(desc_.diff_dst_desc).has_zero_dim();
263 convolution_desc_t desc_;
264 const convolution_fwd_pd_t *hint_fwd_pd_;
266 virtual status_t init() = 0;
269 struct convolution_bwd_weights_pd_t: public primitive_desc_t {
270 typedef convolution_bwd_weights_pd_t base_class;
271 typedef convolution_fwd_pd_t hint_class;
272 static constexpr auto base_pkind = primitive_kind::convolution;
274 convolution_bwd_weights_pd_t(mkldnn::impl::engine_t *engine,
275 const convolution_desc_t *adesc,
276 const primitive_attr_t *attr,
277 const convolution_fwd_pd_t *hint_fwd_pd)
278 : primitive_desc_t(engine, attr, base_pkind), desc_(*adesc)
279 , hint_fwd_pd_(hint_fwd_pd) {}
280 virtual ~convolution_bwd_weights_pd_t() {}
282 const convolution_desc_t *desc() const { return &desc_; }
283 virtual const op_desc_t *op_desc() const override
284 { return reinterpret_cast<const op_desc_t *>(this->desc()); }
285 virtual void init_info() override { init_info_conv(this, this->info_); }
287 virtual const memory_pd_t *input_pd(int index = 0) const override {
289 case 0: return src_pd();
290 case 1: return diff_dst_pd();
291 default: return nullptr;
294 virtual const memory_pd_t *output_pd(int index = 0) const override {
296 case 0: return diff_weights_pd(0);
297 case 1: return with_bias() ? diff_weights_pd(1) : nullptr;
298 default: return nullptr;
302 virtual int n_inputs() const override { return 2; }
303 virtual int n_outputs() const override { return 1 + with_bias(); }
305 virtual status_t query(query_t what, int idx, void *result) const override
308 case query::convolution_d:
309 *(const convolution_desc_t**)result = desc(); break;
310 default: return primitive_desc_t::query(what, idx, result);
312 return status::success;
315 /* common conv aux functions */
317 inline int MB() const { return desc_.src_desc.dims[0]; }
319 inline int IC() const { return desc_.src_desc.dims[1]; }
320 inline int OC() const { return desc_.diff_dst_desc.dims[1]; }
322 { return with_groups() ? desc_.diff_weights_desc.dims[0] : 1; }
324 inline int ID() const { return (ndims() == 5)
325 ? desc_.src_desc.dims[2] : 1; }
326 inline int IH() const { return (ndims() == 3)
327 ? 1 : desc_.src_desc.dims[ndims()-2]; }
328 inline int IW() const { return desc_.src_desc.dims[ndims()-1]; }
329 inline int OD() const { return (ndims() == 5)
330 ? desc_.diff_dst_desc.dims[2] : 1; }
331 inline int OH() const { return (ndims() == 3)
332 ? 1 : desc_.diff_dst_desc.dims[ndims()-2]; }
333 inline int OW() const { return desc_.diff_dst_desc.dims[ndims()-1]; }
334 inline int KD() const { return (ndims() == 5)
335 ? desc_.diff_weights_desc.dims[2 + with_groups()] : 1; }
336 inline int KH() const
337 { return (ndims() == 3)
338 ? 1 : desc_.diff_weights_desc.dims[ndims() - (2 - with_groups())]; }
339 inline int KW() const
340 { return desc_.diff_weights_desc.dims[ndims() - (1 - with_groups())]; }
342 inline int KSD() const { return (ndims() == 5) ? desc_.strides[0] : 1; }
343 inline int KSH() const { return (ndims() == 3)
344 ? 1 : desc_.strides[ndims()-4]; }
345 inline int KSW() const { return desc_.strides[ndims()-3]; }
347 inline int KDD() const { return (ndims() == 5) ? desc_.dilates[0] : 0; }
348 inline int KDH() const { return (ndims() == 3)
349 ? 0 : desc_.dilates[ndims()-4]; }
350 inline int KDW() const { return desc_.dilates[ndims()-3]; }
352 inline int padFront() const
353 { return (ndims() == 5) ? desc_.padding[0][0] : 0; }
354 inline int padBack() const
355 { return (ndims() == 5) ? desc_.padding[1][0] : 0; }
356 inline int padT() const { return (ndims() == 3)
357 ? 0 : desc_.padding[0][ndims()-4]; }
358 inline int padB() const { return (ndims() == 3)
359 ? 0 : desc_.padding[1][ndims()-4]; }
360 inline int padL() const { return desc_.padding[0][ndims()-3]; }
361 inline int padR() const { return desc_.padding[1][ndims()-3]; }
363 inline bool with_bias() const
364 { return !memory_desc_wrapper(desc_.diff_bias_desc).is_zero(); }
365 inline bool with_groups() const
366 { return desc_.diff_weights_desc.ndims == desc_.diff_dst_desc.ndims + 1; }
368 inline int ndims() const { return desc_.src_desc.ndims; }
370 virtual status_t set_alg_kind(alg_kind_t alg) {
371 if (alg == alg_kind::undef) return status::invalid_arguments;
372 desc_.alg_kind = alg;
373 return status::success;
376 bool has_zero_dim_memory() const {
378 || memory_desc_wrapper(desc_.src_desc).has_zero_dim()
379 || memory_desc_wrapper(desc_.diff_dst_desc).has_zero_dim();
383 convolution_desc_t desc_;
384 const convolution_fwd_pd_t *hint_fwd_pd_;
386 virtual status_t init() = 0;
394 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s