1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
24 #include "src/common/mkldnn_thread.hpp"
26 #include "mkldnn_common.hpp"
27 #include "mkldnn_memory.hpp"
31 #include "conv/deconv.hpp"
37 inline static void swap(T &a, T &b) {
43 inline bool is_deconv_3d(const prb_t *p) {
47 inline bool is_deconv_1d(const prb_t *p) {
48 return !is_deconv_3d(p) && p->ih == 1 && p->kh == 1;
51 inline int transpose_data_wei(const prb_t *p, dnn_mem_t &wei, dnn_mem_t &wei_tr) {
52 mkldnn::impl::parallel_nd(
53 p->g, p->oc / p->g, p->ic / p->g, p->kd, p->kh, p->kw,
54 [&](int g, int oc, int ic, int kd, int kh, int kw) {
55 size_t idx = (((((size_t)g * p->ic / p->g + ic) * p->oc / p->g + oc)
56 * p->kd + kd) * p->kh + kh) * p->kw + kw;
57 ((float*)wei_tr)[idx] = ((float*)wei)[wei_off_f(p, g, oc, ic, kd, kh, kw)];
63 inline int init_pd(const prb_t *p, mkldnn_deconvolution_desc_t &cd,
64 mkldnn_primitive_desc_t &dpd, res_t *r) {
65 int ndims = is_deconv_3d(p) ? 5 : is_deconv_1d(p) ? 3 : 4;
67 mkldnn_memory_desc_t src_d, wei_d, bia_d, dst_d;
68 mkldnn_dims_t src_1d_dims = {p->mb, p->ic, p->iw};
69 mkldnn_dims_t src_2d_dims = {p->mb, p->ic, p->ih, p->iw};
70 mkldnn_dims_t src_3d_dims = {p->mb, p->ic, p->id, p->ih, p->iw};
71 mkldnn_dims_t wei_1d_dims = {p->g, p->oc / p->g, p->ic / p->g, p->kw};
72 mkldnn_dims_t wei_2d_dims = {p->g, p->oc / p->g, p->ic / p->g, p->kh, p->kw};
73 mkldnn_dims_t wei_3d_dims = {p->g, p->oc / p->g, p->ic / p->g, p->kd, p->kh, p->kw};
74 mkldnn_dims_t bia_dims = {p->oc};
75 mkldnn_dims_t dst_1d_dims = {p->mb, p->oc, p->ow};
76 mkldnn_dims_t dst_2d_dims = {p->mb, p->oc, p->oh, p->ow};
77 mkldnn_dims_t dst_3d_dims = {p->mb, p->oc, p->od, p->oh, p->ow};
78 DNN_SAFE(mkldnn_memory_desc_init(&src_d, ndims,
79 is_deconv_3d(p) ? src_3d_dims : is_deconv_1d(p) ? src_1d_dims : src_2d_dims,
80 p->cfg[SRC].dt, mkldnn_any),
82 DNN_SAFE(mkldnn_memory_desc_init(&wei_d, ndims + p->has_groups,
84 ? &wei_3d_dims[!p->has_groups]
86 ? &wei_1d_dims[!p->has_groups]
87 : &wei_2d_dims[!p->has_groups],
88 p->cfg[WEI].dt, mkldnn_any), WARN);
89 DNN_SAFE(mkldnn_memory_desc_init(&bia_d, 1, bia_dims, p->cfg[BIA].dt, mkldnn_any), WARN);
90 DNN_SAFE(mkldnn_memory_desc_init(&dst_d, ndims,
91 is_deconv_3d(p) ? dst_3d_dims : is_deconv_1d(p) ? dst_1d_dims : dst_2d_dims,
92 p->cfg[DST].dt, mkldnn_any), WARN);
94 ptrdiff_t strides_nd[] = {p->sd, p->sh, p->sw};
95 ptrdiff_t dilates_nd[] = {p->dd, p->dh, p->dw};
96 ptrdiff_t padding_nd[] = {p->pd, p->ph, p->pw};
98 auto bph = [&](int ih, int oh, int kh, int sh, int ph, int dh) {
99 return (oh - 1) * sh - ih + ((kh - 1) * (dh + 1) + 1) - ph;
102 ptrdiff_t padding_r_nd[] = {
103 bph(p->od, p->id, p->kd, p->sd, p->pd, p->dd),
104 bph(p->oh, p->ih, p->kh, p->sh, p->ph, p->dh),
105 bph(p->ow, p->iw, p->kw, p->sw, p->pw, p->dw)};
107 ptrdiff_t *strides = strides_nd + (5 - ndims);
108 ptrdiff_t *dilates = dilates_nd + (5 - ndims);
109 ptrdiff_t *padding = padding_nd + (5 - ndims);
110 ptrdiff_t *padding_r = padding_r_nd + (5 - ndims);
112 mkldnn_alg_kind_t alg = mkldnn_deconvolution_direct;
113 if (p->alg == WINO) alg = mkldnn_deconvolution_winograd;
116 case FWD_D: case FWD_B:
117 DNN_SAFE(mkldnn_dilated_deconvolution_forward_desc_init(&cd,
118 mkldnn_forward_inference, alg, &src_d, &wei_d,
119 p->dir == FWD_D ? NULL : &bia_d, &dst_d, strides,
120 dilates, padding, padding_r, mkldnn_padding_zero), WARN);
123 DNN_SAFE(mkldnn_dilated_deconvolution_backward_data_desc_init(&cd, alg,
124 &src_d, &wei_d, &dst_d, strides, dilates, padding,
125 padding_r, mkldnn_padding_zero), WARN);
127 case BWD_W: case BWD_WB:
128 DNN_SAFE(mkldnn_dilated_deconvolution_backward_weights_desc_init(&cd,
129 alg, &src_d, &wei_d, p->dir == BWD_W ? NULL : &bia_d,
130 &dst_d, strides, dilates, padding, padding_r,
131 mkldnn_padding_zero), WARN);
133 default: DNN_SAFE(mkldnn_invalid_arguments, CRIT);
136 DNN_SAFE(cd.accum_data_type == p->cfg[ACC].dt
137 ? mkldnn_success : mkldnn_unimplemented, CRIT);
139 auto mkldnn_attr = create_mkldnn_attr(p->attr, p->oc, p->scales);
141 mkldnn_status_t init_status = mkldnn_success;
142 init_status = mkldnn_primitive_desc_create_v2(&dpd, &cd, mkldnn_attr,
145 mkldnn_primitive_attr_destroy(mkldnn_attr);
147 if (init_status == mkldnn_unimplemented)
149 return r->state = UNIMPLEMENTED, OK;
151 SAFE(init_status, WARN);
153 const char *impl_str = query_impl_info(dpd);
154 if (maybe_skip(skip_impl, impl_str)) {
155 print(2, "SKIPPED: mkldnn implementation: %s\n", impl_str);
156 DNN_SAFE(mkldnn_primitive_desc_destroy(dpd), WARN);
157 return r->state = SKIPPED, OK;
159 print(5, "mkldnn implementation: %s\n", impl_str);
162 auto q = [=](mkldnn_query_t query, int index = 0) {
163 return *mkldnn_primitive_desc_query_memory_d(
164 mkldnn_primitive_desc_query_pd(dpd, query, index));
168 cd.diff_src_desc = q(mkldnn_query_diff_src_pd);
170 cd.src_desc = q(mkldnn_query_src_pd);
172 if (p->dir & FLAG_WEI)
173 cd.diff_weights_desc = q(mkldnn_query_diff_weights_pd);
175 cd.weights_desc = q(mkldnn_query_weights_pd);
177 if (p->dir & FLAG_BIA) {
178 if (p->dir & FLAG_BWD)
179 cd.diff_bias_desc = q(mkldnn_query_diff_weights_pd, 1);
181 cd.bias_desc = q(mkldnn_query_weights_pd, 1);
184 if (p->dir & FLAG_BWD)
185 cd.diff_dst_desc = q(mkldnn_query_diff_dst_pd);
187 cd.dst_desc = q(mkldnn_query_dst_pd);
191 int doit(const prb_t *p, res_t *r) {
194 bool with_groups = 1;
196 prb_t p_tr((desc_t)*p, p->dir, p->cfg, p->alg, p->attr, p->mb, true);
197 swap(p_tr.ic, p_tr.oc);
198 swap(p_tr.ih, p_tr.oh);
199 swap(p_tr.id, p_tr.od);
200 swap(p_tr.iw, p_tr.ow);
202 mkldnn_deconvolution_desc_t cd;
203 mkldnn_primitive_desc_t dpd;
204 mkldnn_primitive_t c{};
206 SAFE(init_pd(p, cd, dpd, r), WARN);
207 if (r->state == SKIPPED || r->state == UNIMPLEMENTED)
210 auto &src_dt_d = p->dir == BWD_D ? cd.diff_src_desc : cd.src_desc;
211 auto &wei_dt_d = p->dir & FLAG_WEI ? cd.diff_weights_desc : cd.weights_desc;
212 auto &bia_dt_d = p->dir & FLAG_BWD ? cd.diff_bias_desc : cd.bias_desc;
213 auto &dst_dt_d = p->dir & FLAG_BWD ? cd.diff_dst_desc: cd.dst_desc;
214 auto wei_tr_dt_d = wei_dt_d;
215 swap(wei_tr_dt_d.dims[with_groups+0], wei_tr_dt_d.dims[with_groups+1]);
217 dnn_mem_t src_dt(src_dt_d, p->cfg[SRC].dt);
218 dnn_mem_t wei_dt(wei_dt_d, p->cfg[WEI].dt);
219 dnn_mem_t dst_dt(dst_dt_d, p->cfg[DST].dt);
220 dnn_mem_t *p_bia_dt = p->dir & FLAG_BIA
221 ? new dnn_mem_t(bia_dt_d, p->cfg[BIA].dt) : new dnn_mem_t();
222 dnn_mem_t &bia_dt = *p_bia_dt;
224 auto src_format = get_default_format(src_dt.md_.ndims, DATA);
225 auto wei_format = get_default_format(wei_dt.md_.ndims,
226 p->has_groups ? GWEI : WEI);
228 const auto fp = mkldnn_f32;
231 dnn_mem_t src_fp(src_dt_d, fp, src_format);
232 dnn_mem_t wei_fp(wei_dt_d, fp, wei_format);
233 dnn_mem_t dst_fp(dst_dt_d, fp, src_format);
234 dnn_mem_t wei_tr_fp(wei_tr_dt_d, fp, wei_format);
235 dnn_mem_t *p_bia_fp = p->dir & FLAG_BIA
236 ? new dnn_mem_t(bia_dt_d, fp, mkldnn_x) : new dnn_mem_t();
237 dnn_mem_t *p_zero_fp = new dnn_mem_t();
238 dnn_mem_t &bia_fp = *p_bia_fp, &zero_fp = *p_zero_fp;
240 /* fill memory + reorders <-> */
241 SAFE(fill_dst(p, dst_dt, dst_fp, r), WARN);
242 SAFE(fill_wei(p, wei_dt, wei_fp, r), WARN);
243 SAFE(fill_src(p, src_dt, src_fp, r), WARN);
245 SAFE(transpose_data_wei(p, wei_fp, wei_tr_fp), WARN);
246 if (p->dir & FLAG_BIA)
247 SAFE(fill_bia(p, bia_dt, bia_fp, r), WARN);
248 if (p->dir & FLAG_FWD) {
249 mkldnn_primitive_at_t inputs[3] = { {src_dt.p_, 0}, {wei_dt.p_, 0},
250 {p->dir & FLAG_BIA ? bia_dt.p_ : NULL, 0}
252 const_mkldnn_primitive_t outputs[] = { dst_dt.p_ };
253 DNN_SAFE(mkldnn_primitive_create(&c, dpd, inputs, outputs), WARN);
254 SAFE(execute(c), WARN);
255 if (bench_mode & CORR) {
256 compute_ref_bwd_d(&p_tr, dst_fp, wei_tr_fp, bia_fp, src_fp);
257 dnn_mem_t dst(dst_dt, fp, src_format);
258 SAFE(compare_dst(p, dst, dst_fp, r, true), WARN);
260 } else if (p->dir == BWD_D) {
261 mkldnn_primitive_at_t inputs[3] = { {dst_dt.p_, 0}, {wei_dt.p_, 0}, };
262 const_mkldnn_primitive_t outputs[] = { src_dt.p_ };
263 DNN_SAFE(mkldnn_primitive_create(&c, dpd, inputs, outputs), WARN);
264 SAFE(execute(c), WARN);
265 if (bench_mode & CORR) {
266 compute_ref_fwd(&p_tr, dst_fp, wei_tr_fp, zero_fp, src_fp);
267 dnn_mem_t src(src_dt, fp, src_format);
268 SAFE(compare_src(p, src, src_fp, r, true), WARN);
270 } else if (p->dir & FLAG_BWD && p->dir & FLAG_WEI) {
271 mkldnn_primitive_at_t inputs[3] = { {src_dt.p_, 0}, {dst_dt.p_, 0}, };
272 const_mkldnn_primitive_t outputs[] = { wei_dt.p_,
273 p->dir & FLAG_BIA ? bia_dt.p_ : NULL,
275 DNN_SAFE(mkldnn_primitive_create(&c, dpd, inputs, outputs), WARN);
276 SAFE(execute(c), WARN);
277 if (bench_mode & CORR) {
278 compute_ref_bwd_weights(&p_tr, dst_fp, wei_tr_fp, src_fp);
279 transpose_data_wei(&p_tr, wei_tr_fp, wei_fp);
280 dnn_mem_t wei(wei_dt, fp, wei_format);
281 SAFE(compare_wei(&p_tr, wei, wei_fp, r, true), WARN);
282 if (p->dir & FLAG_BIA) {
283 compute_ref_bwd_bias(p, bia_fp, dst_fp);
284 dnn_mem_t bia(bia_dt, fp, mkldnn_x);
285 SAFE(compare_bia(p, bia, bia_fp, r, true), WARN);
295 if (bench_mode & PERF) {
299 SAFE(execute(c), WARN);
301 const bool stop = false
302 || (fix_times_per_prb && t.times() >= fix_times_per_prb)
303 || (!fix_times_per_prb
304 && t.total_ms() >= max_ms_per_prb
305 && t.times() >= min_times_per_prb);
310 DNN_SAFE_V(mkldnn_primitive_destroy(c));
311 DNN_SAFE_V(mkldnn_primitive_desc_destroy(dpd));