updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / benchdnn / conv / deconv.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <stdlib.h>
18 #include <stdio.h>
19 #include <float.h>
20 #include <math.h>
21
22 #include "mkldnn.h"
23
24 #include "src/common/mkldnn_thread.hpp"
25
26 #include "mkldnn_common.hpp"
27 #include "mkldnn_memory.hpp"
28
29 #include "norm.hpp"
30
31 #include "conv/deconv.hpp"
32 using namespace conv;
33
34 namespace deconv {
35
36 template <typename T>
37 inline static void swap(T &a, T &b) {
38     T temp = a;
39     a = b;
40     b = temp;
41 }
42
43 inline bool is_deconv_3d(const prb_t *p) {
44     return p->id > 1;
45 }
46
47 inline bool is_deconv_1d(const prb_t *p) {
48     return !is_deconv_3d(p) && p->ih == 1 && p->kh == 1;
49 }
50
51 inline int transpose_data_wei(const prb_t *p, dnn_mem_t &wei, dnn_mem_t &wei_tr) {
52     mkldnn::impl::parallel_nd(
53         p->g, p->oc / p->g, p->ic / p->g, p->kd, p->kh, p->kw,
54         [&](int g, int oc, int ic, int kd, int kh, int kw) {
55         size_t idx = (((((size_t)g * p->ic / p->g + ic) * p->oc / p->g + oc)
56         * p->kd + kd) * p->kh + kh) * p->kw + kw;
57         ((float*)wei_tr)[idx] = ((float*)wei)[wei_off_f(p, g, oc, ic, kd, kh, kw)];
58     });
59
60     return OK;
61 }
62
63 inline int init_pd(const prb_t *p, mkldnn_deconvolution_desc_t &cd,
64         mkldnn_primitive_desc_t &dpd, res_t *r) {
65     int ndims = is_deconv_3d(p) ? 5 : is_deconv_1d(p) ? 3 : 4;
66
67     mkldnn_memory_desc_t src_d, wei_d, bia_d, dst_d;
68     mkldnn_dims_t src_1d_dims = {p->mb, p->ic, p->iw};
69     mkldnn_dims_t src_2d_dims = {p->mb, p->ic, p->ih, p->iw};
70     mkldnn_dims_t src_3d_dims = {p->mb, p->ic, p->id, p->ih, p->iw};
71     mkldnn_dims_t wei_1d_dims = {p->g, p->oc / p->g, p->ic / p->g, p->kw};
72     mkldnn_dims_t wei_2d_dims = {p->g, p->oc / p->g, p->ic / p->g, p->kh, p->kw};
73     mkldnn_dims_t wei_3d_dims = {p->g, p->oc / p->g, p->ic / p->g, p->kd, p->kh, p->kw};
74     mkldnn_dims_t bia_dims = {p->oc};
75     mkldnn_dims_t dst_1d_dims = {p->mb, p->oc, p->ow};
76     mkldnn_dims_t dst_2d_dims = {p->mb, p->oc, p->oh, p->ow};
77     mkldnn_dims_t dst_3d_dims = {p->mb, p->oc, p->od, p->oh, p->ow};
78     DNN_SAFE(mkldnn_memory_desc_init(&src_d, ndims,
79         is_deconv_3d(p) ? src_3d_dims : is_deconv_1d(p) ? src_1d_dims : src_2d_dims,
80         p->cfg[SRC].dt, mkldnn_any),
81             WARN);
82     DNN_SAFE(mkldnn_memory_desc_init(&wei_d, ndims + p->has_groups,
83         is_deconv_3d(p)
84         ? &wei_3d_dims[!p->has_groups]
85         : is_deconv_1d(p)
86         ? &wei_1d_dims[!p->has_groups]
87         : &wei_2d_dims[!p->has_groups],
88         p->cfg[WEI].dt, mkldnn_any), WARN);
89     DNN_SAFE(mkldnn_memory_desc_init(&bia_d, 1, bia_dims, p->cfg[BIA].dt, mkldnn_any), WARN);
90     DNN_SAFE(mkldnn_memory_desc_init(&dst_d, ndims,
91         is_deconv_3d(p) ? dst_3d_dims : is_deconv_1d(p) ? dst_1d_dims : dst_2d_dims,
92         p->cfg[DST].dt, mkldnn_any), WARN);
93
94     ptrdiff_t strides_nd[] = {p->sd, p->sh, p->sw};
95     ptrdiff_t dilates_nd[] = {p->dd, p->dh, p->dw};
96     ptrdiff_t padding_nd[] = {p->pd, p->ph, p->pw};
97
98     auto bph = [&](int ih, int oh, int kh, int sh, int ph, int dh) {
99         return (oh - 1) * sh - ih + ((kh - 1) * (dh + 1) + 1) - ph;
100     };
101
102     ptrdiff_t padding_r_nd[] = {
103         bph(p->od, p->id, p->kd, p->sd, p->pd, p->dd),
104         bph(p->oh, p->ih, p->kh, p->sh, p->ph, p->dh),
105         bph(p->ow, p->iw, p->kw, p->sw, p->pw, p->dw)};
106
107     ptrdiff_t *strides = strides_nd + (5 - ndims);
108     ptrdiff_t *dilates = dilates_nd + (5 - ndims);
109     ptrdiff_t *padding = padding_nd + (5 - ndims);
110     ptrdiff_t *padding_r = padding_r_nd + (5 - ndims);
111
112     mkldnn_alg_kind_t alg = mkldnn_deconvolution_direct;
113     if (p->alg == WINO) alg = mkldnn_deconvolution_winograd;
114
115     switch (p->dir) {
116     case FWD_D: case FWD_B:
117         DNN_SAFE(mkldnn_dilated_deconvolution_forward_desc_init(&cd,
118                     mkldnn_forward_inference, alg, &src_d, &wei_d,
119                     p->dir == FWD_D ? NULL : &bia_d, &dst_d, strides,
120                     dilates, padding, padding_r, mkldnn_padding_zero), WARN);
121         break;
122     case BWD_D:
123         DNN_SAFE(mkldnn_dilated_deconvolution_backward_data_desc_init(&cd, alg,
124                     &src_d, &wei_d, &dst_d, strides, dilates, padding,
125                     padding_r, mkldnn_padding_zero), WARN);
126         break;
127     case BWD_W: case BWD_WB:
128         DNN_SAFE(mkldnn_dilated_deconvolution_backward_weights_desc_init(&cd,
129                     alg, &src_d, &wei_d, p->dir == BWD_W ? NULL : &bia_d,
130                     &dst_d, strides, dilates,  padding, padding_r,
131                     mkldnn_padding_zero), WARN);
132         break;
133     default: DNN_SAFE(mkldnn_invalid_arguments, CRIT);
134     }
135
136     DNN_SAFE(cd.accum_data_type == p->cfg[ACC].dt
137             ? mkldnn_success : mkldnn_unimplemented, CRIT);
138
139     auto mkldnn_attr = create_mkldnn_attr(p->attr, p->oc, p->scales);
140
141     mkldnn_status_t init_status = mkldnn_success;
142     init_status = mkldnn_primitive_desc_create_v2(&dpd, &cd, mkldnn_attr,
143             engine, NULL);
144
145     mkldnn_primitive_attr_destroy(mkldnn_attr);
146
147     if (init_status == mkldnn_unimplemented)
148     {
149         return r->state = UNIMPLEMENTED, OK;
150     } else
151         SAFE(init_status, WARN);
152
153     const char *impl_str = query_impl_info(dpd);
154     if (maybe_skip(skip_impl, impl_str)) {
155         print(2, "SKIPPED: mkldnn implementation: %s\n", impl_str);
156         DNN_SAFE(mkldnn_primitive_desc_destroy(dpd), WARN);
157         return r->state = SKIPPED, OK;
158     } else {
159         print(5, "mkldnn implementation: %s\n", impl_str);
160     }
161
162     auto q = [=](mkldnn_query_t query, int index = 0) {
163         return *mkldnn_primitive_desc_query_memory_d(
164                 mkldnn_primitive_desc_query_pd(dpd, query, index));
165     };
166
167     if (p->dir == BWD_D)
168         cd.diff_src_desc = q(mkldnn_query_diff_src_pd);
169     else
170         cd.src_desc = q(mkldnn_query_src_pd);
171
172     if (p->dir & FLAG_WEI)
173         cd.diff_weights_desc = q(mkldnn_query_diff_weights_pd);
174     else
175         cd.weights_desc = q(mkldnn_query_weights_pd);
176
177     if (p->dir & FLAG_BIA) {
178         if (p->dir & FLAG_BWD)
179             cd.diff_bias_desc = q(mkldnn_query_diff_weights_pd, 1);
180         else
181             cd.bias_desc = q(mkldnn_query_weights_pd, 1);
182     }
183
184     if (p->dir & FLAG_BWD)
185         cd.diff_dst_desc = q(mkldnn_query_diff_dst_pd);
186     else
187         cd.dst_desc = q(mkldnn_query_dst_pd);
188
189     return OK;
190 }
191 int doit(const prb_t *p, res_t *r) {
192     res_t res_zero{};
193     *r = res_zero;
194     bool with_groups = 1;
195
196     prb_t p_tr((desc_t)*p, p->dir, p->cfg, p->alg, p->attr, p->mb, true);
197     swap(p_tr.ic,  p_tr.oc);
198     swap(p_tr.ih,  p_tr.oh);
199     swap(p_tr.id,  p_tr.od);
200     swap(p_tr.iw,  p_tr.ow);
201
202     mkldnn_deconvolution_desc_t cd;
203     mkldnn_primitive_desc_t dpd;
204     mkldnn_primitive_t c{};
205
206     SAFE(init_pd(p, cd, dpd, r), WARN);
207     if (r->state == SKIPPED || r->state == UNIMPLEMENTED)
208         return OK;
209
210     auto &src_dt_d = p->dir == BWD_D ? cd.diff_src_desc : cd.src_desc;
211     auto &wei_dt_d = p->dir & FLAG_WEI ? cd.diff_weights_desc : cd.weights_desc;
212     auto &bia_dt_d = p->dir & FLAG_BWD ? cd.diff_bias_desc : cd.bias_desc;
213     auto &dst_dt_d = p->dir & FLAG_BWD ? cd.diff_dst_desc: cd.dst_desc;
214     auto wei_tr_dt_d = wei_dt_d;
215     swap(wei_tr_dt_d.dims[with_groups+0], wei_tr_dt_d.dims[with_groups+1]);
216
217     dnn_mem_t src_dt(src_dt_d, p->cfg[SRC].dt);
218     dnn_mem_t wei_dt(wei_dt_d, p->cfg[WEI].dt);
219     dnn_mem_t dst_dt(dst_dt_d, p->cfg[DST].dt);
220     dnn_mem_t *p_bia_dt = p->dir & FLAG_BIA
221         ? new dnn_mem_t(bia_dt_d, p->cfg[BIA].dt) : new dnn_mem_t();
222     dnn_mem_t &bia_dt = *p_bia_dt;
223
224     auto src_format = get_default_format(src_dt.md_.ndims, DATA);
225     auto wei_format = get_default_format(wei_dt.md_.ndims,
226         p->has_groups ? GWEI : WEI);
227
228     const auto fp = mkldnn_f32;
229
230     /* memory for ref */
231     dnn_mem_t src_fp(src_dt_d, fp, src_format);
232     dnn_mem_t wei_fp(wei_dt_d, fp, wei_format);
233     dnn_mem_t dst_fp(dst_dt_d, fp, src_format);
234     dnn_mem_t wei_tr_fp(wei_tr_dt_d, fp, wei_format);
235     dnn_mem_t *p_bia_fp = p->dir & FLAG_BIA
236         ? new dnn_mem_t(bia_dt_d, fp, mkldnn_x) : new dnn_mem_t();
237     dnn_mem_t *p_zero_fp = new dnn_mem_t();
238     dnn_mem_t &bia_fp = *p_bia_fp, &zero_fp = *p_zero_fp;
239
240     /* fill memory + reorders <-> */
241     SAFE(fill_dst(p, dst_dt, dst_fp, r), WARN);
242     SAFE(fill_wei(p, wei_dt, wei_fp, r), WARN);
243     SAFE(fill_src(p, src_dt, src_fp, r), WARN);
244
245     SAFE(transpose_data_wei(p, wei_fp, wei_tr_fp), WARN);
246     if (p->dir & FLAG_BIA)
247         SAFE(fill_bia(p, bia_dt, bia_fp, r), WARN);
248     if (p->dir & FLAG_FWD) {
249         mkldnn_primitive_at_t inputs[3] = { {src_dt.p_, 0}, {wei_dt.p_, 0},
250             {p->dir & FLAG_BIA ? bia_dt.p_ : NULL, 0}
251         };
252         const_mkldnn_primitive_t outputs[] = { dst_dt.p_ };
253         DNN_SAFE(mkldnn_primitive_create(&c, dpd, inputs, outputs), WARN);
254         SAFE(execute(c), WARN);
255         if (bench_mode & CORR) {
256             compute_ref_bwd_d(&p_tr, dst_fp, wei_tr_fp, bia_fp, src_fp);
257             dnn_mem_t dst(dst_dt, fp, src_format);
258             SAFE(compare_dst(p, dst, dst_fp, r, true), WARN);
259         }
260     } else if (p->dir == BWD_D) {
261         mkldnn_primitive_at_t inputs[3] = { {dst_dt.p_, 0}, {wei_dt.p_, 0}, };
262         const_mkldnn_primitive_t outputs[] = { src_dt.p_ };
263         DNN_SAFE(mkldnn_primitive_create(&c, dpd, inputs, outputs), WARN);
264         SAFE(execute(c), WARN);
265         if (bench_mode & CORR) {
266             compute_ref_fwd(&p_tr, dst_fp, wei_tr_fp, zero_fp, src_fp);
267             dnn_mem_t src(src_dt, fp, src_format);
268             SAFE(compare_src(p, src, src_fp, r, true), WARN);
269         }
270     } else if (p->dir & FLAG_BWD && p->dir & FLAG_WEI) {
271         mkldnn_primitive_at_t inputs[3] = { {src_dt.p_, 0}, {dst_dt.p_, 0}, };
272         const_mkldnn_primitive_t outputs[] = { wei_dt.p_,
273             p->dir & FLAG_BIA ? bia_dt.p_ : NULL,
274         };
275         DNN_SAFE(mkldnn_primitive_create(&c, dpd, inputs, outputs), WARN);
276         SAFE(execute(c), WARN);
277         if (bench_mode & CORR) {
278             compute_ref_bwd_weights(&p_tr, dst_fp, wei_tr_fp, src_fp);
279             transpose_data_wei(&p_tr, wei_tr_fp, wei_fp);
280             dnn_mem_t wei(wei_dt, fp, wei_format);
281             SAFE(compare_wei(&p_tr, wei, wei_fp, r, true), WARN);
282             if (p->dir & FLAG_BIA) {
283                 compute_ref_bwd_bias(p, bia_fp, dst_fp);
284                 dnn_mem_t bia(bia_dt, fp, mkldnn_x);
285                 SAFE(compare_bia(p, bia, bia_fp, r, true), WARN);
286             }
287         }
288     } else {
289         delete p_bia_dt;
290         delete p_bia_fp;
291         delete p_zero_fp;
292         SAFE(FAIL, CRIT);
293     }
294
295     if (bench_mode & PERF) {
296         auto &t = r->timer;
297         t.reset();
298         while (true) {
299             SAFE(execute(c), WARN);
300             t.stamp();
301             const bool stop = false
302                 || (fix_times_per_prb && t.times() >= fix_times_per_prb)
303                 || (!fix_times_per_prb
304                         && t.total_ms() >= max_ms_per_prb
305                         && t.times() >= min_times_per_prb);
306             if (stop) break;
307         }
308     }
309
310     DNN_SAFE_V(mkldnn_primitive_destroy(c));
311     DNN_SAFE_V(mkldnn_primitive_desc_destroy(dpd));
312
313     delete p_bia_dt;
314     delete p_bia_fp;
315     delete p_zero_fp;
316
317    return OK;
318 }
319
320 }