return VANILLA_RNN;
}
+policy_t str2policy(const char *str) {
+#define CASE(_plc) if (!strcasecmp(STRINGIFY(_plc), str)) return _plc
+ CASE(NONE);
+ CASE(COMMON);
+ CASE(PER_OC);
+#undef CASE
+ assert(!"unknown policy");
+ return NONE;
+}
+
+const char * policy2str(policy_t policy) {
+ if (policy == NONE) return "none";
+ if (policy == COMMON) return "common";
+ if (policy == PER_OC) return "per_oc";
+ assert(!"unknown policy");
+ return "unknown policy";
+}
+
const char *alg2str(alg_t alg) {
if (alg == VANILLA_RNN)
return "VANILLA_RNN";
return alg_kind;
}
+mkldnn_prop_kind_t str2prop(const char *str) {
+ if (!strcasecmp("FWD_D", str))
+ return mkldnn_forward;
+ if (!strcasecmp("BWD_D", str))
+ return mkldnn_backward;
+ assert(!"unknown propagation");
+ return mkldnn_forward;
+}
+
+const char *prop2str(mkldnn_prop_kind_t prop) {
+ if (prop == mkldnn_forward)
+ return "FWD_D";
+ if (prop == mkldnn_backward)
+ return "BWD_DW";
+ assert(!"unknown propagation");
+ return "unknown propagation";
+
+}
+
mkldnn_rnn_direction_t str2direction(const char *str) {
if (!strcasecmp("left2right", str))
return mkldnn_unidirectional_left2right;
void prb2str(const rnn_prb_t *p, const res_t *res, char *buffer) {
int rem_len = max_prb_len;
- DPRINT("%s,%s,%s,", alg2str(p->alg), activation2str(p->activation),
- direction2str(p->direction));
+ DPRINT("--prop=%s --alg=%s --activation=%s --direction=%s --cfg=%s "
+ "--scaling=%s ",
+ prop2str(p->prop), alg2str(p->alg), activation2str(p->activation),
+ direction2str(p->direction), cfg2str(p->cfg),
+ policy2str(p->scale_policy));
DPRINT("l%d", p->n_layer);
DPRINT("t%d", p->n_iter);
DPRINT("mb%d", p->mb);
}
float logistic(float x) {
- return 1.0f / (1.0f + expf(-x));
+ if (x < 0)
+ return (expf(x) / (1 + expf(x)));
+ else
+ return 1.0f - (expf(-x) / (1 + expf(-x)));
}
float dlogistic(float x) {
- return x * (1 - x);
+ float tmp = logistic(x);
+ return tmp * (1 - tmp);
+}
+float dtanhf(float x) {
+ return (1 - tanhf(x)) * (1 + tanhf(x));
+}
+float x_m_square(float x) {
+ return x - x * x;
}
float relu(float x) {
return x > 0 ? x : 0;
float drelu(float x) {
return float(x > 0);
}
-float dtanhf(float x) {
- return (1 - x) * (1 + x);
+float one_m_square(float x) {
+ return 1 - x * x;
}
int compare_dat(const rnn_prb_t *p, rnn_data_kind_t kind, dnn_mem_t &mem_dt,
return compare_dat(p, dst_last_iteration, mem_dt, mem_fp, r, final_compare);
}
+void rnn_prb_t::set_qparams(float fp_min, float fp_max) {
+ if (cfg == conf_f32) {
+ data_shift = 0.;
+ data_scale = 1.;
+ wei_scale = 1.;
+ return;
+ }
+
+ /* Set parameters for quantization of src and weights from fp32 data
+ * in [-1, 1] to int8 data in a range specified in cfg */
+ float fp_range = fp_max - fp_min;
+ float int8_src_range = cfg[input].f_max - cfg[input].f_min,
+ int8_wei_range = cfg[weights_input].f_max - cfg[weights_input].f_min;
+
+ data_shift = cfg[input].f_mean;
+ data_scale = int8_src_range / fp_range;
+
+ if (scale_policy == COMMON) {
+ wei_scale = int8_wei_range / fp_range;
+ } else if (scale_policy == PER_OC) {
+ float K = int8_wei_range / fp_range;
+ int nelems = dic * n_gates();
+ for (int i = 0; i < nelems; i++) {
+ wei_oc_scales[i] = K * (1. + (float)i / nelems);
+ }
+ }
+}
+
} // namespace rnn