#include "mkldnn_common.hpp"
#include "mkldnn_memory.hpp"
-
#include "norm.hpp"
#include "conv/conv_common.hpp"
namespace conv {
-inline bool is_conv_3d(const prb_t *p)
-{
- return (p->id > 1) ? 1 : 0;
+inline bool is_conv_3d(const prb_t *p) {
+ return p->id > 1;
}
-inline bool is_conv_1d(const prb_t *p)
-{
- return (!is_conv_3d(p) && p->ih == 1 && p->kh == 1
+inline bool is_conv_1d(const prb_t *p) {
+ return !is_conv_3d(p) && p->ih == 1 && p->kh == 1
&& p->cfg[SRC].dt != mkldnn_s8 // temporary workaround until
- && p->cfg[SRC].dt != mkldnn_u8) // int8 jit supports 1d
- ? 1 : 0;
+ && p->cfg[SRC].dt != mkldnn_u8; // int8 jit supports 1d
}
-double get_trust_nz_level(const prb_t *p, data_kind_t kind, bool final_compare)
-{
+double get_trust_nz_level(const prb_t *p, data_kind_t kind,
+ bool final_compare) {
if (!final_compare)
return p->cfg[kind].f_sparsity;
- auto count_relu = [&]() {
+ auto negative_to_zero = [&]() {
+ using pk = attr_t::post_ops_t::kind_t;
const auto &po = p->attr.post_ops;
int count = 0;
- for (int i = 0; i < po.len; ++i)
- count += po.entry[i].kind == attr_t::post_ops_t::kind_t::RELU;
- count = MAX2(count, p->merge == RELU ? 1 : 0);
- return count;
+ for (int i = 0; i < po.len; ++i) {
+ auto k = po.entry[i].kind;
+ count +=
+ k == pk::RELU || k == pk::ELU || k == pk::SQRT || k == pk::BRELU;
+ }
+ return !!count;
};
double trust = 0.3; /* why? */
trust = 0.8 * p->cfg[DST].f_sparsity; /* why? */
break;
case DST:
- trust /= count_relu() == 0 ? 1 : 2;
+ trust /= negative_to_zero() == 0 ? 1 : 2;
break;
}
return trust;
}
+inline bool post_ops_require_integral_check(const prb_t *p) {
+ if (p->attr.post_ops.len == 0) return false;
+
+ using pk = attr_t::post_ops_t::kind_t;
+ const auto &ops = p->attr.post_ops;
+
+ // assumptions: at most 1 eltwise, scale = 1.
+ for (int idx = 0; idx < ops.len; ++idx) {
+ const auto &e = ops.entry[idx];
+ if (e.kind == pk::SUM || e.kind == pk::ABS) continue;
+ if (e.kind == pk::RELU && e.eltwise.alpha == 0.f) continue;
+ return true;
+ }
+
+ return false;
+}
+
inline double get_eps(const prb_t *p, const data_kind_t kind) {
+ // Winograd specifics
if (p->alg & WINO && p->dir & FLAG_WEI) {
/*This is an empirical equation derived by observing growth error
with increasing 'k' dimension in gemm of winograd*/
return p->cfg[kind].eps *
(MAX2(1, pow(10, 0.4 * log10(0.125 * p->mb * p->oh * p->ow))));
}
+
+ // post-ops specifics
+ if (post_ops_require_integral_check(p))
+ return MAX2(1e-5, p->cfg[kind].eps);
+
return p->cfg[kind].eps;
}
inline void get_result(const prb_t *p, const data_kind_t kind, res_t *r,
const diff_norm_t diff_norm) {
- bool wino_test = (p->alg & WINO)
- && (diff_norm.rel_diff(norm_t::L2) <= get_eps(p, kind));
- /* Ignoring elementwise errors for winograd,
- since large relative error in few elements(which are anyways close to zero)
- results in false positive failures*/
+ const float eps = get_eps(p, kind);
+
+ /* Ignoring element-wise errors for Winograd and in some cases of post-ops,
+ * since large relative error in few elements (which are anyways close
+ * to zero) results in false positive failures */
+
+ bool wino_test = (p->alg & WINO) && diff_norm.rel_diff(norm_t::L2) <= eps;
if (wino_test) r->errors = 0;
- r->state = r->errors ? FAILED : r->state;
+
+ bool post_ops_test = post_ops_require_integral_check(p)
+ && diff_norm.rel_diff(norm_t::L2) <= eps;
+ if (post_ops_test) r->errors = 0;
+
+ if (r->errors) r->state = FAILED;
}
inline int compare_dat(const prb_t *p, data_kind_t kind, dnn_mem_t &mem_dt,
dnn_mem_t &mem_fp, res_t *r, bool final_compare = false) {
+ const bool dont_complain = false
+ || (p->alg & WINO)
+ || post_ops_require_integral_check(p);
+
size_t nelems = mem_dt.nelems();
const char *skind = data_kind2str(kind);
}
if (!ok) {
r->errors++;
- if ((!(p->alg & WINO) && r->errors < 10) || verbose >=10) {
+ if ((!dont_complain && r->errors < 10) || verbose >=10) {
int mb_or_g = 0, g_or_oc = 0, c = 0, d = 0, h = 0, w = 0;
switch (kind) {
case SRC: inv_src_off_f(p, i, mb_or_g, g_or_oc, c, d, h, w); break;
}
diff_norm.done();
+ get_result(p, kind, r, diff_norm);
if (final_compare || r->errors) {
const int vl = r->errors ? 0 : 2;
- print(vl, "@@@ [%s] %sdiff: l0(``%g``) "
+ print(vl, "@@@ [%s] %sdiff: err:%d, l0(``%g``) "
"l1:(%g,%g,%g,``%g``) "
"l2:(%g,%g,%g,``%g``) "
"l8:(%g,%g,%g,``%g``)\n",
- skind, final_compare ? "final: " : "",
+ skind, final_compare ? "final: " : "", (int)r->errors,
diff_norm.rel_diff(norm_t::L0),
diff_norm.a_[norm_t::L1], diff_norm.b_[norm_t::L1],
diff_norm.diff_[norm_t::L1], diff_norm.rel_diff(norm_t::L1),
non_zero, (unsigned long)r->total);
}
- get_result(p, kind, r, diff_norm);
-
if (final_compare && r->state == UNTESTED)
r->state = PASSED; /* optimism */
dnn_mem_t *p_mem_00 = check_reorder
? new dnn_mem_t(mem_dt.md_, mkldnn_f32,
- get_default_format(mem_dt.md_.ndims, GWEI))
+ get_default_format(mem_dt.md_.ndims, p->has_groups ? GWEI : WEI))
: &mem_fp;
dnn_mem_t &mem_00 = *p_mem_00;
mkldnn_memory_desc_t src_d, wei_d, bia_d, dst_d;
int ndims = is_conv_3d(p) ? 5 : is_conv_1d(p) ? 3 : 4;
- mkldnn_dims_t src_dims = {p->mb, p->ic, p->ih, p->iw};
mkldnn_dims_t src_1d_dims = {p->mb, p->ic, p->iw};
+ mkldnn_dims_t src_2d_dims = {p->mb, p->ic, p->ih, p->iw};
mkldnn_dims_t src_3d_dims = {p->mb, p->ic, p->id, p->ih, p->iw};
- mkldnn_dims_t wei_dims = {p->g, p->oc / p->g, p->ic / p->g, p->kh, p->kw};
+
mkldnn_dims_t wei_1d_dims = {p->g, p->oc / p->g, p->ic / p->g, p->kw};
+ mkldnn_dims_t wei_2d_dims = {p->g, p->oc / p->g, p->ic / p->g, p->kh, p->kw};
mkldnn_dims_t wei_3d_dims = {p->g, p->oc / p->g, p->ic / p->g, p->kd, p->kh, p->kw};
+
mkldnn_dims_t bia_dims = {p->oc};
- mkldnn_dims_t dst_dims = {p->mb, p->oc, p->oh, p->ow};
+
mkldnn_dims_t dst_1d_dims = {p->mb, p->oc, p->ow};
+ mkldnn_dims_t dst_2d_dims = {p->mb, p->oc, p->oh, p->ow};
mkldnn_dims_t dst_3d_dims = {p->mb, p->oc, p->od, p->oh, p->ow};
DNN_SAFE(mkldnn_memory_desc_init(&src_d, ndims,
- is_conv_3d(p) ? src_3d_dims : is_conv_1d(p) ? src_1d_dims : src_dims,
+ is_conv_3d(p) ? src_3d_dims : is_conv_1d(p) ? src_1d_dims : src_2d_dims,
p->cfg[SRC].dt, mkldnn_any), WARN);
- DNN_SAFE(mkldnn_memory_desc_init(&wei_d, ndims + 1,
- is_conv_3d(p) ? wei_3d_dims : is_conv_1d(p) ? wei_1d_dims : wei_dims,
+
+ DNN_SAFE(mkldnn_memory_desc_init(&wei_d, ndims + p->has_groups,
+ is_conv_3d(p)
+ ? &wei_3d_dims[!p->has_groups]
+ : is_conv_1d(p)
+ ? &wei_1d_dims[!p->has_groups]
+ : &wei_2d_dims[!p->has_groups],
p->cfg[WEI].dt, mkldnn_any), WARN);
+
DNN_SAFE(mkldnn_memory_desc_init(&bia_d, 1, bia_dims, p->cfg[BIA].dt,
mkldnn_any), WARN);
+
DNN_SAFE(mkldnn_memory_desc_init(&dst_d, ndims,
- is_conv_3d(p) ? dst_3d_dims : is_conv_1d(p) ? dst_1d_dims : dst_dims,
+ is_conv_3d(p) ? dst_3d_dims : is_conv_1d(p) ? dst_1d_dims : dst_2d_dims,
p->cfg[DST].dt, mkldnn_any), WARN);
- int strides_nd[] = {p->sd, p->sh, p->sw};
- int dilates_nd[] = {p->dd, p->dh, p->dw};
- int padding_nd[] = {p->pd, p->ph, p->pw};
+
+ ptrdiff_t strides_nd[] = {p->sd, p->sh, p->sw};
+ ptrdiff_t dilates_nd[] = {p->dd, p->dh, p->dw};
+ ptrdiff_t padding_nd[] = {p->pd, p->ph, p->pw};
auto bph = [&](int ih, int oh, int kh, int sh, int ph, int dh) {
return (oh - 1) * sh - ih + ((kh - 1) * (dh + 1) + 1) - ph;
};
- int padding_r_nd[] = {
+ ptrdiff_t padding_r_nd[] = {
bph(p->id, p->od, p->kd, p->sd, p->pd, p->dd),
bph(p->ih, p->oh, p->kh, p->sh, p->ph, p->dh),
bph(p->iw, p->ow, p->kw, p->sw, p->pw, p->dw)};
- int *strides = strides_nd + (5 - ndims);
- int *dilates = dilates_nd + (5 - ndims);
- int *padding = padding_nd + (5 - ndims);
- int *padding_r = padding_r_nd + (5 - ndims);
+ ptrdiff_t *strides = strides_nd + (5 - ndims);
+ ptrdiff_t *dilates = dilates_nd + (5 - ndims);
+ ptrdiff_t *padding = padding_nd + (5 - ndims);
+ ptrdiff_t *padding_r = padding_r_nd + (5 - ndims);
mkldnn_alg_kind_t alg = mkldnn_convolution_direct;
if (p->alg == WINO) alg = mkldnn_convolution_winograd;
+ if (p->alg == AUTO) alg = mkldnn_convolution_auto;
switch (p->dir) {
case FWD_D: case FWD_B: case FWD_I:
auto mkldnn_attr = create_mkldnn_attr(p->attr, p->oc, p->scales);
mkldnn_status_t init_status = mkldnn_success;
- if (p->merge == RELU) {
- mkldnn_convolution_relu_desc_t crd;
- DNN_SAFE(mkldnn_convolution_relu_desc_init(&crd, &cd, 0), WARN);
- init_status = mkldnn_primitive_desc_create_v2(&cpd, &crd, mkldnn_attr,
+ init_status = mkldnn_primitive_desc_create_v2(&cpd, &cd, mkldnn_attr,
engine, NULL);
- } else {
- init_status = mkldnn_primitive_desc_create_v2(&cpd, &cd, mkldnn_attr,
- engine, NULL);
- }
mkldnn_primitive_attr_destroy(mkldnn_attr);
mkldnn_primitive_desc_query_pd(cpd, query, index));
};
+ if (p->alg == AUTO) {
+ mkldnn_convolution_desc_t *temp_conv_desc = {0};
+ DNN_SAFE(mkldnn_primitive_desc_query(cpd,
+ mkldnn_query_convolution_d, 0, &temp_conv_desc), CRIT);
+ cd.alg_kind = temp_conv_desc->alg_kind;
+ }
+
if (p->dir == BWD_D)
cd.diff_src_desc = q(mkldnn_query_diff_src_pd);
else
mkldnn_primitive_t c{};
SAFE(init_pd(p, cd, cpd, r), WARN);
+
+ prb_t *p_temp = nullptr;
+ if (p->alg == AUTO || p->alg == WINO) {
+ p_temp = new prb_t((desc_t)*p, p->dir, p->cfg,
+ p->alg, p->attr, p->mb);
+ if (p->alg == AUTO) p_temp->alg = alg_kind2alg(cd.alg_kind);
+ p_temp->cfg = auto_cfg(p_temp->alg, p->cfg);
+ p = p_temp;
+ }
+
+
if (r->state == SKIPPED || r->state == UNIMPLEMENTED)
return OK;
dnn_mem_t &bia_dt = *p_bia_dt;
auto src_format = get_default_format(src_dt.md_.ndims, DATA);
- auto wei_format = get_default_format(wei_dt.md_.ndims, GWEI);
+ auto wei_format = get_default_format(wei_dt.md_.ndims,
+ p->has_groups ? GWEI : WEI);
const auto fp = mkldnn_f32;
dnn_mem_t src_fp(src_dt_d, fp, src_format);
if (bench_mode & CORR) {
compute_ref_fwd(p, src_fp, wei_fp, bia_fp, dst_fp);
dnn_mem_t dst(dst_dt, fp, src_format);
- SAFE(dst.reorder(dst_dt), WARN);
SAFE(compare_dst(p, dst, dst_fp, r, true), WARN);
}
} else if (p->dir == BWD_D) {
if (bench_mode & CORR) {
compute_ref_bwd_d(p, src_fp, wei_fp, bia_fp, dst_fp);
dnn_mem_t src(src_dt, fp, src_format);
- SAFE(src.reorder(src_dt), WARN);
SAFE(compare_src(p, src, src_fp, r, true), WARN);
}
} else if (p->dir & FLAG_BWD && p->dir & FLAG_WEI) {
if (bench_mode & CORR) {
compute_ref_bwd_w(p, src_fp, wei_fp, bia_fp, dst_fp);
dnn_mem_t wei(wei_dt, fp, wei_format);
- SAFE(wei.reorder(wei_dt), WARN);
SAFE(compare_wei(p, wei, wei_fp, r, true), WARN);
if (p->dir & FLAG_BIA) {
dnn_mem_t bia(bia_dt, fp, mkldnn_x);
- SAFE(bia.reorder(bia_dt), WARN);
SAFE(compare_bia(p, bia, bia_fp, r, true), WARN);
}
}
delete p_bia_dt;
delete p_bia_fp;
+ delete p_temp;
return OK;
}