namespace deconv {
-inline static void swap(int &a, int &b)
-{
- int temp = a;
+template <typename T>
+inline static void swap(T &a, T &b) {
+ T temp = a;
a = b;
b = temp;
}
-inline bool is_deconv_3d(const prb_t *p)
-{
- return (p->id > 1 || p->od > 1) ? 1 : 0;
+
+inline bool is_deconv_3d(const prb_t *p) {
+ return p->id > 1;
}
inline int transpose_data_wei(const prb_t *p, dnn_mem_t &wei, dnn_mem_t &wei_tr) {
int ndims = is_deconv_3d(p) ? 5 : 4;
mkldnn_memory_desc_t src_d, wei_d, bia_d, dst_d;
- mkldnn_dims_t src_dims = {p->mb, p->ic, p->ih, p->iw};
+ mkldnn_dims_t src_2d_dims = {p->mb, p->ic, p->ih, p->iw};
mkldnn_dims_t src_3d_dims = {p->mb, p->ic, p->id, p->ih, p->iw};
- mkldnn_dims_t wei_dims = {p->g, p->oc / p->g, p->ic / p->g, p->kh, p->kw};
+ mkldnn_dims_t wei_2d_dims = {p->g, p->oc / p->g, p->ic / p->g, p->kh, p->kw};
mkldnn_dims_t wei_3d_dims = {p->g, p->oc / p->g, p->ic / p->g, p->kd, p->kh, p->kw};
mkldnn_dims_t bia_dims = {p->oc};
- mkldnn_dims_t dst_dims = {p->mb, p->oc, p->oh, p->ow};
+ mkldnn_dims_t dst_2d_dims = {p->mb, p->oc, p->oh, p->ow};
mkldnn_dims_t dst_3d_dims = {p->mb, p->oc, p->od, p->oh, p->ow};
-
DNN_SAFE(mkldnn_memory_desc_init(&src_d, ndims,
- is_deconv_3d(p) ? src_3d_dims : src_dims, p->cfg[SRC].dt, mkldnn_any), WARN);
- DNN_SAFE(mkldnn_memory_desc_init(&wei_d, ndims + 1,
- is_deconv_3d(p) ? wei_3d_dims : wei_dims, p->cfg[WEI].dt, mkldnn_any), WARN);
+ is_deconv_3d(p) ? src_3d_dims : src_2d_dims, p->cfg[SRC].dt, mkldnn_any), WARN);
+ DNN_SAFE(mkldnn_memory_desc_init(&wei_d, ndims + p->has_groups,
+ is_deconv_3d(p)
+ ? &wei_3d_dims[!p->has_groups]
+ : &wei_2d_dims[!p->has_groups],
+ p->cfg[WEI].dt, mkldnn_any), WARN);
DNN_SAFE(mkldnn_memory_desc_init(&bia_d, 1, bia_dims, p->cfg[BIA].dt, mkldnn_any), WARN);
DNN_SAFE(mkldnn_memory_desc_init(&dst_d, ndims,
- is_deconv_3d(p) ? dst_3d_dims : dst_dims, p->cfg[DST].dt, mkldnn_any), WARN);
- int strides_2d[] = {p->sh, p->sw};
- int dilates_2d[] = {p->dh, p->dw};
- int padding_2d[] = {p->ph, p->pw};
- int strides_3d[] = {p->sd, p->sh, p->sw};
- int dilates_3d[] = {p->dd, p->dh, p->dw};
- int padding_3d[] = {p->pd, p->ph, p->pw};
+ is_deconv_3d(p) ? dst_3d_dims : dst_2d_dims, p->cfg[DST].dt, mkldnn_any), WARN);
+
+ ptrdiff_t strides_nd[] = {p->sd, p->sh, p->sw};
+ ptrdiff_t dilates_nd[] = {p->dd, p->dh, p->dw};
+ ptrdiff_t padding_nd[] = {p->pd, p->ph, p->pw};
auto bph = [&](int ih, int oh, int kh, int sh, int ph, int dh) {
return (oh - 1) * sh - ih + ((kh - 1) * (dh + 1) + 1) - ph;
};
- int padding_r_3d[] = {
+
+ ptrdiff_t padding_r_nd[] = {
bph(p->od, p->id, p->kd, p->sd, p->pd, p->dd),
bph(p->oh, p->ih, p->kh, p->sh, p->ph, p->dh),
bph(p->ow, p->iw, p->kw, p->sw, p->pw, p->dw)};
- int padding_r_2d[] = {
- bph(p->oh, p->ih, p->kh, p->sh, p->ph, p->dh),
- bph(p->ow, p->iw, p->kw, p->sw, p->pw, p->dw)};
- int *strides = is_deconv_3d(p) ? strides_3d : strides_2d;
- int *dilates = is_deconv_3d(p) ? dilates_3d : dilates_2d;
- int *padding = is_deconv_3d(p) ? padding_3d : padding_2d;
- int *padding_r = is_deconv_3d(p) ? padding_r_3d : padding_r_2d;
+ ptrdiff_t *strides = strides_nd + (5 - ndims);
+ ptrdiff_t *dilates = dilates_nd + (5 - ndims);
+ ptrdiff_t *padding = padding_nd + (5 - ndims);
+ ptrdiff_t *padding_r = padding_r_nd + (5 - ndims);
+
mkldnn_alg_kind_t alg = mkldnn_deconvolution_direct;
if (p->alg == WINO) alg = mkldnn_deconvolution_winograd;
*r = res_zero;
bool with_groups = 1;
- prb_t p_tr((desc_t)*p, p->dir, p->cfg, p->alg, p->merge, p->attr, p->mb);
+ prb_t p_tr((desc_t)*p, p->dir, p->cfg, p->alg, p->attr, p->mb, true);
swap(p_tr.ic, p_tr.oc);
swap(p_tr.ih, p_tr.oh);
swap(p_tr.id, p_tr.od);
? new dnn_mem_t(bia_dt_d, p->cfg[BIA].dt) : new dnn_mem_t();
dnn_mem_t &bia_dt = *p_bia_dt;
- auto src_format = is_deconv_3d(p) ? mkldnn_ncdhw : mkldnn_nchw;
- auto wei_format = is_deconv_3d(p) ? mkldnn_goidhw : mkldnn_goihw;
+ auto src_format = get_default_format(src_dt.md_.ndims, DATA);
+ auto wei_format = get_default_format(wei_dt.md_.ndims,
+ p->has_groups ? GWEI : WEI);
const auto fp = mkldnn_f32;
if (bench_mode & CORR) {
compute_ref_bwd_d(&p_tr, dst_fp, wei_tr_fp, bia_fp, src_fp);
dnn_mem_t dst(dst_dt, fp, src_format);
- SAFE(dst.reorder(dst_dt), WARN);
SAFE(compare_dst(p, dst, dst_fp, r, true), WARN);
}
} else if (p->dir == BWD_D) {
if (bench_mode & CORR) {
compute_ref_fwd(&p_tr, dst_fp, wei_tr_fp, zero_fp, src_fp);
dnn_mem_t src(src_dt, fp, src_format);
- SAFE(src.reorder(src_dt), WARN);
SAFE(compare_src(p, src, src_fp, r, true), WARN);
}
} else if (p->dir & FLAG_BWD && p->dir & FLAG_WEI) {
compute_ref_bwd_weights(&p_tr, dst_fp, wei_tr_fp, src_fp);
transpose_data_wei(&p_tr, wei_tr_fp, wei_fp);
dnn_mem_t wei(wei_dt, fp, wei_format);
- SAFE(wei.reorder(wei_dt), WARN);
SAFE(compare_wei(&p_tr, wei, wei_fp, r, true), WARN);
if (p->dir & FLAG_BIA) {
compute_ref_bwd_bias(p, bia_fp, dst_fp);
dnn_mem_t bia(bia_dt, fp, mkldnn_x);
- SAFE(bia.reorder(bia_dt), WARN);
SAFE(compare_bia(p, bia, bia_fp, r, true), WARN);
}
}