Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / benchdnn / rnn / rnn_aux.cpp
1 /*******************************************************************************
2  * Copyright 2018 Intel Corporation
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  *******************************************************************************/
16
17 #include "rnn/rnn_aux.hpp"
18 #include "norm.hpp"
19
20 #define DPRINT(...)                                     \
21     do {                                                \
22         int l = snprintf(buffer, rem_len, __VA_ARGS__); \
23         buffer += l;                                    \
24         rem_len -= l;                                   \
25     } while (0)
26
27 namespace rnn {
28
29 alg_t str2alg(const char *str) {
30 #define CASE(_alg)                         \
31     if (!strcasecmp(STRINGIFY(_alg), str)) \
32     return _alg
33     CASE(VANILLA_RNN);
34     CASE(VANILLA_LSTM);
35     CASE(VANILLA_GRU);
36     CASE(LBR_GRU);
37 #undef CASE
38     assert(!"unknown algorithm");
39     return VANILLA_RNN;
40 }
41
42 policy_t str2policy(const char *str) {
43 #define CASE(_plc) if (!strcasecmp(STRINGIFY(_plc), str)) return _plc
44     CASE(NONE);
45     CASE(COMMON);
46     CASE(PER_OC);
47 #undef CASE
48     assert(!"unknown policy");
49     return NONE;
50 }
51
52 const char * policy2str(policy_t policy) {
53     if (policy == NONE) return "none";
54     if (policy == COMMON) return "common";
55     if (policy == PER_OC) return "per_oc";
56     assert(!"unknown policy");
57     return "unknown policy";
58 }
59
60 const char *alg2str(alg_t alg) {
61     if (alg == VANILLA_RNN)
62         return "VANILLA_RNN";
63     if (alg == VANILLA_LSTM)
64         return "VANILLA_LSTM";
65     if (alg == VANILLA_GRU)
66         return "VANILLA_GRU";
67     if (alg == LBR_GRU)
68         return "LBR_GRU";
69     assert(!"unknown algorithm");
70     return "unknown algorithm";
71 }
72
73 mkldnn_alg_kind_t alg2kind(alg_t alg) {
74     if (alg == VANILLA_RNN)
75         return mkldnn_vanilla_rnn;
76     if (alg == VANILLA_LSTM)
77         return mkldnn_vanilla_lstm;
78     if (alg == VANILLA_GRU)
79         return mkldnn_vanilla_gru;
80     if (alg == LBR_GRU)
81         return mkldnn_gru_linear_before_reset;
82     assert(!"unknown algorithm");
83     return mkldnn_alg_kind_undef;
84 }
85
86 activation_t str2activation(const char *str) {
87 #define CASE(_act)                         \
88     if (!strcasecmp(STRINGIFY(_act), str)) \
89     return _act
90     CASE(RELU);
91     CASE(LOGISTIC);
92     CASE(TANH);
93 #undef CASE
94     assert(!"unknown activation");
95     return TANH;
96 }
97
98 const char *activation2str(activation_t act) {
99     const char *str = "unknown activation";
100     switch (act) {
101     case RELU: str = "RELU"; break;
102     case LOGISTIC: str = "LOGISTIC"; break;
103     case TANH: str = "TANH"; break;
104     default: assert(!"unknown activation");
105     }
106     return str;
107 }
108
109 mkldnn_alg_kind_t activation2kind(activation_t act) {
110     mkldnn_alg_kind_t alg_kind = mkldnn_alg_kind_undef;
111     switch (act) {
112     case RELU: alg_kind = mkldnn_eltwise_relu; break;
113     case LOGISTIC: alg_kind = mkldnn_eltwise_logistic; break;
114     case TANH: alg_kind = mkldnn_eltwise_tanh; break;
115     default: assert(!"unknown activation");
116     }
117     return alg_kind;
118 }
119
120 mkldnn_prop_kind_t str2prop(const char *str) {
121     if (!strcasecmp("FWD_D", str))
122         return mkldnn_forward;
123     if (!strcasecmp("BWD_D", str))
124         return mkldnn_backward;
125     assert(!"unknown propagation");
126     return mkldnn_forward;
127 }
128
129 const char *prop2str(mkldnn_prop_kind_t prop) {
130     if (prop == mkldnn_forward)
131         return "FWD_D";
132     if (prop == mkldnn_backward)
133         return "BWD_DW";
134     assert(!"unknown propagation");
135     return "unknown propagation";
136
137 }
138
139 mkldnn_rnn_direction_t str2direction(const char *str) {
140     if (!strcasecmp("left2right", str))
141         return mkldnn_unidirectional_left2right;
142     if (!strcasecmp("right2left", str))
143         return mkldnn_unidirectional_right2left;
144     if (!strcasecmp("concat", str))
145         return mkldnn_bidirectional_concat;
146     if (!strcasecmp("sum", str))
147         return mkldnn_bidirectional_sum;
148     assert(!"unknown direction");
149     return mkldnn_unidirectional_left2right;
150 }
151
152 const char *direction2str(mkldnn_rnn_direction_t direction) {
153     if (direction == mkldnn_unidirectional_left2right)
154         return "left2right";
155     if (direction == mkldnn_unidirectional_right2left)
156         return "right2left";
157     if (direction == mkldnn_bidirectional_concat)
158         return "concat";
159     if (direction == mkldnn_bidirectional_sum)
160         return "sum";
161     assert(!"unknown direction");
162     return "unknown direction";
163 }
164
165 int str2desc(rnn_desc_t *desc, const char *str) {
166     rnn_desc_t d{0};
167
168     /* canonical form:
169      * lXtXmXsicXslcXdicXdlc
170      *
171      * where: X is number, S - string
172      * note: symbol `_` is ignored
173      *
174      * implicit rules:
175      *  - default values:
176      *      l = 1, t = 1, mb = 2, S="wip"
177      *  - if slc/dlc/dic is undefined => slc/dlc/dic = sic
178      */
179
180     d.n_layer = 1;
181     d.n_iter = 1;
182     d.mb = 2;
183     d.name = "\"wip\"";
184
185     const char *s = str;
186     assert(s);
187
188 #   define CASE_NN(p, c) do { \
189         if (!strncmp(p, s, strlen(p))) { \
190             ok = 1; s += strlen(p); \
191             char *end_s; d. c = strtol(s, &end_s, 10); s += (end_s - s); \
192         } \
193     } while (0)
194 #   define CASE_N(c) CASE_NN(#c, c)
195     while (*s) {
196         int ok = 0;
197         CASE_NN("l", n_layer);
198         CASE_NN("t", n_iter);
199         CASE_N(mb);
200         CASE_N(sic);
201         CASE_N(slc);
202         CASE_N(dic);
203         CASE_N(dlc);
204         if (*s == 'n') { d.name = s + 1; break; }
205         if (*s == '_') ++s;
206         if (!ok) return FAIL;
207     }
208 #   undef CASE_NN
209 #   undef CASE_N
210
211     if (d.sic == 0) return FAIL;
212     if (d.slc == 0) d.slc = d.sic;
213     if (d.dlc == 0) d.dlc = d.sic;
214     if (d.dic == 0) d.dic = d.sic;
215
216     *desc = d;
217
218     return OK;
219 }
220
221
222 void prb2str(const rnn_prb_t *p, const res_t *res, char *buffer) {
223     int rem_len = max_prb_len;
224
225     DPRINT("--prop=%s --alg=%s --activation=%s --direction=%s --cfg=%s "
226            "--scaling=%s ",
227             prop2str(p->prop), alg2str(p->alg), activation2str(p->activation),
228             direction2str(p->direction), cfg2str(p->cfg),
229             policy2str(p->scale_policy));
230     DPRINT("l%d", p->n_layer);
231     DPRINT("t%d", p->n_iter);
232     DPRINT("mb%d", p->mb);
233     DPRINT("sic%d", p->sic);
234     DPRINT("slc%d", p->slc);
235     DPRINT("dic%d", p->dic);
236     DPRINT("dlc%d", p->dlc);
237     DPRINT("n\"%s\"", p->name);
238 }
239
240 void init_buffer(float *buf, int size, float value) {
241     for (int i = 0; i < size; i++)
242         buf[i] = value;
243 }
244
245 float logistic(float x) {
246     if (x < 0)
247         return (expf(x) / (1 + expf(x)));
248     else
249         return 1.0f - (expf(-x) / (1 + expf(-x)));
250 }
251 float dlogistic(float x) {
252     float tmp = logistic(x);
253     return tmp * (1 - tmp);
254 }
255 float dtanhf(float x) {
256     return (1 - tanhf(x)) * (1 + tanhf(x));
257 }
258 float x_m_square(float x) {
259     return x - x * x;
260 }
261 float relu(float x) {
262     return x > 0 ? x : 0;
263 }
264 float drelu(float x) {
265     return float(x > 0);
266 }
267 float one_m_square(float x) {
268     return 1 - x * x;
269 }
270
271 int compare_dat(const rnn_prb_t *p, rnn_data_kind_t kind, dnn_mem_t &mem_dt,
272         dnn_mem_t &mem_fp, res_t *r, bool final_compare = false) {
273     size_t nelems = mem_dt.nelems();
274
275     const char *skind = rnn_data_kind2str(kind);
276
277     diff_norm_t diff_norm;
278
279     r->errors = 0;
280     r->total = nelems;
281
282     for (size_t i = 0; i < nelems; ++i) {
283         const float dt = ((float *)mem_dt)[i];
284         const float fp = ((float *)mem_fp)[i];
285         diff_norm.update(fp, dt);
286
287         const float diff = fabsf(fp - dt);
288         const float rel_diff = diff / (fabsf(fp) > FLT_MIN ? fabsf(fp) : 1);
289
290         const bool ok = (fabs(fp) > 1e-5 ? rel_diff : diff) <= p->cfg[kind].eps;
291
292         if (!ok) {
293             r->errors++;
294             if (r->errors < 10 || verbose >= 10) {
295                 int n = 0, t = 0, c = 0, s = 0, l = 0, d = 0, w = 0, ic = 0,
296                     oc = 0, b = 0;
297                 switch (kind) {
298                 case input:
299                     inv_ntc_off_f(p, i, n, t, c);
300                     print(0, "%lu, %s, [%s][%d,%d,%d] "
301                              "fp:%8g dt:%8g diff:%8g rdiff:%8g\n",
302                             (unsigned long)i,
303                             final_compare == false ? "REORDER " : "", skind, n,
304                             t, c, fp, dt, diff, rel_diff);
305                     break;
306                 case states:
307                     inv_ldsnc_off_f(p, i, l, d, s, n, c);
308                     print(0, "%lu, %s, [%s][%d,%d,%d,%d,%d] "
309                              "fp:%8g dt:%8g diff:%8g rdiff:%8g\n",
310                             (unsigned long)i,
311                             final_compare == false ? "REORDER " : "", skind, l,
312                             d, s, n, c, fp, dt, diff, rel_diff);
313                     break;
314                 case weights_input:
315                     inv_ldigo_off_f(p, i, l, d, w, ic, oc);
316                     print(0, "%lu, %s, [%s][%d,%d,%d,%d,%d] "
317                              "fp:%8g dt:%8g diff:%8g rdiff:%8g\n",
318                             (unsigned long)i,
319                             final_compare == false ? "REORDER " : "", skind, l,
320                             d, w, ic, oc, fp, dt, diff, rel_diff);
321                     break;
322                 case weights_states:
323                     inv_ldigo_off_f(p, i, l, d, w, ic, oc);
324                     print(0, "%lu, %s, [%s][%d,%d,%d,%d,%d] "
325                              "fp:%8g dt:%8g diff:%8g rdiff:%8g\n",
326                             (unsigned long)i,
327                             final_compare == false ? "REORDER " : "", skind, l,
328                             d, w, ic, oc, fp, dt, diff, rel_diff);
329                     break;
330                 case bias:
331                     inv_ldgo_off_f(p, i, l, d, b, c);
332                     print(0, "%lu, %s, [%s][%d,%d,%d,%d] "
333                              "fp:%8g dt:%8g diff:%8g rdiff:%8g\n",
334                             (unsigned long)i,
335                             final_compare == false ? "REORDER " : "", skind, l,
336                             d, b, c, fp,  dt, diff, rel_diff);
337                     break;
338                 case dst_last_layer:
339                     inv_tnc_off_f(p, i, s, t, n, c);
340                     print(0, "%lu, %s, [%s][%d,%d,%d,%d] "
341                              "fp:%8g dt:%8g diff:%8g rdiff:%8g\n",
342                             (unsigned long)i,
343                             final_compare == false ? "REORDER " : "", skind, s,
344                             t, n, c, fp, dt, diff, rel_diff);
345                     break;
346                 case dst_last_iteration:
347                     inv_ldsnc_off_f(p, i, l, d, s, n, c);
348                     print(0, "%lu, %s, [%s][%d,%d,%d,%d,%d "
349                              "fp:%8g dt:%8g diff:%8g rdiff:%8g\n",
350                             (unsigned long)i,
351                             final_compare == false ? "REORDER " : "", skind, l,
352                             d, s, n, c, fp, dt, diff, rel_diff);
353                     break;
354                 default: assert("unknown data kind"); return FAIL;
355                 }
356             }
357         }
358
359 #if 1
360         /* for debug purposes only: dump the output */
361         if (final_compare && verbose >= 50) {
362             int n = 0, t = 0, c = 0, s = 0, l = 0, d = 0, w = 0, ic = 0, oc = 0,
363                 b = 0;
364
365             switch (kind) {
366             case input:
367                 inv_ntc_off_f(p, i, n, t, c);
368                 print(0, "[%4lu][%s][%d,%d,%d] fp:%8g dt:%8g\n",
369                         (unsigned long)i, skind, n, t, c, fp, dt);
370                 break;
371             case states:
372                 inv_ldsnc_off_f(p, i, l, d, s, n, c);
373                 print(0, "[%4lu][%s][%d,%d,%d,%d,%d] fp:%8g dt:%8g\n",
374                         (unsigned long)i, skind, l, d, s, n, c, fp, dt);
375                 break;
376             case weights_input:
377                 inv_ldigo_off_f(p, i, l, d, w, ic, oc);
378                 print(0, "[%4lu][%s][%d,%d,%d,%d,%d] fp:%8g dt:%8g\n",
379                         (unsigned long)i, skind, l, d, w, ic, oc, fp, dt);
380                 break;
381             case weights_states:
382                 inv_ldigo_off_f(p, i, l, d, w, ic, oc);
383                 break;
384                 print(0, "[%4lu][%s][%d,%d,%d,%d,%d] fp:%8g dt:%8g\n",
385                         (unsigned long)i, skind, l, d, w, ic, oc, fp, dt);
386             case bias:
387                 inv_ldgo_off_f(p, i, l, d, b, c);
388                 break;
389                 print(0, "[%4lu][%s][%d,%d,%d,%d] fp:%8g dt:%8g\n",
390                         (unsigned long)i, skind, l, d, b, c, fp, dt);
391             case dst_last_layer:
392                 inv_tnc_off_f(p, i, s, t, n, c);
393                 print(0, "[%4lu][%s][%d,%d,%d] fp:%8g dt:%8g\n",
394                         (unsigned long)i, skind, n, t, c, fp, dt);
395                 break;
396             case dst_last_iteration:
397                 inv_ldsnc_off_f(p, i, l, d, s, n, c);
398                 print(0, "[%4lu][%s][%d,%d,%d,%d,%d] fp:%8g dt:%8g\n",
399                         (unsigned long)i, skind, l, d, s, n, c, fp, dt);
400                 break;
401             default:
402                 print(0, "[%4lu][unknown] fp:%8g dt:%8g\n",
403                         (unsigned long)i, fp, dt);
404                 break;
405             }
406         }
407 #endif
408     }
409
410     diff_norm.done();
411
412     if (final_compare || r->errors) {
413         const int vl = r->errors ? 0 : 2;
414         print(vl,
415                 "@@@ [%s] %sdiff: l0(``%g``) "
416                 "l1:(%g,%g,%g,``%g``) "
417                 "l2:(%g,%g,%g,``%g``) "
418                 "l8:(%g,%g,%g,``%g``)\n",
419                 skind, final_compare ? "final: " : "",
420                 diff_norm.rel_diff(norm_t::L0), diff_norm.a_[norm_t::L1],
421                 diff_norm.b_[norm_t::L1], diff_norm.diff_[norm_t::L1],
422                 diff_norm.rel_diff(norm_t::L1), diff_norm.a_[norm_t::L2],
423                 diff_norm.b_[norm_t::L2], diff_norm.diff_[norm_t::L2],
424                 diff_norm.rel_diff(norm_t::L2), diff_norm.a_[norm_t::L8],
425                 diff_norm.b_[norm_t::L8], diff_norm.diff_[norm_t::L8],
426                 diff_norm.rel_diff(norm_t::L8));
427     }
428
429     if (r->errors)
430         r->state = FAILED;
431
432     if (final_compare && r->state == UNTESTED)
433         r->state = PASSED; /* optimism */
434
435     return r->state == FAILED ? FAIL : OK;
436 }
437
438 int compare_input(const rnn_prb_t *p, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp,
439         res_t *r, bool final_compare = false) {
440     return compare_dat(p, input, mem_dt, mem_fp, r, final_compare);
441 }
442 int compare_states(const rnn_prb_t *p, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp,
443         res_t *r, bool final_compare = false) {
444     return compare_dat(p, states, mem_dt, mem_fp, r, final_compare);
445 }
446 int compare_weights_input(const rnn_prb_t *p, dnn_mem_t &mem_dt,
447         dnn_mem_t &mem_fp, res_t *r, bool final_compare = false) {
448     return compare_dat(p, weights_input, mem_dt, mem_fp, r, final_compare);
449 }
450 int compare_weights_states(const rnn_prb_t *p, dnn_mem_t &mem_dt,
451         dnn_mem_t &mem_fp, res_t *r, bool final_compare = false) {
452     return compare_dat(p, weights_states, mem_dt, mem_fp, r, final_compare);
453 }
454 int compare_bias(const rnn_prb_t *p, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp,
455         res_t *r, bool final_compare = false) {
456     return compare_dat(p, bias, mem_dt, mem_fp, r, final_compare);
457 }
458 int compare_dst_last_layer(const rnn_prb_t *p, dnn_mem_t &mem_dt,
459         dnn_mem_t &mem_fp, res_t *r, bool final_compare = false) {
460     return compare_dat(p, dst_last_layer, mem_dt, mem_fp, r, final_compare);
461 }
462 int compare_dst_last_iteration(const rnn_prb_t *p, dnn_mem_t &mem_dt,
463         dnn_mem_t &mem_fp, res_t *r, bool final_compare = false) {
464     return compare_dat(p, dst_last_iteration, mem_dt, mem_fp, r, final_compare);
465 }
466
467 void rnn_prb_t::set_qparams(float fp_min, float fp_max) {
468     if (cfg == conf_f32) {
469         data_shift = 0.;
470         data_scale = 1.;
471         wei_scale = 1.;
472         return;
473     }
474
475     /* Set parameters for quantization of src and weights from fp32 data
476      * in [-1, 1] to int8 data in a range specified in cfg */
477     float fp_range = fp_max - fp_min;
478     float int8_src_range = cfg[input].f_max - cfg[input].f_min,
479           int8_wei_range = cfg[weights_input].f_max - cfg[weights_input].f_min;
480
481     data_shift = cfg[input].f_mean;
482     data_scale = int8_src_range / fp_range;
483
484     if (scale_policy == COMMON) {
485         wei_scale = int8_wei_range / fp_range;
486     } else if (scale_policy == PER_OC) {
487         float K = int8_wei_range / fp_range;
488         int nelems = dic * n_gates();
489         for (int i = 0; i < nelems; i++) {
490             wei_oc_scales[i] = K * (1. + (float)i / nelems);
491         }
492     }
493 }
494
495 } // namespace rnn