Publishing R3
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / benchdnn / conv / deconv.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <stdlib.h>
18 #include <stdio.h>
19 #include <float.h>
20 #include <math.h>
21
22 #include "mkldnn.h"
23
24 #include "mkldnn_common.hpp"
25 #include "mkldnn_memory.hpp"
26
27 #include "norm.hpp"
28
29 #include "conv/deconv.hpp"
30 using namespace conv;
31
32 namespace deconv {
33
34 inline static void swap(int &a, int &b)
35 {
36     int temp = a;
37     a = b;
38     b = temp;
39 }
40 inline bool is_deconv_3d(const prb_t *p)
41 {
42     return (p->id > 1) ? 1 : 0;
43 }
44
45 inline int transpose_data_wei(const prb_t *p, dnn_mem_t &wei, dnn_mem_t &wei_tr) {
46 #   pragma omp parallel for collapse(5)
47     for (int g = 0; g < p->g; ++g)
48     for (int oc = 0; oc < p->oc / p->g; ++oc)
49     for (int ic = 0; ic < p->ic / p->g; ++ic)
50     for (int kd = 0; kd < p->kd; ++kd)
51     for (int kh = 0; kh < p->kh; ++kh)
52     for (int kw = 0; kw < p->kw; ++kw)
53     {
54         size_t idx = (((((size_t)g * p->ic / p->g + ic) * p->oc / p->g + oc)
55         * p->kd + kd) * p->kh + kh) * p->kw + kw;
56         ((float*)wei_tr)[idx] = ((float*)wei)[wei_off_f(p, g, oc, ic, kd, kh, kw)];
57     }
58
59     return OK;
60 }
61
62 inline int init_pd(const prb_t *p, mkldnn_deconvolution_desc_t &cd,
63         mkldnn_primitive_desc_t &dpd, res_t *r) {
64     int ndims = is_deconv_3d(p) ? 5 : 4;
65
66     mkldnn_memory_desc_t src_d, wei_d, bia_d, dst_d;
67     mkldnn_dims_t src_dims = {p->mb, p->ic, p->ih, p->iw};
68     mkldnn_dims_t src_3d_dims = {p->mb, p->ic, p->id, p->ih, p->iw};
69     mkldnn_dims_t wei_dims = {p->g, p->oc / p->g, p->ic / p->g, p->kh, p->kw};
70     mkldnn_dims_t wei_3d_dims = {p->g, p->oc / p->g, p->ic / p->g, p->kd, p->kh, p->kw};
71     mkldnn_dims_t bia_dims = {p->oc};
72     mkldnn_dims_t dst_dims = {p->mb, p->oc, p->oh, p->ow};
73     mkldnn_dims_t dst_3d_dims = {p->mb, p->oc, p->od, p->oh, p->ow};
74
75     assert(p->cfg[SRC].dt == p->cfg[DST].dt);
76
77     DNN_SAFE(mkldnn_memory_desc_init(&src_d, ndims,
78         is_deconv_3d(p) ? src_3d_dims : src_dims, p->cfg[SRC].dt, mkldnn_any), WARN);
79     DNN_SAFE(mkldnn_memory_desc_init(&wei_d, ndims + 1,
80         is_deconv_3d(p) ? wei_3d_dims : wei_dims, p->cfg[WEI].dt, mkldnn_any), WARN);
81     DNN_SAFE(mkldnn_memory_desc_init(&bia_d, 1, bia_dims, p->cfg[BIA].dt, mkldnn_any), WARN);
82     DNN_SAFE(mkldnn_memory_desc_init(&dst_d, ndims,
83         is_deconv_3d(p) ? dst_3d_dims : dst_dims, p->cfg[DST].dt, mkldnn_any), WARN);
84     int strides_2d[] = {p->sh, p->sw};
85     int padding_2d[] = {p->ph, p->pw};
86     int strides_3d[] = {p->sd, p->sh, p->sw};
87     int padding_3d[] = {p->pd, p->ph, p->pw};
88
89     auto bph = [&](int ih, int oh, int kh, int sh, int ph, int dh) {
90         return (oh - 1) * sh - ih + ((kh - 1) * (dh + 1) + 1) - ph;
91     };
92     int padding_r_3d[] = {
93         bph(p->od, p->id, p->kd, p->sd, p->pd, p->dd),
94         bph(p->oh, p->ih, p->kh, p->sh, p->ph, p->dh),
95         bph(p->ow, p->iw, p->kw, p->sw, p->pw, p->dw)};
96     int padding_r_2d[] = {
97         bph(p->oh, p->ih, p->kh, p->sh, p->ph, p->dh),
98         bph(p->ow, p->iw, p->kw, p->sw, p->pw, p->dw)};
99
100     int *strides = is_deconv_3d(p) ? strides_3d : strides_2d;
101     int *padding = is_deconv_3d(p) ? padding_3d : padding_2d;
102     int *padding_r = is_deconv_3d(p) ? padding_r_3d : padding_r_2d;
103     mkldnn_alg_kind_t alg = mkldnn_deconvolution_direct;
104     if (p->alg == WINO) alg = mkldnn_deconvolution_winograd;
105
106     switch (p->dir) {
107     case FWD_D: case FWD_B:
108         DNN_SAFE(mkldnn_deconvolution_forward_desc_init(&cd,
109                     mkldnn_forward_inference, alg, &src_d, &wei_d,
110                     p->dir == FWD_D ? NULL : &bia_d, &dst_d, strides,
111                     padding, padding_r, mkldnn_padding_zero), WARN);
112         break;
113     case BWD_D:
114         DNN_SAFE(mkldnn_deconvolution_backward_data_desc_init(&cd, alg,
115                     &src_d, &wei_d, &dst_d, strides, padding,
116                     padding_r, mkldnn_padding_zero), WARN);
117         break;
118     case BWD_W: case BWD_WB:
119         DNN_SAFE(mkldnn_deconvolution_backward_weights_desc_init(&cd,
120                     alg, &src_d, &wei_d, p->dir == BWD_W ? NULL : &bia_d,
121                     &dst_d, strides,  padding, padding_r,
122                     mkldnn_padding_zero), WARN);
123         break;
124     default: DNN_SAFE(mkldnn_invalid_arguments, CRIT);
125     }
126
127     DNN_SAFE(cd.accum_data_type == p->cfg[ACC].dt
128             ? mkldnn_success : mkldnn_unimplemented, CRIT);
129
130     auto mkldnn_attr = create_mkldnn_attr(p->attr, p->oc, p->scales);
131
132     mkldnn_status_t init_status = mkldnn_success;
133     init_status = mkldnn_primitive_desc_create_v2(&dpd, &cd, mkldnn_attr,
134             engine, NULL);
135
136     mkldnn_primitive_attr_destroy(mkldnn_attr);
137
138     if (init_status == mkldnn_unimplemented)
139     {
140         return r->state = UNIMPLEMENTED, OK;
141     } else
142         SAFE(init_status, WARN);
143
144     const char *impl_str = query_impl_info(dpd);
145     if (maybe_skip(skip_impl, impl_str)) {
146         print(2, "SKIPPED: mkldnn implementation: %s\n", impl_str);
147         DNN_SAFE(mkldnn_primitive_desc_destroy(dpd), WARN);
148         return r->state = SKIPPED, OK;
149     } else {
150         print(5, "mkldnn implementation: %s\n", impl_str);
151     }
152
153     auto q = [=](mkldnn_query_t query, int index = 0) {
154         return *mkldnn_primitive_desc_query_memory_d(
155                 mkldnn_primitive_desc_query_pd(dpd, query, index));
156     };
157
158     if (p->dir == BWD_D)
159         cd.diff_src_desc = q(mkldnn_query_diff_src_pd);
160     else
161         cd.src_desc = q(mkldnn_query_src_pd);
162
163     if (p->dir & FLAG_WEI)
164         cd.diff_weights_desc = q(mkldnn_query_diff_weights_pd);
165     else
166         cd.weights_desc = q(mkldnn_query_weights_pd);
167
168     if (p->dir & FLAG_BIA) {
169         if (p->dir & FLAG_BWD)
170             cd.diff_bias_desc = q(mkldnn_query_diff_weights_pd, 1);
171         else
172             cd.bias_desc = q(mkldnn_query_weights_pd, 1);
173     }
174
175     if (p->dir & FLAG_BWD)
176         cd.diff_dst_desc = q(mkldnn_query_diff_dst_pd);
177     else
178         cd.dst_desc = q(mkldnn_query_dst_pd);
179
180     return OK;
181 }
182 int doit(const prb_t *p, res_t *r) {
183     res_t res_zero{};
184     *r = res_zero;
185     bool with_groups = 1;
186
187     prb_t p_tr((desc_t)*p, p->dir, p->cfg, p->alg, p->merge, p->attr, p->mb);
188     swap(p_tr.ic,  p_tr.oc);
189     swap(p_tr.ih,  p_tr.oh);
190     swap(p_tr.id,  p_tr.od);
191     swap(p_tr.iw,  p_tr.ow);
192
193     mkldnn_deconvolution_desc_t cd;
194     mkldnn_primitive_desc_t dpd;
195     mkldnn_primitive_t c{};
196
197     SAFE(init_pd(p, cd, dpd, r), WARN);
198     if (r->state == SKIPPED || r->state == UNIMPLEMENTED)
199         return OK;
200
201     auto &src_dt_d = p->dir == BWD_D ? cd.diff_src_desc : cd.src_desc;
202     auto &wei_dt_d = p->dir & FLAG_WEI ? cd.diff_weights_desc : cd.weights_desc;
203     auto &bia_dt_d = p->dir & FLAG_BWD ? cd.diff_bias_desc : cd.bias_desc;
204     auto &dst_dt_d = p->dir & FLAG_BWD ? cd.diff_dst_desc: cd.dst_desc;
205     auto wei_tr_dt_d = wei_dt_d;
206     swap(wei_tr_dt_d.dims[with_groups+0], wei_tr_dt_d.dims[with_groups+1]);
207
208     dnn_mem_t src_dt(src_dt_d, p->cfg[SRC].dt);
209     dnn_mem_t wei_dt(wei_dt_d, p->cfg[WEI].dt);
210     dnn_mem_t dst_dt(dst_dt_d, p->cfg[DST].dt);
211     dnn_mem_t *p_bia_dt = p->dir & FLAG_BIA
212         ? new dnn_mem_t(bia_dt_d, p->cfg[BIA].dt) : new dnn_mem_t();
213     dnn_mem_t &bia_dt = *p_bia_dt;
214
215     auto src_format = is_deconv_3d(p) ? mkldnn_ncdhw : mkldnn_nchw;
216     auto wei_format = is_deconv_3d(p) ? mkldnn_goidhw : mkldnn_goihw;
217
218     const auto fp = mkldnn_f32;
219
220     /* memory for ref */
221     dnn_mem_t src_fp(src_dt_d, fp, src_format);
222     dnn_mem_t wei_fp(wei_dt_d, fp, wei_format);
223     dnn_mem_t dst_fp(dst_dt_d, fp, src_format);
224     dnn_mem_t wei_tr_fp(wei_tr_dt_d, fp, wei_format);
225     dnn_mem_t *p_bia_fp = p->dir & FLAG_BIA
226         ? new dnn_mem_t(bia_dt_d, fp, mkldnn_x) : new dnn_mem_t();
227     dnn_mem_t *p_zero_fp = new dnn_mem_t();
228     dnn_mem_t &bia_fp = *p_bia_fp, &zero_fp = *p_zero_fp;
229
230     /* fill memory + reorders <-> */
231     SAFE(fill_src(&p_tr, dst_dt, dst_fp, r), WARN);
232     SAFE(fill_wei(p, wei_dt, wei_fp, r), WARN);
233     SAFE(fill_dst(&p_tr, src_dt, src_fp, r), WARN);
234
235     SAFE(transpose_data_wei(p, wei_fp, wei_tr_fp), WARN);
236     if (p->dir & FLAG_BIA)
237         SAFE(fill_bia(p, bia_dt, bia_fp, r), WARN);
238     if (p->dir & FLAG_FWD) {
239         mkldnn_primitive_at_t inputs[3] = { {src_dt.p_, 0}, {wei_dt.p_, 0},
240             {p->dir & FLAG_BIA ? bia_dt.p_ : NULL, 0}
241         };
242         const_mkldnn_primitive_t outputs[] = { dst_dt.p_ };
243         DNN_SAFE(mkldnn_primitive_create(&c, dpd, inputs, outputs), WARN);
244         SAFE(execute(c), WARN);
245         if (bench_mode & CORR) {
246             compute_ref_bwd_d(&p_tr, dst_fp, wei_tr_fp, src_fp);
247             dnn_mem_t dst(dst_dt, fp, src_format);
248             SAFE(dst.reorder(dst_dt), WARN);
249             if (p->dir & FLAG_BIA) {
250                 compute_bias_fwd(p, bia_fp, dst_fp);
251             }
252            SAFE(compare_dst(p, dst, dst_fp, r, true), WARN);
253         }
254     } else if (p->dir == BWD_D) {
255         mkldnn_primitive_at_t inputs[3] = { {dst_dt.p_, 0}, {wei_dt.p_, 0}, };
256         const_mkldnn_primitive_t outputs[] = { src_dt.p_ };
257         DNN_SAFE(mkldnn_primitive_create(&c, dpd, inputs, outputs), WARN);
258         SAFE(execute(c), WARN);
259         if (bench_mode & CORR) {
260             compute_ref_fwd(&p_tr, dst_fp, wei_tr_fp, zero_fp, src_fp);
261             dnn_mem_t src(src_dt, fp, src_format);
262             SAFE(src.reorder(src_dt), WARN);
263             SAFE(compare_src(p, src, src_fp, r, true), WARN);
264         }
265     } else if (p->dir & FLAG_BWD && p->dir & FLAG_WEI) {
266         mkldnn_primitive_at_t inputs[3] = { {src_dt.p_, 0}, {dst_dt.p_, 0}, };
267         const_mkldnn_primitive_t outputs[] = { wei_dt.p_,
268             p->dir & FLAG_BIA ? bia_dt.p_ : NULL,
269         };
270         DNN_SAFE(mkldnn_primitive_create(&c, dpd, inputs, outputs), WARN);
271         SAFE(execute(c), WARN);
272         if (bench_mode & CORR) {
273             compute_ref_bwd_weights(&p_tr, dst_fp, wei_tr_fp, src_fp);
274             transpose_data_wei(&p_tr, wei_tr_fp, wei_fp);
275             dnn_mem_t wei(wei_dt, fp, wei_format);
276             SAFE(wei.reorder(wei_dt), WARN);
277             SAFE(compare_wei(&p_tr, wei, wei_fp, r, true), WARN);
278             if (p->dir & FLAG_BIA) {
279                 compute_ref_bwd_bias(p, bia_fp, dst_fp);
280                 dnn_mem_t bia(bia_dt, fp, mkldnn_x);
281                 SAFE(bia.reorder(bia_dt), WARN);
282                 SAFE(compare_bia(p, bia, bia_fp, r, true), WARN);
283             }
284         }
285     } else {
286         delete p_bia_dt;
287         delete p_bia_fp;
288         delete p_zero_fp;
289         SAFE(FAIL, CRIT);
290     }
291
292     if (bench_mode & PERF) {
293         auto &t = r->timer;
294         t.reset();
295         while (true) {
296             SAFE(execute(c), WARN);
297             t.stamp();
298             const bool stop = false
299                 || (fix_times_per_prb && t.times() >= fix_times_per_prb)
300                 || (!fix_times_per_prb
301                         && t.total_ms() >= max_ms_per_prb
302                         && t.times() >= min_times_per_prb);
303             if (stop) break;
304         }
305     }
306
307     delete p_bia_dt;
308     delete p_bia_fp;
309     delete p_zero_fp;
310
311    return OK;
312 }
313
314 }