1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "rnn/rnn_aux.hpp"
22 int l = snprintf(buffer, rem_len, __VA_ARGS__); \
29 alg_t str2alg(const char *str) {
31 if (!strcasecmp(STRINGIFY(_alg), str)) \
38 assert(!"unknown algorithm");
42 policy_t str2policy(const char *str) {
43 #define CASE(_plc) if (!strcasecmp(STRINGIFY(_plc), str)) return _plc
48 assert(!"unknown policy");
52 const char * policy2str(policy_t policy) {
53 if (policy == NONE) return "none";
54 if (policy == COMMON) return "common";
55 if (policy == PER_OC) return "per_oc";
56 assert(!"unknown policy");
57 return "unknown policy";
60 const char *alg2str(alg_t alg) {
61 if (alg == VANILLA_RNN)
63 if (alg == VANILLA_LSTM)
64 return "VANILLA_LSTM";
65 if (alg == VANILLA_GRU)
69 assert(!"unknown algorithm");
70 return "unknown algorithm";
73 mkldnn_alg_kind_t alg2kind(alg_t alg) {
74 if (alg == VANILLA_RNN)
75 return mkldnn_vanilla_rnn;
76 if (alg == VANILLA_LSTM)
77 return mkldnn_vanilla_lstm;
78 if (alg == VANILLA_GRU)
79 return mkldnn_vanilla_gru;
81 return mkldnn_gru_linear_before_reset;
82 assert(!"unknown algorithm");
83 return mkldnn_alg_kind_undef;
86 activation_t str2activation(const char *str) {
88 if (!strcasecmp(STRINGIFY(_act), str)) \
94 assert(!"unknown activation");
98 const char *activation2str(activation_t act) {
99 const char *str = "unknown activation";
101 case RELU: str = "RELU"; break;
102 case LOGISTIC: str = "LOGISTIC"; break;
103 case TANH: str = "TANH"; break;
104 default: assert(!"unknown activation");
109 mkldnn_alg_kind_t activation2kind(activation_t act) {
110 mkldnn_alg_kind_t alg_kind = mkldnn_alg_kind_undef;
112 case RELU: alg_kind = mkldnn_eltwise_relu; break;
113 case LOGISTIC: alg_kind = mkldnn_eltwise_logistic; break;
114 case TANH: alg_kind = mkldnn_eltwise_tanh; break;
115 default: assert(!"unknown activation");
120 mkldnn_prop_kind_t str2prop(const char *str) {
121 if (!strcasecmp("FWD_D", str))
122 return mkldnn_forward;
123 if (!strcasecmp("BWD_D", str))
124 return mkldnn_backward;
125 assert(!"unknown propagation");
126 return mkldnn_forward;
129 const char *prop2str(mkldnn_prop_kind_t prop) {
130 if (prop == mkldnn_forward)
132 if (prop == mkldnn_backward)
134 assert(!"unknown propagation");
135 return "unknown propagation";
139 mkldnn_rnn_direction_t str2direction(const char *str) {
140 if (!strcasecmp("left2right", str))
141 return mkldnn_unidirectional_left2right;
142 if (!strcasecmp("right2left", str))
143 return mkldnn_unidirectional_right2left;
144 if (!strcasecmp("concat", str))
145 return mkldnn_bidirectional_concat;
146 if (!strcasecmp("sum", str))
147 return mkldnn_bidirectional_sum;
148 assert(!"unknown direction");
149 return mkldnn_unidirectional_left2right;
152 const char *direction2str(mkldnn_rnn_direction_t direction) {
153 if (direction == mkldnn_unidirectional_left2right)
155 if (direction == mkldnn_unidirectional_right2left)
157 if (direction == mkldnn_bidirectional_concat)
159 if (direction == mkldnn_bidirectional_sum)
161 assert(!"unknown direction");
162 return "unknown direction";
165 int str2desc(rnn_desc_t *desc, const char *str) {
169 * lXtXmXsicXslcXdicXdlc
171 * where: X is number, S - string
172 * note: symbol `_` is ignored
176 * l = 1, t = 1, mb = 2, S="wip"
177 * - if slc/dlc/dic is undefined => slc/dlc/dic = sic
188 # define CASE_NN(p, c) do { \
189 if (!strncmp(p, s, strlen(p))) { \
190 ok = 1; s += strlen(p); \
191 char *end_s; d. c = strtol(s, &end_s, 10); s += (end_s - s); \
194 # define CASE_N(c) CASE_NN(#c, c)
197 CASE_NN("l", n_layer);
198 CASE_NN("t", n_iter);
204 if (*s == 'n') { d.name = s + 1; break; }
206 if (!ok) return FAIL;
211 if (d.sic == 0) return FAIL;
212 if (d.slc == 0) d.slc = d.sic;
213 if (d.dlc == 0) d.dlc = d.sic;
214 if (d.dic == 0) d.dic = d.sic;
222 void prb2str(const rnn_prb_t *p, const res_t *res, char *buffer) {
223 int rem_len = max_prb_len;
225 DPRINT("--prop=%s --alg=%s --activation=%s --direction=%s --cfg=%s "
227 prop2str(p->prop), alg2str(p->alg), activation2str(p->activation),
228 direction2str(p->direction), cfg2str(p->cfg),
229 policy2str(p->scale_policy));
230 DPRINT("l%d", p->n_layer);
231 DPRINT("t%d", p->n_iter);
232 DPRINT("mb%d", p->mb);
233 DPRINT("sic%d", p->sic);
234 DPRINT("slc%d", p->slc);
235 DPRINT("dic%d", p->dic);
236 DPRINT("dlc%d", p->dlc);
237 DPRINT("n\"%s\"", p->name);
240 void init_buffer(float *buf, int size, float value) {
241 for (int i = 0; i < size; i++)
245 float logistic(float x) {
247 return (expf(x) / (1 + expf(x)));
249 return 1.0f - (expf(-x) / (1 + expf(-x)));
251 float dlogistic(float x) {
252 float tmp = logistic(x);
253 return tmp * (1 - tmp);
255 float dtanhf(float x) {
256 return (1 - tanhf(x)) * (1 + tanhf(x));
258 float x_m_square(float x) {
261 float relu(float x) {
262 return x > 0 ? x : 0;
264 float drelu(float x) {
267 float one_m_square(float x) {
271 int compare_dat(const rnn_prb_t *p, rnn_data_kind_t kind, dnn_mem_t &mem_dt,
272 dnn_mem_t &mem_fp, res_t *r, bool final_compare = false) {
273 size_t nelems = mem_dt.nelems();
275 const char *skind = rnn_data_kind2str(kind);
277 diff_norm_t diff_norm;
282 for (size_t i = 0; i < nelems; ++i) {
283 const float dt = ((float *)mem_dt)[i];
284 const float fp = ((float *)mem_fp)[i];
285 diff_norm.update(fp, dt);
287 const float diff = fabsf(fp - dt);
288 const float rel_diff = diff / (fabsf(fp) > FLT_MIN ? fabsf(fp) : 1);
290 const bool ok = (fabs(fp) > 1e-5 ? rel_diff : diff) <= p->cfg[kind].eps;
294 if (r->errors < 10 || verbose >= 10) {
295 int n = 0, t = 0, c = 0, s = 0, l = 0, d = 0, w = 0, ic = 0,
299 inv_ntc_off_f(p, i, n, t, c);
300 print(0, "%lu, %s, [%s][%d,%d,%d] "
301 "fp:%8g dt:%8g diff:%8g rdiff:%8g\n",
303 final_compare == false ? "REORDER " : "", skind, n,
304 t, c, fp, dt, diff, rel_diff);
307 inv_ldsnc_off_f(p, i, l, d, s, n, c);
308 print(0, "%lu, %s, [%s][%d,%d,%d,%d,%d] "
309 "fp:%8g dt:%8g diff:%8g rdiff:%8g\n",
311 final_compare == false ? "REORDER " : "", skind, l,
312 d, s, n, c, fp, dt, diff, rel_diff);
315 inv_ldigo_off_f(p, i, l, d, w, ic, oc);
316 print(0, "%lu, %s, [%s][%d,%d,%d,%d,%d] "
317 "fp:%8g dt:%8g diff:%8g rdiff:%8g\n",
319 final_compare == false ? "REORDER " : "", skind, l,
320 d, w, ic, oc, fp, dt, diff, rel_diff);
323 inv_ldigo_off_f(p, i, l, d, w, ic, oc);
324 print(0, "%lu, %s, [%s][%d,%d,%d,%d,%d] "
325 "fp:%8g dt:%8g diff:%8g rdiff:%8g\n",
327 final_compare == false ? "REORDER " : "", skind, l,
328 d, w, ic, oc, fp, dt, diff, rel_diff);
331 inv_ldgo_off_f(p, i, l, d, b, c);
332 print(0, "%lu, %s, [%s][%d,%d,%d,%d] "
333 "fp:%8g dt:%8g diff:%8g rdiff:%8g\n",
335 final_compare == false ? "REORDER " : "", skind, l,
336 d, b, c, fp, dt, diff, rel_diff);
339 inv_tnc_off_f(p, i, s, t, n, c);
340 print(0, "%lu, %s, [%s][%d,%d,%d,%d] "
341 "fp:%8g dt:%8g diff:%8g rdiff:%8g\n",
343 final_compare == false ? "REORDER " : "", skind, s,
344 t, n, c, fp, dt, diff, rel_diff);
346 case dst_last_iteration:
347 inv_ldsnc_off_f(p, i, l, d, s, n, c);
348 print(0, "%lu, %s, [%s][%d,%d,%d,%d,%d "
349 "fp:%8g dt:%8g diff:%8g rdiff:%8g\n",
351 final_compare == false ? "REORDER " : "", skind, l,
352 d, s, n, c, fp, dt, diff, rel_diff);
354 default: assert("unknown data kind"); return FAIL;
360 /* for debug purposes only: dump the output */
361 if (final_compare && verbose >= 50) {
362 int n = 0, t = 0, c = 0, s = 0, l = 0, d = 0, w = 0, ic = 0, oc = 0,
367 inv_ntc_off_f(p, i, n, t, c);
368 print(0, "[%4lu][%s][%d,%d,%d] fp:%8g dt:%8g\n",
369 (unsigned long)i, skind, n, t, c, fp, dt);
372 inv_ldsnc_off_f(p, i, l, d, s, n, c);
373 print(0, "[%4lu][%s][%d,%d,%d,%d,%d] fp:%8g dt:%8g\n",
374 (unsigned long)i, skind, l, d, s, n, c, fp, dt);
377 inv_ldigo_off_f(p, i, l, d, w, ic, oc);
378 print(0, "[%4lu][%s][%d,%d,%d,%d,%d] fp:%8g dt:%8g\n",
379 (unsigned long)i, skind, l, d, w, ic, oc, fp, dt);
382 inv_ldigo_off_f(p, i, l, d, w, ic, oc);
384 print(0, "[%4lu][%s][%d,%d,%d,%d,%d] fp:%8g dt:%8g\n",
385 (unsigned long)i, skind, l, d, w, ic, oc, fp, dt);
387 inv_ldgo_off_f(p, i, l, d, b, c);
389 print(0, "[%4lu][%s][%d,%d,%d,%d] fp:%8g dt:%8g\n",
390 (unsigned long)i, skind, l, d, b, c, fp, dt);
392 inv_tnc_off_f(p, i, s, t, n, c);
393 print(0, "[%4lu][%s][%d,%d,%d] fp:%8g dt:%8g\n",
394 (unsigned long)i, skind, n, t, c, fp, dt);
396 case dst_last_iteration:
397 inv_ldsnc_off_f(p, i, l, d, s, n, c);
398 print(0, "[%4lu][%s][%d,%d,%d,%d,%d] fp:%8g dt:%8g\n",
399 (unsigned long)i, skind, l, d, s, n, c, fp, dt);
402 print(0, "[%4lu][unknown] fp:%8g dt:%8g\n",
403 (unsigned long)i, fp, dt);
412 if (final_compare || r->errors) {
413 const int vl = r->errors ? 0 : 2;
415 "@@@ [%s] %sdiff: l0(``%g``) "
416 "l1:(%g,%g,%g,``%g``) "
417 "l2:(%g,%g,%g,``%g``) "
418 "l8:(%g,%g,%g,``%g``)\n",
419 skind, final_compare ? "final: " : "",
420 diff_norm.rel_diff(norm_t::L0), diff_norm.a_[norm_t::L1],
421 diff_norm.b_[norm_t::L1], diff_norm.diff_[norm_t::L1],
422 diff_norm.rel_diff(norm_t::L1), diff_norm.a_[norm_t::L2],
423 diff_norm.b_[norm_t::L2], diff_norm.diff_[norm_t::L2],
424 diff_norm.rel_diff(norm_t::L2), diff_norm.a_[norm_t::L8],
425 diff_norm.b_[norm_t::L8], diff_norm.diff_[norm_t::L8],
426 diff_norm.rel_diff(norm_t::L8));
432 if (final_compare && r->state == UNTESTED)
433 r->state = PASSED; /* optimism */
435 return r->state == FAILED ? FAIL : OK;
438 int compare_input(const rnn_prb_t *p, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp,
439 res_t *r, bool final_compare = false) {
440 return compare_dat(p, input, mem_dt, mem_fp, r, final_compare);
442 int compare_states(const rnn_prb_t *p, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp,
443 res_t *r, bool final_compare = false) {
444 return compare_dat(p, states, mem_dt, mem_fp, r, final_compare);
446 int compare_weights_input(const rnn_prb_t *p, dnn_mem_t &mem_dt,
447 dnn_mem_t &mem_fp, res_t *r, bool final_compare = false) {
448 return compare_dat(p, weights_input, mem_dt, mem_fp, r, final_compare);
450 int compare_weights_states(const rnn_prb_t *p, dnn_mem_t &mem_dt,
451 dnn_mem_t &mem_fp, res_t *r, bool final_compare = false) {
452 return compare_dat(p, weights_states, mem_dt, mem_fp, r, final_compare);
454 int compare_bias(const rnn_prb_t *p, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp,
455 res_t *r, bool final_compare = false) {
456 return compare_dat(p, bias, mem_dt, mem_fp, r, final_compare);
458 int compare_dst_last_layer(const rnn_prb_t *p, dnn_mem_t &mem_dt,
459 dnn_mem_t &mem_fp, res_t *r, bool final_compare = false) {
460 return compare_dat(p, dst_last_layer, mem_dt, mem_fp, r, final_compare);
462 int compare_dst_last_iteration(const rnn_prb_t *p, dnn_mem_t &mem_dt,
463 dnn_mem_t &mem_fp, res_t *r, bool final_compare = false) {
464 return compare_dat(p, dst_last_iteration, mem_dt, mem_fp, r, final_compare);
467 void rnn_prb_t::set_qparams(float fp_min, float fp_max) {
468 if (cfg == conf_f32) {
475 /* Set parameters for quantization of src and weights from fp32 data
476 * in [-1, 1] to int8 data in a range specified in cfg */
477 float fp_range = fp_max - fp_min;
478 float int8_src_range = cfg[input].f_max - cfg[input].f_min,
479 int8_wei_range = cfg[weights_input].f_max - cfg[weights_input].f_min;
481 data_shift = cfg[input].f_mean;
482 data_scale = int8_src_range / fp_range;
484 if (scale_policy == COMMON) {
485 wei_scale = int8_wei_range / fp_range;
486 } else if (scale_policy == PER_OC) {
487 float K = int8_wei_range / fp_range;
488 int nelems = dic * n_gates();
489 for (int i = 0; i < nelems; i++) {
490 wei_oc_scales[i] = K * (1. + (float)i / nelems);