Publishing R3
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / ref_deconvolution.hpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef CPU_REF_DECONVOLUTION_HPP
18 #define CPU_REF_DECONVOLUTION_HPP
19
20 #include <assert.h>
21 #include <string.h>
22
23 #include "c_types_map.hpp"
24 #include "cpu_deconvolution_pd.hpp"
25 #include "cpu_engine.hpp"
26 #include "type_helpers.hpp"
27 #include "utils.hpp"
28 #include "primitive_iterator.hpp"
29
30 #define DECLARE_DECONVOLUTION_PD_t(impl_name, ...) \
31     virtual pd_t *clone() const override { return new pd_t(*this); } \
32     virtual status_t create_primitive(primitive_t **primitive, \
33                 const primitive_at_t *inputs, \
34                 const primitive_t **outputs) const override { \
35         double ms = get_msec(); \
36         using namespace prop_kind;\
37         primitive_t::input_vector ins(inputs, inputs + this->n_inputs()); \
38         primitive_t::output_vector outs(outputs, outputs + this->n_outputs()); \
39         auto ret = safe_ptr_assign<primitive_t>(*primitive, \
40                                 new (__VA_ARGS__)(this, ins, outs)); \
41         primitive_t *conv_primitive; \
42         if (utils::one_of(this->desc()->prop_kind, backward, backward_weights)) {\
43             primitive_at_t conv_inputs[2];\
44             conv_inputs[0] = inputs[1];\
45             conv_inputs[1] = inputs[0];\
46             conv_pd_->create_primitive((&conv_primitive), conv_inputs, outputs);\
47         } \
48         else conv_pd_->create_primitive((&conv_primitive), inputs, outputs);\
49         ((__VA_ARGS__ *)(*primitive))->conv_p_ = conv_primitive;\
50         ms = get_msec() - ms; \
51         if (mkldnn_verbose()->level >= 2) { \
52                         printf("mkldnn_verbose,create,%s,%g\n", this->info(), ms); \
53                         fflush(0); \
54                     } \
55         return ret; \
56     } \
57 virtual const char *name() const override { return impl_name; }
58
59 #define DECLARE_DECONVOLUTION_PD_T(impl_name, ...) \
60         DECLARE_DECONVOLUTION_PD_t(impl_name,  __VA_ARGS__)
61
62
63 namespace mkldnn {
64 namespace impl {
65 namespace cpu {
66
67 static status_t compute_blocked_format(bool with_groups,
68     const memory_desc_t *oi_md, memory_desc_t *io_md)
69 {
70     /* Computes blocking for *i*o* format from *o*i* format */
71     if (oi_md->ndims != io_md->ndims) return status::invalid_arguments;
72     blocking_desc_t oi_blk = oi_md->layout_desc.blocking,
73         &io_blk = io_md->layout_desc.blocking;
74     io_blk = oi_blk;
75     nstl::swap(io_blk.strides[0][0+with_groups], io_blk.strides[0][1+with_groups]);
76     nstl::swap(io_blk.strides[1][0+with_groups], io_blk.strides[1][1+with_groups]);
77     nstl::swap(io_blk.padding_dims[0+with_groups], io_blk.padding_dims[1+with_groups]);
78     nstl::swap(io_blk.offset_padding_to_data[0+with_groups],
79          io_blk.offset_padding_to_data[1+with_groups]);
80     nstl::swap(io_blk.block_dims[0+with_groups], io_blk.block_dims[1+with_groups]);
81     io_md->format = memory_format::blocked;
82     return status::success;
83 }
84
85 static status_t conv_descr_create(const deconvolution_desc_t *dd,
86         convolution_desc_t *cd)
87 {
88     using namespace prop_kind;
89     using namespace memory_format;
90     alg_kind_t alg_kind = ( dd->alg_kind == alg_kind::deconvolution_direct
91         ? alg_kind::convolution_direct : alg_kind::convolution_winograd );
92     prop_kind_t prop_kind;
93     const memory_desc_t *src_md, *dst_md;
94     memory_desc_t c_weights_d, d_weights_d;
95     bool with_groups;
96     if ( utils::one_of(dd->prop_kind, forward_training, forward_inference) ) {
97         prop_kind = backward_data;
98         src_md = &dd->dst_desc;
99         dst_md = &dd->src_desc;
100         d_weights_d = dd->weights_desc;
101     }
102     else if( utils::one_of(dd->prop_kind, backward, backward_data) ) {
103         prop_kind = forward_training;
104         src_md = &dd->diff_dst_desc;
105         dst_md = &dd->diff_src_desc;
106         d_weights_d = dd->weights_desc;
107     }
108     else {
109         prop_kind = dd->prop_kind;
110         src_md = &dd->diff_dst_desc;
111         dst_md = &dd->src_desc;
112         d_weights_d = dd->diff_weights_desc;
113     }
114     with_groups = d_weights_d.ndims == src_md->ndims + 1;
115
116     /* create weights desc for convolution */
117     c_weights_d = d_weights_d;
118     nstl::swap(c_weights_d.dims[with_groups + 0], c_weights_d.dims[with_groups + 1]);
119     if (c_weights_d.format != any)
120     {
121         if (utils::one_of(c_weights_d.format, gOIhw8i16o2i, OIhw8i16o2i,
122             gOIhw8o16i2o, OIhw8o16i2o, gOIhw4i16o4i, OIhw4i16o4i))
123             return unimplemented;
124         CHECK( compute_blocked_format(with_groups, &d_weights_d, &c_weights_d));
125     }
126     return conv_desc_init(cd, prop_kind, alg_kind,
127             src_md, &(c_weights_d),
128             ( (utils::one_of(dd->prop_kind, backward, backward_data))?
129             (&(dd->bias_desc)): nullptr ), dst_md, dd->strides, nullptr,
130             dd->padding[0], dd->padding[1], dd->padding_kind);
131 }
132
133 struct ref_deconvolution_fwd_t: public cpu_primitive_t {
134     struct pd_t: public cpu_deconvolution_fwd_pd_t {
135         pd_t(engine_t *engine,
136                 const deconvolution_desc_t *adesc,
137                 const primitive_attr_t *attr,
138                 const deconvolution_fwd_pd_t *hint_fwd_pd)
139             : cpu_deconvolution_fwd_pd_t(engine, adesc, attr,
140                     hint_fwd_pd)
141         {}
142
143         DECLARE_DECONVOLUTION_PD_T("ref:any", ref_deconvolution_fwd_t);
144
145         status_t init_convolution(){
146             using namespace memory_format;
147             convolution_desc_t cd;
148             status_t status;
149
150             status = conv_descr_create(this->cdesc(), &cd);
151             if (status != status::success) return status;
152
153             mkldnn_primitive_desc_iterator it(this->engine_, (op_desc_t *)&cd,
154                 &(this->attr_), nullptr);
155             while (++it != it.end()) {
156                 conv_pd_ = *it;
157                 const memory_desc_t *md = conv_pd_->weights_pd()->desc();
158                 /* double blocked format is not supported */
159                 if (!utils::one_of(md->format, gOIhw8i16o2i, OIhw8i16o2i,
160                     gOIhw8o16i2o, OIhw8o16i2o, gOIhw4i16o4i, OIhw4i16o4i))
161                     return success;
162             }
163             return unimplemented;
164
165         };
166         virtual status_t init() override {
167             using namespace prop_kind;
168             using namespace data_type;
169             assert(this->engine()->kind() == engine_kind::cpu);
170             bool ok = true
171                 && utils::one_of(this->desc()->prop_kind, forward_training,
172                         forward_inference)
173                 && utils::everyone_is(data_type::f32,
174                         this->desc()->src_desc.data_type,
175                         this->desc()->weights_desc.data_type,
176                         this->desc()->dst_desc.data_type)
177                 && utils::one_of(this->desc()->alg_kind,
178                         alg_kind::deconvolution_direct,
179                         alg_kind::deconvolution_winograd)
180                && this->attr()->has_default_values();
181             if (ok) {
182                 CHECK(init_convolution());
183                 if (weights_pd_.desc()->format == memory_format::any)
184                 {
185                     CHECK(compute_blocked_format(with_groups(),
186                         conv_pd_->weights_pd()->desc(),
187                         &desc_.weights_desc));
188                     cpu_memory_pd_t weights(engine_, &desc_.weights_desc);
189                     weights_pd_ = weights;
190                 }
191                 if (src_pd_.desc()->format == memory_format::any)
192                     CHECK(src_pd_.set_format(conv_pd_->diff_dst_pd()->desc()->format));
193                 if (dst_pd_.desc()->format == memory_format::any)
194                     CHECK(dst_pd_.set_format(conv_pd_->diff_src_pd()->desc()->format));
195                 if (bias_pd_.desc()->format == memory_format::any)
196                     CHECK(bias_pd_.set_format(memory_format::x));
197                 return status::success;
198             }
199             else return status::unimplemented;
200         }
201         primitive_desc_t *conv_pd_;
202     };
203
204     ref_deconvolution_fwd_t(const pd_t *pd, const input_vector &inputs,
205             const output_vector &outputs)
206         : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd) {}
207
208     typedef typename prec_traits<data_type::f32>::type data_t;
209
210     virtual void execute(event_t *e) {
211         switch (conf_.desc()->prop_kind) {
212         case prop_kind::forward_training:
213         case prop_kind::forward_inference:
214             (conv_p_)->execute(e);
215             if (conf_.with_bias()) {
216                 switch (conf_.dst_pd()->desc()->format) {
217                     case memory_format::nchw :
218                     case memory_format::ncdhw :
219                         compute_fwd_bias_ncdhw();
220                         break;
221                     case memory_format::nChw8c :
222                         compute_fwd_bias_nCdhwXc<8>();
223                         break;
224                     case memory_format::nChw16c :
225                     case memory_format::nCdhw16c :
226                         compute_fwd_bias_nCdhwXc<16>();
227                         break;
228                     default:
229                         compute_fwd_bias();
230                         break;
231                 }
232             }
233             break;
234         default:
235             assert(!"invalid prop_kind");
236         }
237         e->set_state(event_t::ready);
238     }
239
240 private:
241     void compute_fwd_bias();
242     void compute_fwd_bias_ncdhw();
243     template <int blksize> void compute_fwd_bias_nCdhwXc();
244     pd_t conf_;
245     primitive_t *conv_p_;
246 };
247
248 struct ref_deconvolution_bwd_data_t: public cpu_primitive_t {
249     struct pd_t: public cpu_deconvolution_bwd_data_pd_t {
250         pd_t(engine_t *engine,
251                 const deconvolution_desc_t *adesc,
252                 const primitive_attr_t *attr,
253                 const deconvolution_fwd_pd_t *hint_fwd_pd)
254             : cpu_deconvolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd)
255         {}
256
257         DECLARE_DECONVOLUTION_PD_T("ref:any", ref_deconvolution_bwd_data_t);
258
259         status_t init_convolution(){
260             using namespace memory_format;
261             convolution_desc_t cd;
262             status_t status;
263
264             status = conv_descr_create(this->cdesc(), &cd);
265             if (status != status::success) return status;
266
267              mkldnn_primitive_desc_iterator it(this->engine_, (op_desc_t *)&cd,
268                 &(this->attr_), nullptr);
269              while (++it != it.end()) {
270                 conv_pd_ = *it;
271                 const memory_desc_t *md = conv_pd_->weights_pd()->desc();
272                 /* double blocked format is not supported */
273                 if (!utils::one_of(md->format, gOIhw8i16o2i, OIhw8i16o2i,
274                     gOIhw8o16i2o, OIhw8o16i2o, gOIhw4i16o4i, OIhw4i16o4i))
275                     return success;
276             }
277             return unimplemented;
278         };
279
280         virtual status_t init() override {
281             using namespace prop_kind;
282             using namespace data_type;
283             assert(this->engine()->kind() == engine_kind::cpu);
284             bool ok = true
285                 && utils::one_of(this->desc()->prop_kind, backward,
286                         backward_data)
287                 && utils::everyone_is(data_type::f32,
288                         this->desc()->diff_src_desc.data_type,
289                         this->desc()->weights_desc.data_type,
290                         this->desc()->diff_dst_desc.data_type)
291                 && utils::one_of(this->desc()->alg_kind,
292                         alg_kind::deconvolution_direct,
293                         alg_kind::deconvolution_winograd);
294             if (ok) {
295                 CHECK(init_convolution());
296                 if (weights_pd_.desc()->format == memory_format::any)
297                 {
298                     CHECK(compute_blocked_format(with_groups(),
299                         conv_pd_->weights_pd()->desc(),
300                         &desc_.weights_desc));
301                     cpu_memory_pd_t weights(engine_, &desc_.weights_desc);
302                     weights_pd_ = weights;
303                 }
304                 if (diff_src_pd_.desc()->format == memory_format::any)
305                     CHECK(diff_src_pd_.set_format(conv_pd_->dst_pd()->desc()->format));
306                 if (diff_dst_pd_.desc()->format == memory_format::any)
307                     CHECK(diff_dst_pd_.set_format(conv_pd_->src_pd()->desc()->format));
308                 return status::success;
309             }
310             else return status::unimplemented;
311         }
312         primitive_desc_t *conv_pd_;
313     };
314     ref_deconvolution_bwd_data_t(const pd_t *pd, const input_vector &inputs,
315             const output_vector &outputs)
316         : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd) {}
317
318     virtual void execute(event_t *e) {
319         switch (conf_.desc()->prop_kind) {
320         case prop_kind::backward:
321         case prop_kind::backward_data:
322             (conv_p_)->execute(e);
323             break;
324         default:
325             assert(!"invalid prop_kind");
326         }
327         e->set_state(event_t::ready);
328     }
329
330 private:
331     pd_t conf_;
332     primitive_t *conv_p_;
333 };
334
335 struct ref_deconvolution_bwd_weights_t: public cpu_primitive_t {
336     struct pd_t: public cpu_deconvolution_bwd_weights_pd_t {
337         pd_t(engine_t *engine,
338                 const deconvolution_desc_t *adesc,
339                 const primitive_attr_t *attr,
340                 const deconvolution_fwd_pd_t *hint_fwd_pd)
341             : cpu_deconvolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd)
342         {}
343
344         DECLARE_DECONVOLUTION_PD_T("ref:any", ref_deconvolution_bwd_weights_t);
345
346         status_t init_convolution(){
347             using namespace memory_format;
348             convolution_desc_t cd;
349             status_t status;
350
351             status = conv_descr_create(this->cdesc(), &cd);
352             if (status != status::success) return status;
353
354              mkldnn_primitive_desc_iterator it(this->engine_, (op_desc_t *)&cd,
355                 &(this->attr_), nullptr);
356              while (++it != it.end()) {
357                 conv_pd_ = *it;
358                 const memory_desc_t *md = conv_pd_->diff_weights_pd()->desc();
359                 /* double blocked format is not supported */
360                 if (!utils::one_of(md->format, gOIhw8i16o2i, OIhw8i16o2i,
361                     gOIhw8o16i2o, OIhw8o16i2o, gOIhw4i16o4i, OIhw4i16o4i))
362                     return success;
363             }
364             return unimplemented;
365         };
366
367         virtual status_t init() override {
368             using namespace prop_kind;
369             assert(this->engine()->kind() == engine_kind::cpu);
370             bool ok = true
371                 && utils::one_of(this->desc()->prop_kind, backward,
372                         backward_weights)
373                 && utils::everyone_is(data_type::f32,
374                         this->desc()->src_desc.data_type,
375                         this->desc()->diff_weights_desc.data_type,
376                         this->desc()->diff_dst_desc.data_type)
377                 && utils::one_of(this->desc()->alg_kind,
378                         alg_kind::deconvolution_direct,
379                         alg_kind::deconvolution_winograd)
380                 && this->attr()->has_default_values();
381             if (ok) {
382                 CHECK(init_convolution());
383                 if (diff_weights_pd_.desc()->format == memory_format::any)
384                 {
385                     CHECK(compute_blocked_format(with_groups(),
386                         conv_pd_->diff_weights_pd()->desc(),
387                         &desc_.diff_weights_desc));
388                     cpu_memory_pd_t weights(engine_, &desc_.diff_weights_desc);
389                     diff_weights_pd_ = weights;
390                 }
391                 if (src_pd_.desc()->format == memory_format::any)
392                     CHECK(src_pd_.set_format(conv_pd_->diff_dst_pd()->desc()->format));
393                 if (diff_dst_pd_.desc()->format == memory_format::any)
394                     CHECK(diff_dst_pd_.set_format(conv_pd_->src_pd()->desc()->format));
395                 if (diff_bias_pd_.desc()->format == memory_format::any)
396                     CHECK(diff_bias_pd_.set_format(memory_format::x));
397                 return status::success;
398             }
399             else return status::unimplemented;
400         }
401         primitive_desc_t *conv_pd_;
402     };
403
404     ref_deconvolution_bwd_weights_t(const pd_t *pd, const input_vector &inputs,
405             const output_vector &outputs)
406         : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd) {}
407
408     typedef typename prec_traits<data_type::f32>::type data_t;
409
410     virtual void execute(event_t *e) {
411         switch (conf_.desc()->prop_kind) {
412         case prop_kind::backward:
413         case prop_kind::backward_weights:
414             (conv_p_)->execute(e);
415             if (conf_.with_bias()) {
416                 switch (conf_.diff_dst_pd()->desc()->format) {
417                     case memory_format::nchw :
418                     case memory_format::ncdhw :
419                         compute_bwd_bias_ncdhw();
420                         break;
421                     case memory_format::nChw8c :
422                         compute_bwd_bias_nCdhwXc<8>();
423                         break;
424                     case memory_format::nChw16c :
425                     case memory_format::nCdhw16c :
426                         compute_bwd_bias_nCdhwXc<16>();
427                         break;
428                     default:
429                         compute_bwd_bias();
430                         break;
431                 }
432             }
433             break;
434         default:
435             assert(!"invalid prop_kind");
436         }
437         e->set_state(event_t::ready);
438     }
439
440 private:
441     pd_t conf_;
442     primitive_t *conv_p_;
443     void compute_bwd_bias();
444     void compute_bwd_bias_ncdhw();
445     template <int blksize> void compute_bwd_bias_nCdhwXc();
446 };
447
448 }
449 }
450 }
451 #endif
452
453 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s