1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef CPU_REF_DECONVOLUTION_HPP
18 #define CPU_REF_DECONVOLUTION_HPP
23 #include "c_types_map.hpp"
24 #include "cpu_deconvolution_pd.hpp"
25 #include "cpu_engine.hpp"
26 #include "type_helpers.hpp"
28 #include "primitive_iterator.hpp"
30 #define DECLARE_DECONVOLUTION_PD_t(impl_name, ...) \
31 virtual pd_t *clone() const override { return new pd_t(*this); } \
32 virtual status_t create_primitive(primitive_t **primitive, \
33 const primitive_at_t *inputs, \
34 const primitive_t **outputs) const override { \
35 double ms = get_msec(); \
36 using namespace prop_kind;\
37 primitive_t::input_vector ins(inputs, inputs + this->n_inputs()); \
38 primitive_t::output_vector outs(outputs, outputs + this->n_outputs()); \
39 auto ret = safe_ptr_assign<primitive_t>(*primitive, \
40 new (__VA_ARGS__)(this, ins, outs)); \
41 primitive_t *conv_primitive; \
42 if (utils::one_of(this->desc()->prop_kind, backward, backward_weights)) {\
43 primitive_at_t conv_inputs[2];\
44 conv_inputs[0] = inputs[1];\
45 conv_inputs[1] = inputs[0];\
46 conv_pd_->create_primitive((&conv_primitive), conv_inputs, outputs);\
48 else conv_pd_->create_primitive((&conv_primitive), inputs, outputs);\
49 ((__VA_ARGS__ *)(*primitive))->conv_p_ = conv_primitive;\
50 ms = get_msec() - ms; \
51 if (mkldnn_verbose()->level >= 2) { \
52 printf("mkldnn_verbose,create,%s,%g\n", this->info(), ms); \
57 virtual const char *name() const override { return impl_name; }
59 #define DECLARE_DECONVOLUTION_PD_T(impl_name, ...) \
60 DECLARE_DECONVOLUTION_PD_t(impl_name, __VA_ARGS__)
67 static status_t compute_blocked_format(bool with_groups,
68 const memory_desc_t *oi_md, memory_desc_t *io_md)
70 /* Computes blocking for *i*o* format from *o*i* format */
71 if (oi_md->ndims != io_md->ndims) return status::invalid_arguments;
72 blocking_desc_t oi_blk = oi_md->layout_desc.blocking,
73 &io_blk = io_md->layout_desc.blocking;
75 nstl::swap(io_blk.strides[0][0+with_groups], io_blk.strides[0][1+with_groups]);
76 nstl::swap(io_blk.strides[1][0+with_groups], io_blk.strides[1][1+with_groups]);
77 nstl::swap(io_blk.padding_dims[0+with_groups], io_blk.padding_dims[1+with_groups]);
78 nstl::swap(io_blk.offset_padding_to_data[0+with_groups],
79 io_blk.offset_padding_to_data[1+with_groups]);
80 nstl::swap(io_blk.block_dims[0+with_groups], io_blk.block_dims[1+with_groups]);
81 io_md->format = memory_format::blocked;
82 return status::success;
85 static status_t conv_descr_create(const deconvolution_desc_t *dd,
86 convolution_desc_t *cd)
88 using namespace prop_kind;
89 using namespace memory_format;
90 alg_kind_t alg_kind = ( dd->alg_kind == alg_kind::deconvolution_direct
91 ? alg_kind::convolution_direct : alg_kind::convolution_winograd );
92 prop_kind_t prop_kind;
93 const memory_desc_t *src_md, *dst_md;
94 memory_desc_t c_weights_d, d_weights_d;
96 if ( utils::one_of(dd->prop_kind, forward_training, forward_inference) ) {
97 prop_kind = backward_data;
98 src_md = &dd->dst_desc;
99 dst_md = &dd->src_desc;
100 d_weights_d = dd->weights_desc;
102 else if( utils::one_of(dd->prop_kind, backward, backward_data) ) {
103 prop_kind = forward_training;
104 src_md = &dd->diff_dst_desc;
105 dst_md = &dd->diff_src_desc;
106 d_weights_d = dd->weights_desc;
109 prop_kind = dd->prop_kind;
110 src_md = &dd->diff_dst_desc;
111 dst_md = &dd->src_desc;
112 d_weights_d = dd->diff_weights_desc;
114 with_groups = d_weights_d.ndims == src_md->ndims + 1;
116 /* create weights desc for convolution */
117 c_weights_d = d_weights_d;
118 nstl::swap(c_weights_d.dims[with_groups + 0], c_weights_d.dims[with_groups + 1]);
119 if (c_weights_d.format != any)
121 if (utils::one_of(c_weights_d.format, gOIhw8i16o2i, OIhw8i16o2i,
122 gOIhw8o16i2o, OIhw8o16i2o, gOIhw4i16o4i, OIhw4i16o4i))
123 return unimplemented;
124 CHECK( compute_blocked_format(with_groups, &d_weights_d, &c_weights_d));
126 return conv_desc_init(cd, prop_kind, alg_kind,
127 src_md, &(c_weights_d),
128 ( (utils::one_of(dd->prop_kind, backward, backward_data))?
129 (&(dd->bias_desc)): nullptr ), dst_md, dd->strides, nullptr,
130 dd->padding[0], dd->padding[1], dd->padding_kind);
133 struct ref_deconvolution_fwd_t: public cpu_primitive_t {
134 struct pd_t: public cpu_deconvolution_fwd_pd_t {
135 pd_t(engine_t *engine,
136 const deconvolution_desc_t *adesc,
137 const primitive_attr_t *attr,
138 const deconvolution_fwd_pd_t *hint_fwd_pd)
139 : cpu_deconvolution_fwd_pd_t(engine, adesc, attr,
143 DECLARE_DECONVOLUTION_PD_T("ref:any", ref_deconvolution_fwd_t);
145 status_t init_convolution(){
146 using namespace memory_format;
147 convolution_desc_t cd;
150 status = conv_descr_create(this->cdesc(), &cd);
151 if (status != status::success) return status;
153 mkldnn_primitive_desc_iterator it(this->engine_, (op_desc_t *)&cd,
154 &(this->attr_), nullptr);
155 while (++it != it.end()) {
157 const memory_desc_t *md = conv_pd_->weights_pd()->desc();
158 /* double blocked format is not supported */
159 if (!utils::one_of(md->format, gOIhw8i16o2i, OIhw8i16o2i,
160 gOIhw8o16i2o, OIhw8o16i2o, gOIhw4i16o4i, OIhw4i16o4i))
163 return unimplemented;
166 virtual status_t init() override {
167 using namespace prop_kind;
168 using namespace data_type;
169 assert(this->engine()->kind() == engine_kind::cpu);
171 && utils::one_of(this->desc()->prop_kind, forward_training,
173 && utils::everyone_is(data_type::f32,
174 this->desc()->src_desc.data_type,
175 this->desc()->weights_desc.data_type,
176 this->desc()->dst_desc.data_type)
177 && utils::one_of(this->desc()->alg_kind,
178 alg_kind::deconvolution_direct,
179 alg_kind::deconvolution_winograd)
180 && this->attr()->has_default_values();
182 CHECK(init_convolution());
183 if (weights_pd_.desc()->format == memory_format::any)
185 CHECK(compute_blocked_format(with_groups(),
186 conv_pd_->weights_pd()->desc(),
187 &desc_.weights_desc));
188 cpu_memory_pd_t weights(engine_, &desc_.weights_desc);
189 weights_pd_ = weights;
191 if (src_pd_.desc()->format == memory_format::any)
192 CHECK(src_pd_.set_format(conv_pd_->diff_dst_pd()->desc()->format));
193 if (dst_pd_.desc()->format == memory_format::any)
194 CHECK(dst_pd_.set_format(conv_pd_->diff_src_pd()->desc()->format));
195 if (bias_pd_.desc()->format == memory_format::any)
196 CHECK(bias_pd_.set_format(memory_format::x));
197 return status::success;
199 else return status::unimplemented;
201 primitive_desc_t *conv_pd_;
204 ref_deconvolution_fwd_t(const pd_t *pd, const input_vector &inputs,
205 const output_vector &outputs)
206 : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd) {}
208 typedef typename prec_traits<data_type::f32>::type data_t;
210 virtual void execute(event_t *e) {
211 switch (conf_.desc()->prop_kind) {
212 case prop_kind::forward_training:
213 case prop_kind::forward_inference:
214 (conv_p_)->execute(e);
215 if (conf_.with_bias()) {
216 switch (conf_.dst_pd()->desc()->format) {
217 case memory_format::nchw :
218 case memory_format::ncdhw :
219 compute_fwd_bias_ncdhw();
221 case memory_format::nChw8c :
222 compute_fwd_bias_nCdhwXc<8>();
224 case memory_format::nChw16c :
225 case memory_format::nCdhw16c :
226 compute_fwd_bias_nCdhwXc<16>();
235 assert(!"invalid prop_kind");
237 e->set_state(event_t::ready);
241 void compute_fwd_bias();
242 void compute_fwd_bias_ncdhw();
243 template <int blksize> void compute_fwd_bias_nCdhwXc();
245 primitive_t *conv_p_;
248 struct ref_deconvolution_bwd_data_t: public cpu_primitive_t {
249 struct pd_t: public cpu_deconvolution_bwd_data_pd_t {
250 pd_t(engine_t *engine,
251 const deconvolution_desc_t *adesc,
252 const primitive_attr_t *attr,
253 const deconvolution_fwd_pd_t *hint_fwd_pd)
254 : cpu_deconvolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd)
257 DECLARE_DECONVOLUTION_PD_T("ref:any", ref_deconvolution_bwd_data_t);
259 status_t init_convolution(){
260 using namespace memory_format;
261 convolution_desc_t cd;
264 status = conv_descr_create(this->cdesc(), &cd);
265 if (status != status::success) return status;
267 mkldnn_primitive_desc_iterator it(this->engine_, (op_desc_t *)&cd,
268 &(this->attr_), nullptr);
269 while (++it != it.end()) {
271 const memory_desc_t *md = conv_pd_->weights_pd()->desc();
272 /* double blocked format is not supported */
273 if (!utils::one_of(md->format, gOIhw8i16o2i, OIhw8i16o2i,
274 gOIhw8o16i2o, OIhw8o16i2o, gOIhw4i16o4i, OIhw4i16o4i))
277 return unimplemented;
280 virtual status_t init() override {
281 using namespace prop_kind;
282 using namespace data_type;
283 assert(this->engine()->kind() == engine_kind::cpu);
285 && utils::one_of(this->desc()->prop_kind, backward,
287 && utils::everyone_is(data_type::f32,
288 this->desc()->diff_src_desc.data_type,
289 this->desc()->weights_desc.data_type,
290 this->desc()->diff_dst_desc.data_type)
291 && utils::one_of(this->desc()->alg_kind,
292 alg_kind::deconvolution_direct,
293 alg_kind::deconvolution_winograd);
295 CHECK(init_convolution());
296 if (weights_pd_.desc()->format == memory_format::any)
298 CHECK(compute_blocked_format(with_groups(),
299 conv_pd_->weights_pd()->desc(),
300 &desc_.weights_desc));
301 cpu_memory_pd_t weights(engine_, &desc_.weights_desc);
302 weights_pd_ = weights;
304 if (diff_src_pd_.desc()->format == memory_format::any)
305 CHECK(diff_src_pd_.set_format(conv_pd_->dst_pd()->desc()->format));
306 if (diff_dst_pd_.desc()->format == memory_format::any)
307 CHECK(diff_dst_pd_.set_format(conv_pd_->src_pd()->desc()->format));
308 return status::success;
310 else return status::unimplemented;
312 primitive_desc_t *conv_pd_;
314 ref_deconvolution_bwd_data_t(const pd_t *pd, const input_vector &inputs,
315 const output_vector &outputs)
316 : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd) {}
318 virtual void execute(event_t *e) {
319 switch (conf_.desc()->prop_kind) {
320 case prop_kind::backward:
321 case prop_kind::backward_data:
322 (conv_p_)->execute(e);
325 assert(!"invalid prop_kind");
327 e->set_state(event_t::ready);
332 primitive_t *conv_p_;
335 struct ref_deconvolution_bwd_weights_t: public cpu_primitive_t {
336 struct pd_t: public cpu_deconvolution_bwd_weights_pd_t {
337 pd_t(engine_t *engine,
338 const deconvolution_desc_t *adesc,
339 const primitive_attr_t *attr,
340 const deconvolution_fwd_pd_t *hint_fwd_pd)
341 : cpu_deconvolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd)
344 DECLARE_DECONVOLUTION_PD_T("ref:any", ref_deconvolution_bwd_weights_t);
346 status_t init_convolution(){
347 using namespace memory_format;
348 convolution_desc_t cd;
351 status = conv_descr_create(this->cdesc(), &cd);
352 if (status != status::success) return status;
354 mkldnn_primitive_desc_iterator it(this->engine_, (op_desc_t *)&cd,
355 &(this->attr_), nullptr);
356 while (++it != it.end()) {
358 const memory_desc_t *md = conv_pd_->diff_weights_pd()->desc();
359 /* double blocked format is not supported */
360 if (!utils::one_of(md->format, gOIhw8i16o2i, OIhw8i16o2i,
361 gOIhw8o16i2o, OIhw8o16i2o, gOIhw4i16o4i, OIhw4i16o4i))
364 return unimplemented;
367 virtual status_t init() override {
368 using namespace prop_kind;
369 assert(this->engine()->kind() == engine_kind::cpu);
371 && utils::one_of(this->desc()->prop_kind, backward,
373 && utils::everyone_is(data_type::f32,
374 this->desc()->src_desc.data_type,
375 this->desc()->diff_weights_desc.data_type,
376 this->desc()->diff_dst_desc.data_type)
377 && utils::one_of(this->desc()->alg_kind,
378 alg_kind::deconvolution_direct,
379 alg_kind::deconvolution_winograd)
380 && this->attr()->has_default_values();
382 CHECK(init_convolution());
383 if (diff_weights_pd_.desc()->format == memory_format::any)
385 CHECK(compute_blocked_format(with_groups(),
386 conv_pd_->diff_weights_pd()->desc(),
387 &desc_.diff_weights_desc));
388 cpu_memory_pd_t weights(engine_, &desc_.diff_weights_desc);
389 diff_weights_pd_ = weights;
391 if (src_pd_.desc()->format == memory_format::any)
392 CHECK(src_pd_.set_format(conv_pd_->diff_dst_pd()->desc()->format));
393 if (diff_dst_pd_.desc()->format == memory_format::any)
394 CHECK(diff_dst_pd_.set_format(conv_pd_->src_pd()->desc()->format));
395 if (diff_bias_pd_.desc()->format == memory_format::any)
396 CHECK(diff_bias_pd_.set_format(memory_format::x));
397 return status::success;
399 else return status::unimplemented;
401 primitive_desc_t *conv_pd_;
404 ref_deconvolution_bwd_weights_t(const pd_t *pd, const input_vector &inputs,
405 const output_vector &outputs)
406 : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd) {}
408 typedef typename prec_traits<data_type::f32>::type data_t;
410 virtual void execute(event_t *e) {
411 switch (conf_.desc()->prop_kind) {
412 case prop_kind::backward:
413 case prop_kind::backward_weights:
414 (conv_p_)->execute(e);
415 if (conf_.with_bias()) {
416 switch (conf_.diff_dst_pd()->desc()->format) {
417 case memory_format::nchw :
418 case memory_format::ncdhw :
419 compute_bwd_bias_ncdhw();
421 case memory_format::nChw8c :
422 compute_bwd_bias_nCdhwXc<8>();
424 case memory_format::nChw16c :
425 case memory_format::nCdhw16c :
426 compute_bwd_bias_nCdhwXc<16>();
435 assert(!"invalid prop_kind");
437 e->set_state(event_t::ready);
442 primitive_t *conv_p_;
443 void compute_bwd_bias();
444 void compute_bwd_bias_ncdhw();
445 template <int blksize> void compute_bwd_bias_nCdhwXc();
453 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s