Publishing R3
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / benchdnn / rnn / rnn_aux.cpp
1 /*******************************************************************************
2  * Copyright 2018 Intel Corporation
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  *******************************************************************************/
16
17 #include "rnn/rnn_aux.hpp"
18 #include "norm.hpp"
19
20 #define DPRINT(...)                                     \
21     do {                                                \
22         int l = snprintf(buffer, rem_len, __VA_ARGS__); \
23         buffer += l;                                    \
24         rem_len -= l;                                   \
25     } while (0)
26
27 namespace rnn {
28
29 alg_t str2alg(const char *str) {
30 #define CASE(_alg)                         \
31     if (!strcasecmp(STRINGIFY(_alg), str)) \
32     return _alg
33     CASE(VANILLA_RNN);
34     CASE(VANILLA_LSTM);
35     CASE(VANILLA_GRU);
36     CASE(GRU_LINEAR_BEFORE_RESET);
37 #undef CASE
38     assert(!"unknown algorithm");
39     return VANILLA_RNN;
40 }
41
42 const char *alg2str(alg_t alg) {
43     if (alg == VANILLA_RNN)
44         return "VANILLA_RNN";
45     if (alg == VANILLA_LSTM)
46         return "VANILLA_LSTM";
47     if (alg == VANILLA_GRU)
48         return "VANILLA_GRU";
49     if (alg == GRU_LINEAR_BEFORE_RESET)
50         return "GRU_LINEAR_BEFORE_RESET";
51     assert(!"unknown algorithm");
52     return "unknown algorithm";
53 }
54
55 mkldnn_alg_kind_t alg2kind(alg_t alg) {
56     if (alg == VANILLA_RNN)
57         return mkldnn_vanilla_rnn;
58     if (alg == VANILLA_LSTM)
59         return mkldnn_vanilla_lstm;
60     if (alg == VANILLA_GRU)
61         return mkldnn_vanilla_gru;
62     if (alg == GRU_LINEAR_BEFORE_RESET)
63         return mkldnn_gru_linear_before_reset;
64     assert(!"unknown algorithm");
65     return mkldnn_alg_kind_undef;
66 }
67
68 activation_t str2activation(const char *str) {
69 #define CASE(_act)                         \
70     if (!strcasecmp(STRINGIFY(_act), str)) \
71     return _act
72     CASE(RELU);
73     CASE(LOGISTIC);
74     CASE(TANH);
75 #undef CASE
76     assert(!"unknown activation");
77     return TANH;
78 }
79
80 const char *activation2str(activation_t act) {
81     const char *str = "unknown activation";
82     switch (act) {
83     case RELU: str = "RELU"; break;
84     case LOGISTIC: str = "LOGISTIC"; break;
85     case TANH: str = "TANH"; break;
86     default: assert(!"unknown activation");
87     }
88     return str;
89 }
90
91 mkldnn_alg_kind_t activation2kind(activation_t act) {
92     mkldnn_alg_kind_t alg_kind = mkldnn_alg_kind_undef;
93     switch (act) {
94     case RELU: alg_kind = mkldnn_eltwise_relu; break;
95     case LOGISTIC: alg_kind = mkldnn_eltwise_logistic; break;
96     case TANH: alg_kind = mkldnn_eltwise_tanh; break;
97     default: assert(!"unknown activation");
98     }
99     return alg_kind;
100 }
101
102 const char *direction2str(mkldnn_rnn_direction_t direction) {
103
104     // typedef enum {
105     //     mkldnn_unidirectional_left2right,
106     //     mkldnn_unidirectional_right2left,
107     //     mkldnn_unidirectional = mkldnn_unidirectional_left2right,
108     //     mkldnn_bidirectional_concat,
109     //     mkldnn_bidirectional_sum,
110     // } mkldnn_rnn_direction_t;
111
112     if (direction == mkldnn_unidirectional_left2right)
113         return "left2right";
114     if (direction == mkldnn_unidirectional_right2left)
115         return "right2left";
116     if (direction == mkldnn_bidirectional_concat)
117         return "concat";
118     if (direction == mkldnn_bidirectional_sum)
119         return "sum";
120     assert(!"unknown direction");
121     return "unknown direction";
122 }
123
124 void prb2str(const rnn_prb_t *p, const res_t *res, char *buffer) {
125     int rem_len = max_prb_len;
126
127     DPRINT("%s(%s,%s)", alg2str(p->alg), activation2str(p->activation),
128             direction2str(p->direction));
129     DPRINT("l%d", p->n_layer);
130     DPRINT("t%d", p->n_iter);
131     DPRINT("m%d", p->mb);
132     DPRINT("sic%d", p->sic);
133     DPRINT("slc%d", p->slc);
134     DPRINT("dic%d", p->dic);
135     DPRINT("dlc%d", p->dlc);
136     DPRINT("n\"%s\"", p->name);
137 }
138
139 void init_buffer(float *buf, int size, float value) {
140     for (int i = 0; i < size; i++)
141         buf[i] = value;
142 }
143
144 void gemm(const char *transa, const char *transb, int m, int n, int k,
145         //    float a[m][k], float b[k][n], float c[m][n],
146         const float *a, int lda, const float *b, int ldb, float *c, int ldc,
147         float beta) {
148
149     const bool tr_a = transa && (*transa == 'T' || *transa == 't');
150     const bool tr_b = transb && (*transb == 'T' || *transb == 't');
151
152     array_offset_calculator<const float> pa(a, tr_a ? k : m, lda);
153     array_offset_calculator<const float> pb(b, tr_b ? n : k, ldb);
154     array_offset_calculator<float> pc(c, m, ldc);
155
156     print(80, "gemm(m:%d, n:%d, k:%d, lda:%d, ldb:%d, ldc:%d beta:%f)\n", m, n,
157             k, lda, ldb, ldc, beta);
158 #pragma omp parallel for collapse(2)
159     for (int im = 0; im < m; im++) {
160         for (int in = 0; in < n; in++) {
161             // if beta == 0 the initialize pc by 0. Multiplication of
162             // uninitialized value even by zero can lead to nan
163             float c_elem = (beta == 0.) ? 0. : pc(im, in) * beta;
164             for (int ik = 0; ik < k; ik++) {
165                 const float a_elem = tr_a ? pa(ik, im) : pa(im, ik);
166                 const float b_elem = tr_b ? pb(in, ik) : pb(ik, in);
167                 c_elem += a_elem * b_elem;
168             }
169             pc(im, in) = c_elem;
170         }
171     }
172 }
173
174 float logistic(float x) {
175     return 1.0f / (1.0f + expf(-x));
176 }
177 float dlogistic(float x) {
178     float tmp = logistic(x);
179     return tmp * (1 - tmp);
180 }
181 float relu(float x) {
182     return x > 0 ? x : 0;
183 }
184 float drelu(float x) {
185     return float(x > 0);
186 }
187 float dtanhf(float x) {
188     return (1 - (tanhf(x) * tanhf(x)));
189 }
190
191 int compare_dat(const rnn_prb_t *p, rnn_data_kind_t kind, dnn_mem_t &mem_dt,
192         dnn_mem_t &mem_fp, res_t *r, bool final_compare = false) {
193     size_t nelems = mem_dt.nelems();
194
195     const char *skind = rnn_data_kind2str(kind);
196
197     int in = 0, below = 0, above = 0;
198     int in_ok = 0, below_ok = 0, above_ok = 0;
199     int non_zero = 0;
200
201     diff_norm_t diff_norm;
202
203     r->errors = 0;
204     r->total = nelems;
205
206     for (size_t i = 0; i < nelems; ++i) {
207         const float dt = ((float *)mem_dt)[i];
208         const float fp0 = ((float *)mem_fp)[i];
209
210         float fp = fp0;
211
212         const float diff = fabsf(fp - dt);
213         const float rel_diff = diff / (fabsf(fp) > FLT_MIN ? fabsf(fp) : 1);
214
215         bool ok = true;
216         if (fp < p->cfg_[kind].min) {
217             diff_norm.update(p->cfg_[kind].min, dt);
218             ok = dt == p->cfg_[kind].min;
219             below += 1;
220             below_ok += ok;
221         } else if (fp > p->cfg_[kind].max) {
222             diff_norm.update(p->cfg_[kind].max, dt);
223             ok = dt == p->cfg_[kind].max;
224             above += 1;
225             above_ok += ok;
226         } else {
227             diff_norm.update(fp, dt);
228             ok = (fabs(fp) > 1e-5 ? rel_diff : diff) <= p->cfg_[kind].eps;
229             in += 1;
230             in_ok += ok;
231         }
232         if (!ok) {
233             r->errors++;
234             if (r->errors < 10 || verbose >= 10) {
235                 int n = 0, t = 0, c = 0, s = 0, l = 0, d = 0, w = 0, ic = 0,
236                     oc = 0, b = 0;
237                 switch (kind) {
238                 case input:
239                     inv_ntc_off_f(p, i, n, t, c);
240                     print(0, "%lu, %s, [%s][%d,%d,%d] "
241                              "fp:%8g fp0:%8g dt:%8g diff:%8g rdiff:%8g\n",
242                             (unsigned long)i,
243                             final_compare == false ? "REORDER " : "", skind, n,
244                             t, c, fp, fp0, dt, diff, rel_diff);
245                     break;
246                 case states:
247                     inv_ldsnc_off_f(p, i, l, d, s, n, c);
248                     print(0, "%lu, %s, [%s][%d,%d,%d,%d,%d] "
249                              "fp:%8g fp0:%8g dt:%8g diff:%8g rdiff:%8g\n",
250                             (unsigned long)i,
251                             final_compare == false ? "REORDER " : "", skind, l,
252                             d, s, n, c, fp, fp0, dt, diff, rel_diff);
253                     break;
254                 case weights_input:
255                     inv_ldigo_off_f(p, i, l, d, w, ic, oc);
256                     print(0, "%lu, %s, [%s][%d,%d,%d,%d,%d] "
257                              "fp:%8g fp0:%8g dt:%8g diff:%8g rdiff:%8g\n",
258                             (unsigned long)i,
259                             final_compare == false ? "REORDER " : "", skind, l,
260                             d, w, ic, oc, fp, fp0, dt, diff, rel_diff);
261                     break;
262                 case weights_states:
263                     inv_ldigo_off_f(p, i, l, d, w, ic, oc);
264                     print(0, "%lu, %s, [%s][%d,%d,%d,%d,%d] "
265                              "fp:%8g fp0:%8g dt:%8g diff:%8g rdiff:%8g\n",
266                             (unsigned long)i,
267                             final_compare == false ? "REORDER " : "", skind, l,
268                             d, w, ic, oc, fp, fp0, dt, diff, rel_diff);
269                     break;
270                 case bias:
271                     inv_ldgo_off_f(p, i, l, d, b, c);
272                     print(0, "%lu, %s, [%s][%d,%d,%d,%d] "
273                              "fp:%8g fp0:%8g dt:%8g diff:%8g rdiff:%8g\n",
274                             (unsigned long)i,
275                             final_compare == false ? "REORDER " : "", skind, l,
276                             d, b, c, fp, fp0, dt, diff, rel_diff);
277                     break;
278                 case dst_last_layer:
279                     inv_tnc_off_f(p, i, s, t, n, c);
280                     print(0, "%lu, %s, [%s][%d,%d,%d,%d] "
281                              "fp:%8g fp0:%8g dt:%8g diff:%8g rdiff:%8g\n",
282                             (unsigned long)i,
283                             final_compare == false ? "REORDER " : "", skind, s,
284                             t, n, c, fp, fp0, dt, diff, rel_diff);
285                     break;
286                 case dst_last_iteration:
287                     inv_ldsnc_off_f(p, i, l, d, s, n, c);
288                     print(0, "%lu, %s, [%s][%d,%d,%d,%d,%d "
289                              "fp:%8g fp0:%8g dt:%8g diff:%8g rdiff:%8g\n",
290                             (unsigned long)i,
291                             final_compare == false ? "REORDER " : "", skind, l,
292                             d, s, n, c, fp, fp0, dt, diff, rel_diff);
293                     break;
294                 default: assert("unknown data kind"); return FAIL;
295                 }
296             }
297         }
298
299 #if 1
300         /* for debug purposes only: dump the output */
301         if (final_compare && verbose >= 50) {
302             int n = 0, t = 0, c = 0, s = 0, l = 0, d = 0, w = 0, ic = 0, oc = 0,
303                 b = 0;
304
305             switch (kind) {
306             case input:
307                 inv_ntc_off_f(p, i, n, t, c);
308                 print(0, "[%4lu][%s][%d,%d,%d] fp:%8g fp0:%8g dt:%8g\n",
309                         (unsigned long)i, skind, n, t, c, fp, fp0, dt);
310                 break;
311             case states:
312                 inv_ldsnc_off_f(p, i, l, d, s, n, c);
313                 print(0, "[%4lu][%s][%d,%d,%d,%d,%d] fp:%8g fp0:%8g dt:%8g\n",
314                         (unsigned long)i, skind, l, d, s, n, c, fp, fp0, dt);
315                 break;
316             case weights_input:
317                 inv_ldigo_off_f(p, i, l, d, w, ic, oc);
318                 print(0, "[%4lu][%s][%d,%d,%d,%d,%d] fp:%8g fp0:%8g dt:%8g\n",
319                         (unsigned long)i, skind, l, d, w, ic, oc, fp, fp0, dt);
320                 break;
321             case weights_states:
322                 inv_ldigo_off_f(p, i, l, d, w, ic, oc);
323                 break;
324                 print(0, "[%4lu][%s][%d,%d,%d,%d,%d] fp:%8g fp0:%8g dt:%8g\n",
325                         (unsigned long)i, skind, l, d, w, ic, oc, fp, fp0, dt);
326             case bias:
327                 inv_ldgo_off_f(p, i, l, d, b, c);
328                 break;
329                 print(0, "[%4lu][%s][%d,%d,%d,%d] fp:%8g fp0:%8g dt:%8g\n",
330                         (unsigned long)i, skind, l, d, b, c, fp, fp0, dt);
331             case dst_last_layer:
332                 inv_tnc_off_f(p, i, s, t, n, c);
333                 print(0, "[%4lu][%s][%d,%d,%d] fp:%8g fp0:%8g dt:%8g\n",
334                         (unsigned long)i, skind, n, t, c, fp, fp0, dt);
335                 break;
336             case dst_last_iteration:
337                 inv_ldsnc_off_f(p, i, l, d, s, n, c);
338                 print(0, "[%4lu][%s][%d,%d,%d,%d,%d] fp:%8g fp0:%8g dt:%8g\n",
339                         (unsigned long)i, skind, l, d, s, n, c, fp, fp0, dt);
340                 break;
341             default:
342                 print(0, "[%4lu][unknown] fp:%8g fp0:%8g dt:%8g\n",
343                         (unsigned long)i, fp, fp0, dt);
344                 break;
345             }
346         }
347 #endif
348
349         non_zero += fp != 0;
350     }
351
352     diff_norm.done();
353
354     if (final_compare || r->errors) {
355         const int vl = r->errors ? 0 : 2;
356         print(vl,
357                 "@@@ [%s] %sdiff: l0(``%g``) "
358                 "l1:(%g,%g,%g,``%g``) "
359                 "l2:(%g,%g,%g,``%g``) "
360                 "l8:(%g,%g,%g,``%g``)\n",
361                 skind, final_compare ? "final: " : "",
362                 diff_norm.rel_diff(norm_t::L0), diff_norm.a_[norm_t::L1],
363                 diff_norm.b_[norm_t::L1], diff_norm.diff_[norm_t::L1],
364                 diff_norm.rel_diff(norm_t::L1), diff_norm.a_[norm_t::L2],
365                 diff_norm.b_[norm_t::L2], diff_norm.diff_[norm_t::L2],
366                 diff_norm.rel_diff(norm_t::L2), diff_norm.a_[norm_t::L8],
367                 diff_norm.b_[norm_t::L8], diff_norm.diff_[norm_t::L8],
368                 diff_norm.rel_diff(norm_t::L8));
369     }
370
371     // const double trust_rg_level = 0.3;
372     //??        const double trust_nz_level = get_trust_nz_level(p, kind,
373     // final_compare);
374
375     // const double trust_rg = (double)in / r->total;
376     // const double trust_nz = (double)non_zero / r->total;
377
378     // const bool no_trust = true /* ...in the test ...at all */
379     // && final_compare
380     //??            && (trust_rg < trust_rg_level || trust_nz <
381     // trust_nz_level)
382     //;
383
384     // const bool dump = verbose >= 20
385     // || (verbose >= 10 && (trust_rg < 1. || trust_nz < 1.));
386     /*??
387     if (dump) {
388         print(0, "@@@ [%s] %strust range:%.2f nz:%.2f "
389             "(level range:%.2f nz:%.2f). "
390             "in:%d (ok:%d) below:%d (ok:%d) above:%d (ok:%d) nz:%d "
391             "total:%lu\n", skind, final_compare ? "final: " : "",
392             trust_rg, trust_nz, trust_rg_level, trust_nz_level, in, in_ok,
393             below, below_ok, above, above_ok, non_zero,
394             (unsigned long)r->total);
395     }
396     */
397
398     /*??
399     if (no_trust) {
400         r->state = MISTRUSTED;
401         print(0, "@@@ [%s] test-bug: trust is too low. "
402             "range:%.2f (?<%.2f) nz:%.2f (?<%.2f) (nz: %d total: %lu)\n",
403             skind, trust_rg, trust_rg_level, trust_nz, trust_nz_level,
404             non_zero, (unsigned long)r->total);
405     }*/
406
407     if (r->errors)
408         r->state = FAILED;
409
410     if (final_compare && r->state == UNTESTED)
411         r->state = PASSED; /* optimism */
412
413     return r->state == FAILED ? FAIL : OK;
414 }
415
416 int compare_input(const rnn_prb_t *p, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp,
417         res_t *r, bool final_compare = false) {
418     return compare_dat(p, input, mem_dt, mem_fp, r, final_compare);
419 }
420 int compare_states(const rnn_prb_t *p, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp,
421         res_t *r, bool final_compare = false) {
422     return compare_dat(p, states, mem_dt, mem_fp, r, final_compare);
423 }
424 int compare_weights_input(const rnn_prb_t *p, dnn_mem_t &mem_dt,
425         dnn_mem_t &mem_fp, res_t *r, bool final_compare = false) {
426     return compare_dat(p, weights_input, mem_dt, mem_fp, r, final_compare);
427 }
428 int compare_weights_states(const rnn_prb_t *p, dnn_mem_t &mem_dt,
429         dnn_mem_t &mem_fp, res_t *r, bool final_compare = false) {
430     return compare_dat(p, weights_states, mem_dt, mem_fp, r, final_compare);
431 }
432 int compare_bias(const rnn_prb_t *p, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp,
433         res_t *r, bool final_compare = false) {
434     return compare_dat(p, bias, mem_dt, mem_fp, r, final_compare);
435 }
436 int compare_dst_last_layer(const rnn_prb_t *p, dnn_mem_t &mem_dt,
437         dnn_mem_t &mem_fp, res_t *r, bool final_compare = false) {
438     return compare_dat(p, dst_last_layer, mem_dt, mem_fp, r, final_compare);
439 }
440 int compare_dst_last_iteration(const rnn_prb_t *p, dnn_mem_t &mem_dt,
441         dnn_mem_t &mem_fp, res_t *r, bool final_compare = false) {
442     return compare_dat(p, dst_last_iteration, mem_dt, mem_fp, r, final_compare);
443 }
444
445 } // namespace rnn