1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
24 #include "mkldnn_common.hpp"
25 #include "mkldnn_memory.hpp"
29 #include "conv/deconv.hpp"
34 inline static void swap(int &a, int &b)
40 inline bool is_deconv_3d(const prb_t *p)
42 return (p->id > 1) ? 1 : 0;
45 inline int transpose_data_wei(const prb_t *p, dnn_mem_t &wei, dnn_mem_t &wei_tr) {
46 # pragma omp parallel for collapse(5)
47 for (int g = 0; g < p->g; ++g)
48 for (int oc = 0; oc < p->oc / p->g; ++oc)
49 for (int ic = 0; ic < p->ic / p->g; ++ic)
50 for (int kd = 0; kd < p->kd; ++kd)
51 for (int kh = 0; kh < p->kh; ++kh)
52 for (int kw = 0; kw < p->kw; ++kw)
54 size_t idx = (((((size_t)g * p->ic / p->g + ic) * p->oc / p->g + oc)
55 * p->kd + kd) * p->kh + kh) * p->kw + kw;
56 ((float*)wei_tr)[idx] = ((float*)wei)[wei_off_f(p, g, oc, ic, kd, kh, kw)];
62 inline int init_pd(const prb_t *p, mkldnn_deconvolution_desc_t &cd,
63 mkldnn_primitive_desc_t &dpd, res_t *r) {
64 int ndims = is_deconv_3d(p) ? 5 : 4;
66 mkldnn_memory_desc_t src_d, wei_d, bia_d, dst_d;
67 mkldnn_dims_t src_dims = {p->mb, p->ic, p->ih, p->iw};
68 mkldnn_dims_t src_3d_dims = {p->mb, p->ic, p->id, p->ih, p->iw};
69 mkldnn_dims_t wei_dims = {p->g, p->oc / p->g, p->ic / p->g, p->kh, p->kw};
70 mkldnn_dims_t wei_3d_dims = {p->g, p->oc / p->g, p->ic / p->g, p->kd, p->kh, p->kw};
71 mkldnn_dims_t bia_dims = {p->oc};
72 mkldnn_dims_t dst_dims = {p->mb, p->oc, p->oh, p->ow};
73 mkldnn_dims_t dst_3d_dims = {p->mb, p->oc, p->od, p->oh, p->ow};
75 assert(p->cfg[SRC].dt == p->cfg[DST].dt);
77 DNN_SAFE(mkldnn_memory_desc_init(&src_d, ndims,
78 is_deconv_3d(p) ? src_3d_dims : src_dims, p->cfg[SRC].dt, mkldnn_any), WARN);
79 DNN_SAFE(mkldnn_memory_desc_init(&wei_d, ndims + 1,
80 is_deconv_3d(p) ? wei_3d_dims : wei_dims, p->cfg[WEI].dt, mkldnn_any), WARN);
81 DNN_SAFE(mkldnn_memory_desc_init(&bia_d, 1, bia_dims, p->cfg[BIA].dt, mkldnn_any), WARN);
82 DNN_SAFE(mkldnn_memory_desc_init(&dst_d, ndims,
83 is_deconv_3d(p) ? dst_3d_dims : dst_dims, p->cfg[DST].dt, mkldnn_any), WARN);
84 int strides_2d[] = {p->sh, p->sw};
85 int padding_2d[] = {p->ph, p->pw};
86 int strides_3d[] = {p->sd, p->sh, p->sw};
87 int padding_3d[] = {p->pd, p->ph, p->pw};
89 auto bph = [&](int ih, int oh, int kh, int sh, int ph, int dh) {
90 return (oh - 1) * sh - ih + ((kh - 1) * (dh + 1) + 1) - ph;
92 int padding_r_3d[] = {
93 bph(p->od, p->id, p->kd, p->sd, p->pd, p->dd),
94 bph(p->oh, p->ih, p->kh, p->sh, p->ph, p->dh),
95 bph(p->ow, p->iw, p->kw, p->sw, p->pw, p->dw)};
96 int padding_r_2d[] = {
97 bph(p->oh, p->ih, p->kh, p->sh, p->ph, p->dh),
98 bph(p->ow, p->iw, p->kw, p->sw, p->pw, p->dw)};
100 int *strides = is_deconv_3d(p) ? strides_3d : strides_2d;
101 int *padding = is_deconv_3d(p) ? padding_3d : padding_2d;
102 int *padding_r = is_deconv_3d(p) ? padding_r_3d : padding_r_2d;
103 mkldnn_alg_kind_t alg = mkldnn_deconvolution_direct;
104 if (p->alg == WINO) alg = mkldnn_deconvolution_winograd;
107 case FWD_D: case FWD_B:
108 DNN_SAFE(mkldnn_deconvolution_forward_desc_init(&cd,
109 mkldnn_forward_inference, alg, &src_d, &wei_d,
110 p->dir == FWD_D ? NULL : &bia_d, &dst_d, strides,
111 padding, padding_r, mkldnn_padding_zero), WARN);
114 DNN_SAFE(mkldnn_deconvolution_backward_data_desc_init(&cd, alg,
115 &src_d, &wei_d, &dst_d, strides, padding,
116 padding_r, mkldnn_padding_zero), WARN);
118 case BWD_W: case BWD_WB:
119 DNN_SAFE(mkldnn_deconvolution_backward_weights_desc_init(&cd,
120 alg, &src_d, &wei_d, p->dir == BWD_W ? NULL : &bia_d,
121 &dst_d, strides, padding, padding_r,
122 mkldnn_padding_zero), WARN);
124 default: DNN_SAFE(mkldnn_invalid_arguments, CRIT);
127 DNN_SAFE(cd.accum_data_type == p->cfg[ACC].dt
128 ? mkldnn_success : mkldnn_unimplemented, CRIT);
130 auto mkldnn_attr = create_mkldnn_attr(p->attr, p->oc, p->scales);
132 mkldnn_status_t init_status = mkldnn_success;
133 init_status = mkldnn_primitive_desc_create_v2(&dpd, &cd, mkldnn_attr,
136 mkldnn_primitive_attr_destroy(mkldnn_attr);
138 if (init_status == mkldnn_unimplemented)
140 return r->state = UNIMPLEMENTED, OK;
142 SAFE(init_status, WARN);
144 const char *impl_str = query_impl_info(dpd);
145 if (maybe_skip(skip_impl, impl_str)) {
146 print(2, "SKIPPED: mkldnn implementation: %s\n", impl_str);
147 DNN_SAFE(mkldnn_primitive_desc_destroy(dpd), WARN);
148 return r->state = SKIPPED, OK;
150 print(5, "mkldnn implementation: %s\n", impl_str);
153 auto q = [=](mkldnn_query_t query, int index = 0) {
154 return *mkldnn_primitive_desc_query_memory_d(
155 mkldnn_primitive_desc_query_pd(dpd, query, index));
159 cd.diff_src_desc = q(mkldnn_query_diff_src_pd);
161 cd.src_desc = q(mkldnn_query_src_pd);
163 if (p->dir & FLAG_WEI)
164 cd.diff_weights_desc = q(mkldnn_query_diff_weights_pd);
166 cd.weights_desc = q(mkldnn_query_weights_pd);
168 if (p->dir & FLAG_BIA) {
169 if (p->dir & FLAG_BWD)
170 cd.diff_bias_desc = q(mkldnn_query_diff_weights_pd, 1);
172 cd.bias_desc = q(mkldnn_query_weights_pd, 1);
175 if (p->dir & FLAG_BWD)
176 cd.diff_dst_desc = q(mkldnn_query_diff_dst_pd);
178 cd.dst_desc = q(mkldnn_query_dst_pd);
182 int doit(const prb_t *p, res_t *r) {
185 bool with_groups = 1;
187 prb_t p_tr((desc_t)*p, p->dir, p->cfg, p->alg, p->merge, p->attr, p->mb);
188 swap(p_tr.ic, p_tr.oc);
189 swap(p_tr.ih, p_tr.oh);
190 swap(p_tr.id, p_tr.od);
191 swap(p_tr.iw, p_tr.ow);
193 mkldnn_deconvolution_desc_t cd;
194 mkldnn_primitive_desc_t dpd;
195 mkldnn_primitive_t c{};
197 SAFE(init_pd(p, cd, dpd, r), WARN);
198 if (r->state == SKIPPED || r->state == UNIMPLEMENTED)
201 auto &src_dt_d = p->dir == BWD_D ? cd.diff_src_desc : cd.src_desc;
202 auto &wei_dt_d = p->dir & FLAG_WEI ? cd.diff_weights_desc : cd.weights_desc;
203 auto &bia_dt_d = p->dir & FLAG_BWD ? cd.diff_bias_desc : cd.bias_desc;
204 auto &dst_dt_d = p->dir & FLAG_BWD ? cd.diff_dst_desc: cd.dst_desc;
205 auto wei_tr_dt_d = wei_dt_d;
206 swap(wei_tr_dt_d.dims[with_groups+0], wei_tr_dt_d.dims[with_groups+1]);
208 dnn_mem_t src_dt(src_dt_d, p->cfg[SRC].dt);
209 dnn_mem_t wei_dt(wei_dt_d, p->cfg[WEI].dt);
210 dnn_mem_t dst_dt(dst_dt_d, p->cfg[DST].dt);
211 dnn_mem_t *p_bia_dt = p->dir & FLAG_BIA
212 ? new dnn_mem_t(bia_dt_d, p->cfg[BIA].dt) : new dnn_mem_t();
213 dnn_mem_t &bia_dt = *p_bia_dt;
215 auto src_format = is_deconv_3d(p) ? mkldnn_ncdhw : mkldnn_nchw;
216 auto wei_format = is_deconv_3d(p) ? mkldnn_goidhw : mkldnn_goihw;
218 const auto fp = mkldnn_f32;
221 dnn_mem_t src_fp(src_dt_d, fp, src_format);
222 dnn_mem_t wei_fp(wei_dt_d, fp, wei_format);
223 dnn_mem_t dst_fp(dst_dt_d, fp, src_format);
224 dnn_mem_t wei_tr_fp(wei_tr_dt_d, fp, wei_format);
225 dnn_mem_t *p_bia_fp = p->dir & FLAG_BIA
226 ? new dnn_mem_t(bia_dt_d, fp, mkldnn_x) : new dnn_mem_t();
227 dnn_mem_t *p_zero_fp = new dnn_mem_t();
228 dnn_mem_t &bia_fp = *p_bia_fp, &zero_fp = *p_zero_fp;
230 /* fill memory + reorders <-> */
231 SAFE(fill_src(&p_tr, dst_dt, dst_fp, r), WARN);
232 SAFE(fill_wei(p, wei_dt, wei_fp, r), WARN);
233 SAFE(fill_dst(&p_tr, src_dt, src_fp, r), WARN);
235 SAFE(transpose_data_wei(p, wei_fp, wei_tr_fp), WARN);
236 if (p->dir & FLAG_BIA)
237 SAFE(fill_bia(p, bia_dt, bia_fp, r), WARN);
238 if (p->dir & FLAG_FWD) {
239 mkldnn_primitive_at_t inputs[3] = { {src_dt.p_, 0}, {wei_dt.p_, 0},
240 {p->dir & FLAG_BIA ? bia_dt.p_ : NULL, 0}
242 const_mkldnn_primitive_t outputs[] = { dst_dt.p_ };
243 DNN_SAFE(mkldnn_primitive_create(&c, dpd, inputs, outputs), WARN);
244 SAFE(execute(c), WARN);
245 if (bench_mode & CORR) {
246 compute_ref_bwd_d(&p_tr, dst_fp, wei_tr_fp, src_fp);
247 dnn_mem_t dst(dst_dt, fp, src_format);
248 SAFE(dst.reorder(dst_dt), WARN);
249 if (p->dir & FLAG_BIA) {
250 compute_bias_fwd(p, bia_fp, dst_fp);
252 SAFE(compare_dst(p, dst, dst_fp, r, true), WARN);
254 } else if (p->dir == BWD_D) {
255 mkldnn_primitive_at_t inputs[3] = { {dst_dt.p_, 0}, {wei_dt.p_, 0}, };
256 const_mkldnn_primitive_t outputs[] = { src_dt.p_ };
257 DNN_SAFE(mkldnn_primitive_create(&c, dpd, inputs, outputs), WARN);
258 SAFE(execute(c), WARN);
259 if (bench_mode & CORR) {
260 compute_ref_fwd(&p_tr, dst_fp, wei_tr_fp, zero_fp, src_fp);
261 dnn_mem_t src(src_dt, fp, src_format);
262 SAFE(src.reorder(src_dt), WARN);
263 SAFE(compare_src(p, src, src_fp, r, true), WARN);
265 } else if (p->dir & FLAG_BWD && p->dir & FLAG_WEI) {
266 mkldnn_primitive_at_t inputs[3] = { {src_dt.p_, 0}, {dst_dt.p_, 0}, };
267 const_mkldnn_primitive_t outputs[] = { wei_dt.p_,
268 p->dir & FLAG_BIA ? bia_dt.p_ : NULL,
270 DNN_SAFE(mkldnn_primitive_create(&c, dpd, inputs, outputs), WARN);
271 SAFE(execute(c), WARN);
272 if (bench_mode & CORR) {
273 compute_ref_bwd_weights(&p_tr, dst_fp, wei_tr_fp, src_fp);
274 transpose_data_wei(&p_tr, wei_tr_fp, wei_fp);
275 dnn_mem_t wei(wei_dt, fp, wei_format);
276 SAFE(wei.reorder(wei_dt), WARN);
277 SAFE(compare_wei(&p_tr, wei, wei_fp, r, true), WARN);
278 if (p->dir & FLAG_BIA) {
279 compute_ref_bwd_bias(p, bia_fp, dst_fp);
280 dnn_mem_t bia(bia_dt, fp, mkldnn_x);
281 SAFE(bia.reorder(bia_dt), WARN);
282 SAFE(compare_bia(p, bia, bia_fp, r, true), WARN);
292 if (bench_mode & PERF) {
296 SAFE(execute(c), WARN);
298 const bool stop = false
299 || (fix_times_per_prb && t.times() >= fix_times_per_prb)
300 || (!fix_times_per_prb
301 && t.total_ms() >= max_ms_per_prb
302 && t.times() >= min_times_per_prb);