Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / ref_deconvolution.hpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef CPU_REF_DECONVOLUTION_HPP
18 #define CPU_REF_DECONVOLUTION_HPP
19
20 #include <assert.h>
21 #include <string.h>
22
23 #include "c_types_map.hpp"
24 #include "cpu_convolution_pd.hpp"
25 #include "cpu_deconvolution_pd.hpp"
26 #include "cpu_engine.hpp"
27 #include "type_helpers.hpp"
28 #include "utils.hpp"
29 #include "primitive_iterator.hpp"
30
31 namespace mkldnn {
32 namespace impl {
33 namespace cpu {
34
35 static status_t compute_blocked_format(bool with_groups,
36     const memory_desc_t *oi_md, memory_desc_t *io_md)
37 {
38     /* Computes blocking for *i*o* format from *o*i* format */
39     if (oi_md->ndims != io_md->ndims) return status::invalid_arguments;
40     blocking_desc_t oi_blk = oi_md->layout_desc.blocking,
41         &io_blk = io_md->layout_desc.blocking;
42     io_blk = oi_blk;
43     nstl::swap(io_blk.strides[0][0+with_groups], io_blk.strides[0][1+with_groups]);
44     nstl::swap(io_blk.strides[1][0+with_groups], io_blk.strides[1][1+with_groups]);
45     nstl::swap(io_blk.padding_dims[0+with_groups], io_blk.padding_dims[1+with_groups]);
46     nstl::swap(io_blk.offset_padding_to_data[0+with_groups],
47          io_blk.offset_padding_to_data[1+with_groups]);
48     nstl::swap(io_blk.block_dims[0+with_groups], io_blk.block_dims[1+with_groups]);
49     io_md->format = memory_format::blocked;
50     return status::success;
51 }
52
53 static status_t conv_descr_create(const deconvolution_desc_t *dd,
54         convolution_desc_t *cd)
55 {
56     using namespace prop_kind;
57     using namespace memory_format;
58     alg_kind_t alg_kind = ( dd->alg_kind == alg_kind::deconvolution_direct
59         ? alg_kind::convolution_direct : alg_kind::convolution_winograd );
60     prop_kind_t prop_kind;
61     const memory_desc_t *src_md, *dst_md;
62     memory_desc_t c_weights_d, d_weights_d;
63     bool with_groups;
64     if ( utils::one_of(dd->prop_kind, forward_training, forward_inference) ) {
65         prop_kind = backward_data;
66         src_md = &dd->dst_desc;
67         dst_md = &dd->src_desc;
68         d_weights_d = dd->weights_desc;
69     } else if (dd->prop_kind == backward_data) {
70         prop_kind = forward_training;
71         src_md = &dd->diff_dst_desc;
72         dst_md = &dd->diff_src_desc;
73         d_weights_d = dd->weights_desc;
74     } else {
75         prop_kind = dd->prop_kind;
76         src_md = &dd->diff_dst_desc;
77         dst_md = &dd->src_desc;
78         d_weights_d = dd->diff_weights_desc;
79     }
80     with_groups = d_weights_d.ndims == src_md->ndims + 1;
81
82     /* create weights desc for convolution */
83     c_weights_d = d_weights_d;
84     nstl::swap(c_weights_d.dims[with_groups + 0], c_weights_d.dims[with_groups + 1]);
85     if (c_weights_d.format != any)
86     {
87         if (utils::one_of(c_weights_d.format, gOIhw8i16o2i, OIhw8i16o2i,
88             gOIhw8o16i2o, OIhw8o16i2o, gOIhw4i16o4i, OIhw4i16o4i))
89             return unimplemented;
90         CHECK( compute_blocked_format(with_groups, &d_weights_d, &c_weights_d));
91     }
92     return conv_desc_init(cd, prop_kind, alg_kind, src_md, &(c_weights_d),
93             (prop_kind != backward_weights ? &(dd->bias_desc) : nullptr),
94             dst_md, dd->strides, dd->dilates,
95             dd->padding[0], dd->padding[1], dd->padding_kind);
96 }
97
98 struct ref_deconvolution_fwd_t: public cpu_primitive_t {
99     struct pd_t: public cpu_deconvolution_fwd_pd_t {
100         pd_t(engine_t *engine,
101                 const deconvolution_desc_t *adesc,
102                 const primitive_attr_t *attr,
103                 const deconvolution_fwd_pd_t *hint_fwd_pd)
104             : cpu_deconvolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
105             , conv_pd_(nullptr)
106         {}
107
108         pd_t(const pd_t &other)
109             : cpu_deconvolution_fwd_pd_t(other)
110             , conv_pd_(other.conv_pd_->clone())
111             , conv_supports_bias_(other.conv_supports_bias_)
112         {}
113
114         ~pd_t() { delete conv_pd_; }
115
116         DECLARE_DECONVOLUTION_PD_T(ref_deconvolution_fwd_t);
117
118         status_t init_convolution(){
119             using namespace memory_format;
120             using namespace types;
121             convolution_desc_t cd;
122             status_t status;
123
124             status = conv_descr_create(this->desc(), &cd);
125             if (status != status::success) return status;
126
127             mkldnn_primitive_desc_iterator it(this->engine_, (op_desc_t *)&cd,
128                 &(this->attr_), nullptr);
129             while (++it != it.end()) {
130                 conv_pd_ = *it;
131                 conv_supports_bias_ = static_cast<cpu_convolution_bwd_data_pd_t *>
132                     (conv_pd_)->support_bias();
133                 bool output_f32 = utils::everyone_is(data_type::f32,
134                         desc()->accum_data_type,
135                         desc()->dst_desc.data_type);
136                 auto wei_fmt =
137                     format_normalize(conv_pd_->weights_pd()->desc()->format);
138
139                 bool ok = true
140                     /* only weights in non-double-blocked format are supported */
141                     && (wei_fmt == blocked && !is_format_double_blocked(wei_fmt))
142                     /* deconv reference code can process only f32 bias */
143                     && IMPLICATION(with_bias(),
144                             conv_supports_bias_ || output_f32);
145                 if (ok)
146                     return success;
147                 delete conv_pd_;
148             }
149             return unimplemented;
150         };
151         virtual status_t init() override {
152             using namespace prop_kind;
153             assert(this->engine()->kind() == engine_kind::cpu);
154             bool ok = true
155                 && utils::one_of(this->desc()->prop_kind, forward_training,
156                         forward_inference)
157                 && utils::one_of(this->desc()->alg_kind,
158                         alg_kind::deconvolution_direct,
159                         alg_kind::deconvolution_winograd)
160                 && attr()->post_ops_.has_default_values();
161
162             if (ok) {
163                 CHECK(init_convolution());
164                 if (weights_pd_.desc()->format == memory_format::any)
165                 {
166                     CHECK(compute_blocked_format(with_groups(),
167                         conv_pd_->weights_pd()->desc(),
168                         &desc_.weights_desc));
169                     cpu_memory_pd_t weights(engine_, &desc_.weights_desc);
170                     weights_pd_ = weights;
171                 }
172                 if (src_pd_.desc()->format == memory_format::any)
173                     CHECK(src_pd_.set_format(conv_pd_->diff_dst_pd()->desc()->format));
174                 if (dst_pd_.desc()->format == memory_format::any)
175                     CHECK(dst_pd_.set_format(conv_pd_->diff_src_pd()->desc()->format));
176                 if (bias_pd_.desc()->format == memory_format::any)
177                     CHECK(bias_pd_.set_format(memory_format::x));
178                 return status::success;
179             }
180             else return status::unimplemented;
181         }
182         primitive_desc_t *conv_pd_;
183         bool conv_supports_bias_;
184     };
185
186     ref_deconvolution_fwd_t(const pd_t *apd, const input_vector &inputs,
187             const output_vector &outputs)
188         : cpu_primitive_t(apd, inputs, outputs), conv_p_(nullptr) {}
189
190     ~ref_deconvolution_fwd_t() { delete this->conv_p_; }
191
192     virtual void execute(event_t *e) const {
193         switch (pd()->desc()->prop_kind) {
194         case prop_kind::forward_training:
195         case prop_kind::forward_inference:
196             (conv_p_)->execute(e);
197             if (pd()->with_bias() && !pd()->conv_supports_bias_) {
198                 switch (pd()->dst_pd()->desc()->format) {
199                     case memory_format::nchw :
200                     case memory_format::ncdhw :
201                         compute_fwd_bias_ncdhw();
202                         break;
203                     case memory_format::nChw8c :
204                     case memory_format::nCdhw8c :
205                         compute_fwd_bias_nCdhwXc<8>();
206                         break;
207                     case memory_format::nChw16c :
208                     case memory_format::nCdhw16c :
209                         compute_fwd_bias_nCdhwXc<16>();
210                         break;
211                     default:
212                         compute_fwd_bias();
213                         break;
214                 }
215             }
216             break;
217         default:
218             assert(!"invalid prop_kind");
219         }
220         e->set_state(event_t::ready);
221     }
222
223 private:
224     void compute_fwd_bias() const;
225     void compute_fwd_bias_ncdhw() const;
226     template <int blksize> void compute_fwd_bias_nCdhwXc() const;
227     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
228     primitive_t *conv_p_;
229 };
230
231 struct ref_deconvolution_bwd_data_t: public cpu_primitive_t {
232     struct pd_t: public cpu_deconvolution_bwd_data_pd_t {
233         pd_t(engine_t *engine,
234                 const deconvolution_desc_t *adesc,
235                 const primitive_attr_t *attr,
236                 const deconvolution_fwd_pd_t *hint_fwd_pd)
237             : cpu_deconvolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd)
238             , conv_pd_(nullptr)
239         {}
240
241         pd_t(const pd_t &other)
242             : cpu_deconvolution_bwd_data_pd_t(other)
243             , conv_pd_(other.conv_pd_->clone()) {}
244
245         ~pd_t() { delete conv_pd_; }
246
247         DECLARE_DECONVOLUTION_PD_T(ref_deconvolution_bwd_data_t);
248
249         status_t init_convolution(){
250             using namespace memory_format;
251             using namespace types;
252             convolution_desc_t cd;
253             status_t status;
254
255             status = conv_descr_create(this->desc(), &cd);
256             if (status != status::success) return status;
257
258              mkldnn_primitive_desc_iterator it(this->engine_, (op_desc_t *)&cd,
259                 &(this->attr_), nullptr);
260              while (++it != it.end()) {
261                 conv_pd_ = *it;
262                 auto wei_fmt =
263                     format_normalize(conv_pd_->weights_pd()->desc()->format);
264                 /* only weights in non-double-blocked format are supported */
265                 if (wei_fmt == blocked && !is_format_double_blocked(wei_fmt))
266                     return success;
267                 delete conv_pd_;
268             }
269             return unimplemented;
270         };
271
272         virtual status_t init() override {
273             using namespace prop_kind;
274             using namespace data_type;
275             assert(this->engine()->kind() == engine_kind::cpu);
276             bool ok = true
277                 && this->desc()->prop_kind == backward_data
278                 && utils::everyone_is(data_type::f32,
279                         this->desc()->diff_src_desc.data_type,
280                         this->desc()->weights_desc.data_type,
281                         this->desc()->diff_dst_desc.data_type)
282                 && utils::one_of(this->desc()->alg_kind,
283                         alg_kind::deconvolution_direct,
284                         alg_kind::deconvolution_winograd);
285
286             if (ok) {
287                 CHECK(init_convolution());
288                 if (weights_pd_.desc()->format == memory_format::any)
289                 {
290                     CHECK(compute_blocked_format(with_groups(),
291                         conv_pd_->weights_pd()->desc(),
292                         &desc_.weights_desc));
293                     cpu_memory_pd_t weights(engine_, &desc_.weights_desc);
294                     weights_pd_ = weights;
295                 }
296                 if (diff_src_pd_.desc()->format == memory_format::any)
297                     CHECK(diff_src_pd_.set_format(conv_pd_->dst_pd()->desc()->format));
298                 if (diff_dst_pd_.desc()->format == memory_format::any)
299                     CHECK(diff_dst_pd_.set_format(conv_pd_->src_pd()->desc()->format));
300                 return status::success;
301             }
302             else return status::unimplemented;
303         }
304         primitive_desc_t *conv_pd_;
305     };
306     ref_deconvolution_bwd_data_t(const pd_t *apd, const input_vector &inputs,
307             const output_vector &outputs)
308         : cpu_primitive_t(apd, inputs, outputs), conv_p_(nullptr) {}
309     ~ref_deconvolution_bwd_data_t() { delete this->conv_p_; }
310
311     virtual void execute(event_t *e) const {
312         switch (pd()->desc()->prop_kind) {
313         case prop_kind::backward_data:
314             (conv_p_)->execute(e);
315             break;
316         default:
317             assert(!"invalid prop_kind");
318         }
319         e->set_state(event_t::ready);
320     }
321
322 private:
323     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
324     primitive_t *conv_p_;
325 };
326
327 struct ref_deconvolution_bwd_weights_t: public cpu_primitive_t {
328     struct pd_t: public cpu_deconvolution_bwd_weights_pd_t {
329         pd_t(engine_t *engine,
330                 const deconvolution_desc_t *adesc,
331                 const primitive_attr_t *attr,
332                 const deconvolution_fwd_pd_t *hint_fwd_pd)
333             : cpu_deconvolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd)
334             , conv_pd_(nullptr)
335         {}
336
337         pd_t(const pd_t &other)
338             : cpu_deconvolution_bwd_weights_pd_t(other)
339             , conv_pd_(other.conv_pd_->clone()) {}
340
341         ~pd_t() { delete conv_pd_; }
342
343         DECLARE_DECONVOLUTION_PD_T(ref_deconvolution_bwd_weights_t);
344
345         status_t init_convolution(){
346             using namespace memory_format;
347             using namespace types;
348             convolution_desc_t cd;
349             status_t status;
350
351             status = conv_descr_create(this->desc(), &cd);
352             if (status != status::success) return status;
353
354              mkldnn_primitive_desc_iterator it(this->engine_, (op_desc_t *)&cd,
355                 &(this->attr_), nullptr);
356              while (++it != it.end()) {
357                 conv_pd_ = *it;
358                 auto wei_fmt = format_normalize(
359                         conv_pd_->diff_weights_pd()->desc()->format);
360                 /* only weights in non-double-blocked format are supported */
361                 if (wei_fmt == blocked && !is_format_double_blocked(wei_fmt))
362                     return success;
363                 delete conv_pd_;
364             }
365             return unimplemented;
366         };
367
368         virtual status_t init() override {
369             using namespace prop_kind;
370             assert(this->engine()->kind() == engine_kind::cpu);
371             bool ok = true
372                 && this->desc()->prop_kind == backward_weights
373                 && utils::everyone_is(data_type::f32,
374                         this->desc()->src_desc.data_type,
375                         this->desc()->diff_weights_desc.data_type,
376                         this->desc()->diff_dst_desc.data_type)
377                 && utils::one_of(this->desc()->alg_kind,
378                         alg_kind::deconvolution_direct,
379                         alg_kind::deconvolution_winograd)
380                 && this->attr()->has_default_values();
381             if (ok) {
382                 CHECK(init_convolution());
383                 if (diff_weights_pd_.desc()->format == memory_format::any)
384                 {
385                     CHECK(compute_blocked_format(with_groups(),
386                         conv_pd_->diff_weights_pd()->desc(),
387                         &desc_.diff_weights_desc));
388                     cpu_memory_pd_t weights(engine_, &desc_.diff_weights_desc);
389                     diff_weights_pd_ = weights;
390                 }
391                 if (src_pd_.desc()->format == memory_format::any)
392                     CHECK(src_pd_.set_format(conv_pd_->diff_dst_pd()->desc()->format));
393                 if (diff_dst_pd_.desc()->format == memory_format::any)
394                     CHECK(diff_dst_pd_.set_format(conv_pd_->src_pd()->desc()->format));
395                 if (diff_bias_pd_.desc()->format == memory_format::any)
396                     CHECK(diff_bias_pd_.set_format(memory_format::x));
397                 return status::success;
398             }
399             else return status::unimplemented;
400         }
401         primitive_desc_t *conv_pd_;
402     };
403
404     ref_deconvolution_bwd_weights_t(const pd_t *apd, const input_vector &inputs,
405             const output_vector &outputs)
406         : cpu_primitive_t(apd, inputs, outputs), conv_p_(nullptr) {}
407
408     ~ref_deconvolution_bwd_weights_t() { delete this->conv_p_; }
409
410     typedef typename prec_traits<data_type::f32>::type data_t;
411
412     virtual void execute(event_t *e) const {
413         switch (pd()->desc()->prop_kind) {
414         case prop_kind::backward_weights:
415             (conv_p_)->execute(e);
416             if (pd()->with_bias()) {
417                 switch (pd()->diff_dst_pd()->desc()->format) {
418                     case memory_format::nchw :
419                     case memory_format::ncdhw :
420                         compute_bwd_bias_ncdhw();
421                         break;
422                     case memory_format::nChw8c :
423                         compute_bwd_bias_nCdhwXc<8>();
424                         break;
425                     case memory_format::nChw16c :
426                     case memory_format::nCdhw16c :
427                         compute_bwd_bias_nCdhwXc<16>();
428                         break;
429                     default:
430                         compute_bwd_bias();
431                         break;
432                 }
433             }
434             break;
435         default:
436             assert(!"invalid prop_kind");
437         }
438         e->set_state(event_t::ready);
439     }
440
441 private:
442     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
443     primitive_t *conv_p_;
444     void compute_bwd_bias() const;
445     void compute_bwd_bias_ncdhw() const;
446     template <int blksize> void compute_bwd_bias_nCdhwXc() const;
447 };
448
449 }
450 }
451 }
452 #endif
453
454 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s