case mkldnn_gOIw8i16o2i:
case mkldnn_goihw:
case mkldnn_hwigo:
+ case mkldnn_giohw:
case mkldnn_hwigo_s8s8:
case mkldnn_gOIhw8i8o:
case mkldnn_gOIhw16i16o:
case mkldnn_gOhwi16o:
case mkldnn_Goihw8g:
case mkldnn_Goihw16g:
+ case mkldnn_Goihw16g_s8s8:
case mkldnn_gOhIw16o4i:
case mkldnn_goidhw:
case mkldnn_gOIdhw16i16o:
#define CASE(_knd) if (!strcasecmp(STRINGIFY(_knd), str)) return _knd
CASE(SUM);
CASE(RELU);
+ CASE(TANH);
+ CASE(ELU);
+ CASE(SQUARE);
+ CASE(ABS);
+ CASE(SQRT);
+ CASE(LINEAR);
+ CASE(BRELU);
+ CASE(SRELU);
+ CASE(LOGISTIC);
#undef CASE
assert(!"unknown attr::post_ops::kind");
return KIND_TOTAL;
}
const char *attr_t::post_ops_t::kind2str(attr_t::post_ops_t::kind_t kind) {
- if (kind == SUM) return "sum";
- if (kind == RELU) return "relu";
+#define CASE(_knd, str) if (kind == _knd) return str
+ CASE(SUM, "sum");
+ CASE(RELU, "relu");
+ CASE(TANH, "tanh");
+ CASE(ELU, "elu");
+ CASE(SQUARE, "square");
+ CASE(ABS, "abs");
+ CASE(SQRT, "sqrt");
+ CASE(LINEAR, "linear");
+ CASE(BRELU, "brelu");
+ CASE(SRELU, "srelu");
+ CASE(LOGISTIC, "logistic");
+#undef CASE
assert(!"unknown attr::post_ops::kind");
return "unknown attr::post_ops::kind";
}
+mkldnn_alg_kind_t attr_t::post_ops_t::kind2mkldnn_kind(
+ attr_t::post_ops_t::kind_t kind) {
+#define CASE(_knd, _mknd) if (kind == _knd) return _mknd
+ CASE(RELU, mkldnn_eltwise_relu);
+ CASE(TANH, mkldnn_eltwise_tanh);
+ CASE(ELU, mkldnn_eltwise_elu);
+ CASE(SQUARE, mkldnn_eltwise_square);
+ CASE(ABS, mkldnn_eltwise_abs);
+ CASE(SQRT, mkldnn_eltwise_sqrt);
+ CASE(LINEAR, mkldnn_eltwise_linear);
+ CASE(BRELU, mkldnn_eltwise_bounded_relu);
+ CASE(SRELU, mkldnn_eltwise_soft_relu);
+ CASE(LOGISTIC, mkldnn_eltwise_logistic);
+#undef CASE
+ assert(!"unknown attr::post_ops::kind");
+ return mkldnn_alg_kind_undef;
+}
+
int attr_t::post_ops_t::from_str(const char *str, const char **end_s) {
*this = post_ops_t();
} else {
e.sum.scale = 1.f;
}
- } else if (k == RELU) {
+ } else {
+ e.eltwise.alg = kind2mkldnn_kind(k);
e.eltwise.scale = 1.f;
e.eltwise.alpha = e.eltwise.beta = 0.f;
+
+ for (int i = 0; i < 3; ++i) {
+ // :alpha:beta:scale
+ float &val = i == 0 ? e.eltwise.alpha
+ : i == 1 ? e.eltwise.beta : e.eltwise.scale;
+ if (*s == ':') {
+ char *end;
+ val = strtof(++s, &end);
+ if (end == s) return FAIL;
+ s = end;
+ } else {
+ break;
+ }
+ }
+
+ if (e.eltwise.scale <= 0) return FAIL;
}
break;
buffer += sprintf(buffer, "%s:%g", kind2str(e.kind), e.sum.scale);
break;
case RELU:
- buffer += sprintf(buffer, "%s", kind2str(e.kind));
+ case TANH:
+ case ELU:
+ case SQUARE:
+ case ABS:
+ case SQRT:
+ case LINEAR:
+ case BRELU:
+ case SRELU:
+ case LOGISTIC:
+ buffer += sprintf(buffer, "%s:%g", kind2str(e.kind), e.eltwise.alpha);
+ if (e.eltwise.beta != 0.f || e.eltwise.scale != 1.f)
+ buffer += sprintf(buffer, ":%g:%g", e.eltwise.beta, e.eltwise.scale);
break;
default:
assert(!"unknown kind");
DNN_SAFE_V(mkldnn_post_ops_append_sum(ops, e.sum.scale));
break;
case attr_t::post_ops_t::RELU:
+ case attr_t::post_ops_t::TANH:
+ case attr_t::post_ops_t::ELU:
+ case attr_t::post_ops_t::SQUARE:
+ case attr_t::post_ops_t::ABS:
+ case attr_t::post_ops_t::SQRT:
+ case attr_t::post_ops_t::LINEAR:
+ case attr_t::post_ops_t::BRELU:
+ case attr_t::post_ops_t::SRELU:
+ case attr_t::post_ops_t::LOGISTIC:
DNN_SAFE_V(mkldnn_post_ops_append_eltwise(ops, e.eltwise.scale,
- mkldnn_eltwise_relu, e.eltwise.alpha,
- e.eltwise.beta));
+ e.eltwise.alg, e.eltwise.alpha, e.eltwise.beta));
break;
default:
assert(!"unknown attr::post_ops::kind");