1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "rnn/rnn_aux.hpp"
22 int l = snprintf(buffer, rem_len, __VA_ARGS__); \
29 alg_t str2alg(const char *str) {
31 if (!strcasecmp(STRINGIFY(_alg), str)) \
36 CASE(GRU_LINEAR_BEFORE_RESET);
38 assert(!"unknown algorithm");
42 const char *alg2str(alg_t alg) {
43 if (alg == VANILLA_RNN)
45 if (alg == VANILLA_LSTM)
46 return "VANILLA_LSTM";
47 if (alg == VANILLA_GRU)
49 if (alg == GRU_LINEAR_BEFORE_RESET)
50 return "GRU_LINEAR_BEFORE_RESET";
51 assert(!"unknown algorithm");
52 return "unknown algorithm";
55 mkldnn_alg_kind_t alg2kind(alg_t alg) {
56 if (alg == VANILLA_RNN)
57 return mkldnn_vanilla_rnn;
58 if (alg == VANILLA_LSTM)
59 return mkldnn_vanilla_lstm;
60 if (alg == VANILLA_GRU)
61 return mkldnn_vanilla_gru;
62 if (alg == GRU_LINEAR_BEFORE_RESET)
63 return mkldnn_gru_linear_before_reset;
64 assert(!"unknown algorithm");
65 return mkldnn_alg_kind_undef;
68 activation_t str2activation(const char *str) {
70 if (!strcasecmp(STRINGIFY(_act), str)) \
76 assert(!"unknown activation");
80 const char *activation2str(activation_t act) {
81 const char *str = "unknown activation";
83 case RELU: str = "RELU"; break;
84 case LOGISTIC: str = "LOGISTIC"; break;
85 case TANH: str = "TANH"; break;
86 default: assert(!"unknown activation");
91 mkldnn_alg_kind_t activation2kind(activation_t act) {
92 mkldnn_alg_kind_t alg_kind = mkldnn_alg_kind_undef;
94 case RELU: alg_kind = mkldnn_eltwise_relu; break;
95 case LOGISTIC: alg_kind = mkldnn_eltwise_logistic; break;
96 case TANH: alg_kind = mkldnn_eltwise_tanh; break;
97 default: assert(!"unknown activation");
102 const char *direction2str(mkldnn_rnn_direction_t direction) {
105 // mkldnn_unidirectional_left2right,
106 // mkldnn_unidirectional_right2left,
107 // mkldnn_unidirectional = mkldnn_unidirectional_left2right,
108 // mkldnn_bidirectional_concat,
109 // mkldnn_bidirectional_sum,
110 // } mkldnn_rnn_direction_t;
112 if (direction == mkldnn_unidirectional_left2right)
114 if (direction == mkldnn_unidirectional_right2left)
116 if (direction == mkldnn_bidirectional_concat)
118 if (direction == mkldnn_bidirectional_sum)
120 assert(!"unknown direction");
121 return "unknown direction";
124 void prb2str(const rnn_prb_t *p, const res_t *res, char *buffer) {
125 int rem_len = max_prb_len;
127 DPRINT("%s(%s,%s)", alg2str(p->alg), activation2str(p->activation),
128 direction2str(p->direction));
129 DPRINT("l%d", p->n_layer);
130 DPRINT("t%d", p->n_iter);
131 DPRINT("m%d", p->mb);
132 DPRINT("sic%d", p->sic);
133 DPRINT("slc%d", p->slc);
134 DPRINT("dic%d", p->dic);
135 DPRINT("dlc%d", p->dlc);
136 DPRINT("n\"%s\"", p->name);
139 void init_buffer(float *buf, int size, float value) {
140 for (int i = 0; i < size; i++)
144 void gemm(const char *transa, const char *transb, int m, int n, int k,
145 // float a[m][k], float b[k][n], float c[m][n],
146 const float *a, int lda, const float *b, int ldb, float *c, int ldc,
149 const bool tr_a = transa && (*transa == 'T' || *transa == 't');
150 const bool tr_b = transb && (*transb == 'T' || *transb == 't');
152 array_offset_calculator<const float> pa(a, tr_a ? k : m, lda);
153 array_offset_calculator<const float> pb(b, tr_b ? n : k, ldb);
154 array_offset_calculator<float> pc(c, m, ldc);
156 print(80, "gemm(m:%d, n:%d, k:%d, lda:%d, ldb:%d, ldc:%d beta:%f)\n", m, n,
157 k, lda, ldb, ldc, beta);
158 #pragma omp parallel for collapse(2)
159 for (int im = 0; im < m; im++) {
160 for (int in = 0; in < n; in++) {
161 // if beta == 0 the initialize pc by 0. Multiplication of
162 // uninitialized value even by zero can lead to nan
163 float c_elem = (beta == 0.) ? 0. : pc(im, in) * beta;
164 for (int ik = 0; ik < k; ik++) {
165 const float a_elem = tr_a ? pa(ik, im) : pa(im, ik);
166 const float b_elem = tr_b ? pb(in, ik) : pb(ik, in);
167 c_elem += a_elem * b_elem;
174 float logistic(float x) {
175 return 1.0f / (1.0f + expf(-x));
177 float dlogistic(float x) {
178 float tmp = logistic(x);
179 return tmp * (1 - tmp);
181 float relu(float x) {
182 return x > 0 ? x : 0;
184 float drelu(float x) {
187 float dtanhf(float x) {
188 return (1 - (tanhf(x) * tanhf(x)));
191 int compare_dat(const rnn_prb_t *p, rnn_data_kind_t kind, dnn_mem_t &mem_dt,
192 dnn_mem_t &mem_fp, res_t *r, bool final_compare = false) {
193 size_t nelems = mem_dt.nelems();
195 const char *skind = rnn_data_kind2str(kind);
197 int in = 0, below = 0, above = 0;
198 int in_ok = 0, below_ok = 0, above_ok = 0;
201 diff_norm_t diff_norm;
206 for (size_t i = 0; i < nelems; ++i) {
207 const float dt = ((float *)mem_dt)[i];
208 const float fp0 = ((float *)mem_fp)[i];
212 const float diff = fabsf(fp - dt);
213 const float rel_diff = diff / (fabsf(fp) > FLT_MIN ? fabsf(fp) : 1);
216 if (fp < p->cfg_[kind].min) {
217 diff_norm.update(p->cfg_[kind].min, dt);
218 ok = dt == p->cfg_[kind].min;
221 } else if (fp > p->cfg_[kind].max) {
222 diff_norm.update(p->cfg_[kind].max, dt);
223 ok = dt == p->cfg_[kind].max;
227 diff_norm.update(fp, dt);
228 ok = (fabs(fp) > 1e-5 ? rel_diff : diff) <= p->cfg_[kind].eps;
234 if (r->errors < 10 || verbose >= 10) {
235 int n = 0, t = 0, c = 0, s = 0, l = 0, d = 0, w = 0, ic = 0,
239 inv_ntc_off_f(p, i, n, t, c);
240 print(0, "%lu, %s, [%s][%d,%d,%d] "
241 "fp:%8g fp0:%8g dt:%8g diff:%8g rdiff:%8g\n",
243 final_compare == false ? "REORDER " : "", skind, n,
244 t, c, fp, fp0, dt, diff, rel_diff);
247 inv_ldsnc_off_f(p, i, l, d, s, n, c);
248 print(0, "%lu, %s, [%s][%d,%d,%d,%d,%d] "
249 "fp:%8g fp0:%8g dt:%8g diff:%8g rdiff:%8g\n",
251 final_compare == false ? "REORDER " : "", skind, l,
252 d, s, n, c, fp, fp0, dt, diff, rel_diff);
255 inv_ldigo_off_f(p, i, l, d, w, ic, oc);
256 print(0, "%lu, %s, [%s][%d,%d,%d,%d,%d] "
257 "fp:%8g fp0:%8g dt:%8g diff:%8g rdiff:%8g\n",
259 final_compare == false ? "REORDER " : "", skind, l,
260 d, w, ic, oc, fp, fp0, dt, diff, rel_diff);
263 inv_ldigo_off_f(p, i, l, d, w, ic, oc);
264 print(0, "%lu, %s, [%s][%d,%d,%d,%d,%d] "
265 "fp:%8g fp0:%8g dt:%8g diff:%8g rdiff:%8g\n",
267 final_compare == false ? "REORDER " : "", skind, l,
268 d, w, ic, oc, fp, fp0, dt, diff, rel_diff);
271 inv_ldgo_off_f(p, i, l, d, b, c);
272 print(0, "%lu, %s, [%s][%d,%d,%d,%d] "
273 "fp:%8g fp0:%8g dt:%8g diff:%8g rdiff:%8g\n",
275 final_compare == false ? "REORDER " : "", skind, l,
276 d, b, c, fp, fp0, dt, diff, rel_diff);
279 inv_tnc_off_f(p, i, s, t, n, c);
280 print(0, "%lu, %s, [%s][%d,%d,%d,%d] "
281 "fp:%8g fp0:%8g dt:%8g diff:%8g rdiff:%8g\n",
283 final_compare == false ? "REORDER " : "", skind, s,
284 t, n, c, fp, fp0, dt, diff, rel_diff);
286 case dst_last_iteration:
287 inv_ldsnc_off_f(p, i, l, d, s, n, c);
288 print(0, "%lu, %s, [%s][%d,%d,%d,%d,%d "
289 "fp:%8g fp0:%8g dt:%8g diff:%8g rdiff:%8g\n",
291 final_compare == false ? "REORDER " : "", skind, l,
292 d, s, n, c, fp, fp0, dt, diff, rel_diff);
294 default: assert("unknown data kind"); return FAIL;
300 /* for debug purposes only: dump the output */
301 if (final_compare && verbose >= 50) {
302 int n = 0, t = 0, c = 0, s = 0, l = 0, d = 0, w = 0, ic = 0, oc = 0,
307 inv_ntc_off_f(p, i, n, t, c);
308 print(0, "[%4lu][%s][%d,%d,%d] fp:%8g fp0:%8g dt:%8g\n",
309 (unsigned long)i, skind, n, t, c, fp, fp0, dt);
312 inv_ldsnc_off_f(p, i, l, d, s, n, c);
313 print(0, "[%4lu][%s][%d,%d,%d,%d,%d] fp:%8g fp0:%8g dt:%8g\n",
314 (unsigned long)i, skind, l, d, s, n, c, fp, fp0, dt);
317 inv_ldigo_off_f(p, i, l, d, w, ic, oc);
318 print(0, "[%4lu][%s][%d,%d,%d,%d,%d] fp:%8g fp0:%8g dt:%8g\n",
319 (unsigned long)i, skind, l, d, w, ic, oc, fp, fp0, dt);
322 inv_ldigo_off_f(p, i, l, d, w, ic, oc);
324 print(0, "[%4lu][%s][%d,%d,%d,%d,%d] fp:%8g fp0:%8g dt:%8g\n",
325 (unsigned long)i, skind, l, d, w, ic, oc, fp, fp0, dt);
327 inv_ldgo_off_f(p, i, l, d, b, c);
329 print(0, "[%4lu][%s][%d,%d,%d,%d] fp:%8g fp0:%8g dt:%8g\n",
330 (unsigned long)i, skind, l, d, b, c, fp, fp0, dt);
332 inv_tnc_off_f(p, i, s, t, n, c);
333 print(0, "[%4lu][%s][%d,%d,%d] fp:%8g fp0:%8g dt:%8g\n",
334 (unsigned long)i, skind, n, t, c, fp, fp0, dt);
336 case dst_last_iteration:
337 inv_ldsnc_off_f(p, i, l, d, s, n, c);
338 print(0, "[%4lu][%s][%d,%d,%d,%d,%d] fp:%8g fp0:%8g dt:%8g\n",
339 (unsigned long)i, skind, l, d, s, n, c, fp, fp0, dt);
342 print(0, "[%4lu][unknown] fp:%8g fp0:%8g dt:%8g\n",
343 (unsigned long)i, fp, fp0, dt);
354 if (final_compare || r->errors) {
355 const int vl = r->errors ? 0 : 2;
357 "@@@ [%s] %sdiff: l0(``%g``) "
358 "l1:(%g,%g,%g,``%g``) "
359 "l2:(%g,%g,%g,``%g``) "
360 "l8:(%g,%g,%g,``%g``)\n",
361 skind, final_compare ? "final: " : "",
362 diff_norm.rel_diff(norm_t::L0), diff_norm.a_[norm_t::L1],
363 diff_norm.b_[norm_t::L1], diff_norm.diff_[norm_t::L1],
364 diff_norm.rel_diff(norm_t::L1), diff_norm.a_[norm_t::L2],
365 diff_norm.b_[norm_t::L2], diff_norm.diff_[norm_t::L2],
366 diff_norm.rel_diff(norm_t::L2), diff_norm.a_[norm_t::L8],
367 diff_norm.b_[norm_t::L8], diff_norm.diff_[norm_t::L8],
368 diff_norm.rel_diff(norm_t::L8));
371 // const double trust_rg_level = 0.3;
372 //?? const double trust_nz_level = get_trust_nz_level(p, kind,
375 // const double trust_rg = (double)in / r->total;
376 // const double trust_nz = (double)non_zero / r->total;
378 // const bool no_trust = true /* ...in the test ...at all */
380 //?? && (trust_rg < trust_rg_level || trust_nz <
384 // const bool dump = verbose >= 20
385 // || (verbose >= 10 && (trust_rg < 1. || trust_nz < 1.));
388 print(0, "@@@ [%s] %strust range:%.2f nz:%.2f "
389 "(level range:%.2f nz:%.2f). "
390 "in:%d (ok:%d) below:%d (ok:%d) above:%d (ok:%d) nz:%d "
391 "total:%lu\n", skind, final_compare ? "final: " : "",
392 trust_rg, trust_nz, trust_rg_level, trust_nz_level, in, in_ok,
393 below, below_ok, above, above_ok, non_zero,
394 (unsigned long)r->total);
400 r->state = MISTRUSTED;
401 print(0, "@@@ [%s] test-bug: trust is too low. "
402 "range:%.2f (?<%.2f) nz:%.2f (?<%.2f) (nz: %d total: %lu)\n",
403 skind, trust_rg, trust_rg_level, trust_nz, trust_nz_level,
404 non_zero, (unsigned long)r->total);
410 if (final_compare && r->state == UNTESTED)
411 r->state = PASSED; /* optimism */
413 return r->state == FAILED ? FAIL : OK;
416 int compare_input(const rnn_prb_t *p, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp,
417 res_t *r, bool final_compare = false) {
418 return compare_dat(p, input, mem_dt, mem_fp, r, final_compare);
420 int compare_states(const rnn_prb_t *p, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp,
421 res_t *r, bool final_compare = false) {
422 return compare_dat(p, states, mem_dt, mem_fp, r, final_compare);
424 int compare_weights_input(const rnn_prb_t *p, dnn_mem_t &mem_dt,
425 dnn_mem_t &mem_fp, res_t *r, bool final_compare = false) {
426 return compare_dat(p, weights_input, mem_dt, mem_fp, r, final_compare);
428 int compare_weights_states(const rnn_prb_t *p, dnn_mem_t &mem_dt,
429 dnn_mem_t &mem_fp, res_t *r, bool final_compare = false) {
430 return compare_dat(p, weights_states, mem_dt, mem_fp, r, final_compare);
432 int compare_bias(const rnn_prb_t *p, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp,
433 res_t *r, bool final_compare = false) {
434 return compare_dat(p, bias, mem_dt, mem_fp, r, final_compare);
436 int compare_dst_last_layer(const rnn_prb_t *p, dnn_mem_t &mem_dt,
437 dnn_mem_t &mem_fp, res_t *r, bool final_compare = false) {
438 return compare_dat(p, dst_last_layer, mem_dt, mem_fp, r, final_compare);
440 int compare_dst_last_iteration(const rnn_prb_t *p, dnn_mem_t &mem_dt,
441 dnn_mem_t &mem_fp, res_t *r, bool final_compare = false) {
442 return compare_dat(p, dst_last_iteration, mem_dt, mem_fp, r, final_compare);