alg_t str2alg(const char *str) {
#define CASE(_alg) if (!strcasecmp(STRINGIFY(_alg), str)) return _alg
+ CASE(AUTO);
CASE(DIRECT);
CASE(WINO);
#undef CASE
}
const char *alg2str(alg_t alg) {
+ if (alg == AUTO) return "auto";
if (alg == DIRECT) return "direct";
if (alg == WINO) return "wino";
assert(!"unknown algorithm");
return "unknown algorithm";
}
-merge_t str2merge(const char *str) {
-#define CASE(_mrg) if (!strcasecmp(STRINGIFY(_mrg), str)) return _mrg
- CASE(NONE);
- CASE(RELU);
-#undef CASE
- assert(!"unknown merge");
- return NONE;
-}
-
-const char *merge2str(merge_t merge) {
- if (merge == NONE) return "none";
- if (merge == RELU) return "relu";
- assert(!"unknown merge");
- return "unknown merge";
+alg_t alg_kind2alg(mkldnn_alg_kind_t alg) {
+ if (alg == mkldnn_convolution_auto) return AUTO;
+ if (alg == mkldnn_convolution_direct) return DIRECT;
+ if (alg == mkldnn_convolution_winograd) return WINO;
+ assert(!"unknown algorithm");
+ return DIRECT;
}
int str2desc(desc_t *desc, const char *str, bool is_deconv) {
* - if padding is undefined => compute trivial padding
*/
- d.g = 1; d.mb = 2; d.sd = d.sh = d.sw = 1; d.dd = d.dh = d.dw = 0; d.name = "\"wip\"";
+ d.g = 1; d.mb = 2; d.sd = d.sh = d.sw = 1; d.dd = d.dh = d.dw = 0;
+ d.has_groups = false, d.name = "\"wip\"";
+ d.pw = -1; d.ph = -1; d.pd = -1;
const char *s = str;
assert(s);
if (!strncmp(p, s, strlen(p))) { \
ok = 1; s += strlen(p); \
char *end_s; d. c = strtol(s, &end_s, 10); s += (end_s - s); \
+ if (!strncmp(p, "g", 1)) d.has_groups = true; \
/* printf("@@@debug: %s: %d\n", p, d. c); */ \
} \
} while (0)
return ((o - 1) * s - i + ((k - 1) * (d + 1) + 1)) / 2;
};
- const bool no_d = (d.id | d.kd | d.od | d.pd | d.dd) == 0 && d.sd == 1;
- const bool no_h = (d.ih | d.kh | d.oh | d.ph | d.dh) == 0 && d.sh == 1;
- const bool no_w = (d.iw | d.kw | d.ow | d.pw | d.dw) == 0 && d.sw == 1;
-
+ const bool no_d = (d.id | d.kd | d.od | d.dd) == 0 && d.sd == 1 && d.pd < 1;
+ const bool no_h = (d.ih | d.kh | d.oh | d.dh) == 0 && d.sh == 1 && d.ph < 1;
+ const bool no_w = (d.iw | d.kw | d.ow | d.dw) == 0 && d.sw == 1 && d.pw < 1;
if (!no_h) {
if (!d.ih || !d.kh) return FAIL;
-
- if (!d.oh) d.oh = compute_out(is_deconv, d.ih, d.kh, d.sh, d.ph, d.dh);
- else if (!d.ph && d.oh != compute_out(is_deconv, d.ih, d.kh, d.sh, d.ph, d.dh))
+ if (!d.oh) {
+ d.ph = 0;
+ d.oh = compute_out(is_deconv, d.ih, d.kh, d.sh, d.ph, d.dh);
+ } else if (d.ph < 0)
d.ph = compute_pad(is_deconv, d.oh, d.ih, d.kh, d.sh, d.dh);
}
if (!no_w) {
if (!d.iw || !d.kw) return FAIL;
-
- if (!d.ow) d.ow = compute_out(is_deconv, d.iw, d.kw, d.sw, d.pw, d.dw);
- else if (!d.pw && d.ow != compute_out(is_deconv, d.iw, d.kw, d.sw, d.pw, d.dw))
+ if (!d.ow) {
+ d.pw = 0;
+ d.ow = compute_out(is_deconv, d.iw, d.kw, d.sw, d.pw, d.dw);
+ } else if (d.pw < 0)
d.pw = compute_pad(is_deconv, d.ow, d.iw, d.kw, d.sw, d.dw);
}
if (!no_d && d.id) {
if (!d.id || !d.kd) return FAIL;
-
- if (!d.od) d.od = compute_out(is_deconv, d.id, d.kd, d.sd, d.pd, d.dd);
- else if (!d.pd && d.od != compute_out(is_deconv, d.id, d.kd, d.sd, d.pd, d.dd))
+ if (!d.od) {
+ d.pd = 0;
+ d.od = compute_out(is_deconv, d.id, d.kd, d.sd, d.pd, d.dd);
+ } else if (d.pd < 0)
d.pd = compute_pad(is_deconv, d.od, d.id, d.kd, d.sd, d.dd);
}
-
if (no_w && no_h && d.id) {
d.iw = d.ih = d.id;
d.kw = d.kh = d.kd;
buffer += l; rem_len -= l; \
} while(0)
- if (canonical || d->g != 1) DPRINT("g%d", d->g);
+ if (canonical || d->has_groups) DPRINT("g%d", d->g);
if (canonical || d->mb != 2) DPRINT("mb%d", d->mb);
const bool half_form = (d->ih == d->iw && d->kh == d->kw && d->oh == d->ow
void prb_t::count_ops() {
if (ops > 0) return;
+ int od_t = is_deconv ? this->id : this->od;
+ int oh_t = is_deconv ? this->ih : this->oh;
+ int ow_t = is_deconv ? this->iw : this->ow;
+ int id_t = is_deconv ? this->od : this->id;
+ int ih_t = is_deconv ? this->oh : this->ih;
+ int iw_t = is_deconv ? this->ow : this->iw;
double sp_ops = 0;
- for (int od = 0; od < this->od; ++od) {
- for (int oh = 0; oh < this->oh; ++oh) {
- for (int ow = 0; ow < this->ow; ++ow) {
+ for (int od = 0; od < od_t; ++od) {
+ for (int oh = 0; oh < oh_t; ++oh) {
+ for (int ow = 0; ow < ow_t; ++ow) {
for (int kd = 0; kd < this->kd; ++kd) {
const int id = od * this->sd - this->pd + kd * (this->dd + 1);
- if (id < 0 || id >= this->id) continue;
+ if (id < 0 || id >= id_t) continue;
for (int kh = 0; kh < this->kh; ++kh) {
const int ih = oh * this->sh - this->ph + kh * (this->dh + 1);
- if (ih < 0 || ih >= this->ih) continue;
+ if (ih < 0 || ih >= ih_t) continue;
for (int kw = 0; kw < this->kw; ++kw) {
const int iw = ow * this->sw - this->pw + kw * (this->dw + 1);
- if (iw < 0 || iw >= this->iw) continue;
+ if (iw < 0 || iw >= iw_t) continue;
sp_ops += 1;
}
}
void prb2str(const prb_t *p, char *buffer, bool canonical) {
char desc_buf[max_desc_len], attr_buf[max_attr_len];
- char dir_str[32] = {0}, cfg_str[32] = {0}, alg_str[32] = {0},
- merge_str[32] = {0};
+ char dir_str[32] = {0}, cfg_str[32] = {0}, alg_str[32] = {0};
desc2str(p, desc_buf, canonical);
snprintf(dir_str, sizeof(dir_str), "--dir=%s ", dir2str(p->dir));
snprintf(cfg_str, sizeof(cfg_str), "--cfg=%s ", cfg2str(p->cfg));
snprintf(alg_str, sizeof(alg_str), "--alg=%s ", alg2str(p->alg));
- snprintf(merge_str, sizeof(merge_str), "--merge=%s ", merge2str(p->merge));
bool is_attr_def = p->attr.is_def();
if (!is_attr_def) {
int len = snprintf(attr_buf, max_attr_len, "--attr=\"");
len = (int)strnlen(attr_buf, max_attr_len);
snprintf(attr_buf + len, max_attr_len - len, "\" ");
}
- snprintf(buffer, max_prb_len, "%s%s%s%s%s%s",
+ snprintf(buffer, max_prb_len, "%s%s%s%s%s",
p->dir == FWD_B ? "" : dir_str,
p->cfg == conf_f32 ? "" : cfg_str,
p->alg == DIRECT ? "" : alg_str,
- p->merge == NONE ? "" : merge_str,
is_attr_def ? "" : attr_buf,
desc_buf);
}