Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / benchdnn / rnn / rnn.cpp
1 /*******************************************************************************
2  * Copyright 2018 Intel Corporation
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  *******************************************************************************/
16
17 #include <float.h>
18 #include <math.h>
19 #include <stdio.h>
20 #include <stdlib.h>
21 #include <random>
22
23 #include "mkldnn.h"
24
25 #include "src/common/mkldnn_thread.hpp"
26
27 #include "mkldnn_common.hpp"
28 #include "mkldnn_memory.hpp"
29 #include "norm.hpp"
30
31 #include "rnn/rnn.hpp"
32 #include "rnn/rnn_aux.hpp"
33
34 namespace rnn {
35
36 #define CALL_MKLDNN_RNN 1
37
38 mkldnn_primitive_attr_t create_mkldnn_rnn_attr(const rnn_prb_t *p) {
39     mkldnn_primitive_attr_t mkldnn_attr = NULL;
40
41     DNN_SAFE_V(mkldnn_primitive_attr_create(&mkldnn_attr));
42     if (p->attr.irmode != attr_t::round_mode_t::NEAREST)
43         DNN_SAFE_V(mkldnn_primitive_attr_set_int_output_round_mode(
44                 mkldnn_attr, (mkldnn_round_mode_t)p->attr.irmode));
45
46     if (p->scale_policy == PER_OC) {
47         DNN_SAFE_V(mkldnn_primitive_attr_set_rnn_weights_qparams(
48                 mkldnn_attr, p->dic * p->n_gates(), 0x3, p->wei_oc_scales));
49     } else if (p->scale_policy == COMMON && p->wei_scale != 1.) {
50         DNN_SAFE_V(mkldnn_primitive_attr_set_rnn_weights_qparams(
51                 mkldnn_attr, 1, 0, &p->wei_scale));
52     }
53
54     if (p->data_scale != 1.0 || p->data_shift != 0.0) {
55         DNN_SAFE_V(mkldnn_primitive_attr_set_rnn_data_qparams(
56                 mkldnn_attr, p->data_scale, p->data_shift));
57     }
58
59     return mkldnn_attr;
60 }
61
62 int fill_memory(const rnn_prb_t *p, rnn_data_kind_t kind, dnn_mem_t &mem1,
63         dnn_mem_t &mem2) {
64 #ifdef CALL_MKLDNN_RNN
65     const size_t nelems = mem1.nelems();
66     assert(mem1.nelems() == mem2.nelems());
67 #else
68     const size_t nelems = mem2.nelems();
69 #endif
70
71     dt_conf_t c = p->cfg[kind];
72     float mean = c.f_mean, var = c.f_var, min = c.f_min, max = c.f_max;
73     mkldnn::impl::parallel(0, [&](int ithr, int nthr) {
74         size_t chunk_size = (nelems + nthr - 1) / nthr;
75         size_t idx_start = ithr * chunk_size;
76         size_t idx_end = MIN2(idx_start + chunk_size, nelems);
77         std::minstd_rand msr;
78         msr.seed((unsigned long int)kind);
79         std::normal_distribution<float> gen(mean, var);
80         msr.discard(idx_start);
81         for (size_t idx = idx_start; idx < idx_end; ++idx) {
82             auto val = (c.dt == mkldnn_f32) ? gen(msr) : round(gen(msr));
83             mem2.set_elem(idx, MAX2(MIN2(val, max), min));
84         }
85     });
86
87     mem1.reorder(mem2);
88     return OK;
89 }
90
91 inline int init_pd(const rnn_prb_t *p, mkldnn_rnn_desc_t rd[2],
92         mkldnn_primitive_desc_t rpd[2], res_t *r) {
93     const bool is_bwd = p->prop == mkldnn_backward;
94     // If we are testing backward, we have to first run forward
95     // training first in order to generate a valid workspace.
96     auto fwd_prop = is_bwd ? mkldnn_forward_training : mkldnn_forward_inference;
97     const bool is_gru_lbr = p->alg == LBR_GRU;
98     int the_stride = 1;
99     /// @todo we need to add stride support for diff_* tensors too
100     mkldnn_memory_desc_t input_d, states_d, weights_input_d, weights_states_d,
101             bias_d, dst_last_layer_d, dst_last_iteration_d, diff_input_d,
102             diff_states_d, diff_weights_input_d, diff_weights_states_d,
103             diff_bias_d, diff_last_layer_d, diff_last_iteration_d;
104
105     // dimensions with ref
106     mkldnn_dims_t input_dims = { p->n_iter, p->mb, p->slc };
107     // bidirectional = 2, s for lstm = 2, for all other = 1
108     mkldnn_dims_t weights_input_dims
109             = { p->n_layer, p->n_directions(), p->slc, p->n_gates(), p->dic };
110     mkldnn_dims_t weights_states_dims
111             = { p->n_layer, p->n_directions(), p->sic, p->n_gates(), p->dic };
112     mkldnn_dims_t bias_dims
113             = { p->n_layer, p->n_directions(), p->n_gates() + is_gru_lbr, p->dic };
114     // mkldnn_tnc
115     int lastlay_dlc = (p->direction == mkldnn_bidirectional_concat)
116             ? 2 * p->dlc
117             : p->dlc;
118     mkldnn_dims_t dst_last_layer_dims = { p->n_iter, p->mb, lastlay_dlc };
119
120     DNN_SAFE(mkldnn_memory_desc_init(
121                      &input_d, 3, input_dims, p->cfg[input].dt, mkldnn_tnc),
122             WARN);
123     input_d.layout_desc.blocking.strides[0][0] += the_stride;
124
125     mkldnn_dims_t states_dims
126             = { p->n_layer, p->n_directions(), p->n_states(), p->mb, p->sic };
127     DNN_SAFE(mkldnn_memory_desc_init(&states_d, 5, states_dims,
128                      p->cfg[states].dt, mkldnn_ldsnc),
129             WARN);
130
131     states_d.layout_desc.blocking.strides[0][3] = p->sic + the_stride;
132     states_d.layout_desc.blocking.strides[0][2]
133             = states_d.layout_desc.blocking.strides[0][3] * states_d.dims[3]
134             + the_stride;
135     for (int d = 1; d >= 0; --d)
136         states_d.layout_desc.blocking.strides[0][d]
137                 = states_d.layout_desc.blocking.strides[0][d + 1]
138                 * states_d.dims[d + 1];
139
140     DNN_SAFE(mkldnn_memory_desc_init(&weights_input_d, 5, weights_input_dims,
141                      p->cfg[weights_input].dt, mkldnn_any),
142             WARN);
143
144     DNN_SAFE(mkldnn_memory_desc_init(&weights_states_d, 5, weights_states_dims,
145                      p->cfg[weights_states].dt, mkldnn_any),
146             WARN);
147
148     DNN_SAFE(mkldnn_memory_desc_init(
149                      &bias_d, 4, bias_dims, p->cfg[bias].dt, mkldnn_any),
150             WARN);
151
152     DNN_SAFE(mkldnn_memory_desc_init(&dst_last_layer_d, 3, dst_last_layer_dims,
153                      p->cfg[dst_last_layer].dt, mkldnn_tnc),
154             WARN);
155     dst_last_layer_d.layout_desc.blocking.strides[0][0] += the_stride;
156
157     mkldnn_dims_t dst_last_iteration_dims
158             = { p->n_layer, p->n_directions(), p->n_states(), p->mb, p->dic };
159     DNN_SAFE(mkldnn_memory_desc_init(&dst_last_iteration_d, 5,
160                      dst_last_iteration_dims, p->cfg[dst_last_iteration].dt,
161                      mkldnn_ldsnc),
162             WARN);
163
164     dst_last_iteration_d.layout_desc.blocking.strides[0][3]
165             = p->sic + the_stride;
166     dst_last_iteration_d.layout_desc.blocking.strides[0][2]
167             = dst_last_iteration_d.layout_desc.blocking.strides[0][3]
168                     * dst_last_iteration_d.dims[3]
169             + the_stride;
170     for (int d = 1; d >= 0; --d)
171         dst_last_iteration_d.layout_desc.blocking.strides[0][d]
172                 = dst_last_iteration_d.layout_desc.blocking.strides[0][d + 1]
173                 * dst_last_iteration_d.dims[d + 1];
174
175     mkldnn_alg_kind_t kind = alg2kind(p->alg);
176     mkldnn_alg_kind_t f = activation2kind(p->activation);
177
178     mkldnn_rnn_cell_desc_t rcd;
179     DNN_SAFE(mkldnn_rnn_cell_desc_init(&rcd, kind, f, 0U, 0, 0), WARN);
180     // Initializing the forward pass
181     // When inference, we use forward_inference
182     // When training, we use forward_training
183     {
184         mkldnn_status_t init_status = mkldnn_success;
185         init_status = mkldnn_rnn_forward_desc_init(&rd[0], fwd_prop, &rcd,
186                          p->direction, &input_d, &states_d, &weights_input_d,
187                          &weights_states_d, &bias_d, &dst_last_layer_d,
188                          &dst_last_iteration_d);
189         if (init_status == mkldnn_unimplemented)
190             return r->state = UNIMPLEMENTED, OK;
191         else
192             SAFE(init_status, WARN);
193     }
194
195     if (is_bwd) {
196         DNN_SAFE(mkldnn_memory_desc_init(&diff_input_d, 3, input_dims,
197                          p->cfg[dst_diff_input].dt, mkldnn_any),
198                 WARN);
199         DNN_SAFE(mkldnn_memory_desc_init(&diff_states_d, 5, states_dims,
200                          p->cfg[dst_diff_states].dt, mkldnn_any),
201                 WARN);
202         DNN_SAFE(mkldnn_memory_desc_init(&diff_weights_input_d, 5,
203                          weights_input_dims, p->cfg[dst_diff_weights_input].dt,
204                          mkldnn_any),
205                 WARN);
206         DNN_SAFE(mkldnn_memory_desc_init(&diff_weights_states_d, 5,
207                          weights_states_dims,
208                          p->cfg[dst_diff_weights_states].dt, mkldnn_any),
209                 WARN);
210         DNN_SAFE(mkldnn_memory_desc_init(&diff_bias_d, 4, bias_dims,
211                          p->cfg[dst_diff_bias].dt, mkldnn_any),
212                 WARN);
213         DNN_SAFE(mkldnn_memory_desc_init(&diff_last_layer_d, 3,
214                          dst_last_layer_dims, p->cfg[diff_last_layer].dt,
215                          mkldnn_any),
216                 WARN);
217         DNN_SAFE(mkldnn_memory_desc_init(&diff_last_iteration_d, 5,
218                          dst_last_iteration_dims,
219                          p->cfg[diff_last_iteration].dt, mkldnn_any),
220                 WARN);
221         DNN_SAFE(mkldnn_rnn_backward_desc_init(&rd[1], p->prop, &rcd,
222                          p->direction, &input_d, &states_d, &weights_input_d,
223                          &weights_states_d, &bias_d, &dst_last_layer_d,
224                          &dst_last_iteration_d, &diff_input_d, &diff_states_d,
225                          &diff_weights_input_d, &diff_weights_states_d,
226                          &diff_bias_d, &diff_last_layer_d,
227                          &diff_last_iteration_d),
228                 WARN);
229     }
230     auto mkldnn_attr = create_mkldnn_rnn_attr(p);
231     mkldnn_status_t init_status = mkldnn_success;
232     for (int i = 0; i < 1 + (int)is_bwd; i++) {
233         init_status = mkldnn_primitive_desc_create_v2(
234                 &(rpd[i]), &(rd[i]), mkldnn_attr, engine, NULL);
235         if (init_status == mkldnn_unimplemented)
236             return r->state = UNIMPLEMENTED, OK;
237         else
238             SAFE(init_status, WARN);
239     }
240     mkldnn_primitive_attr_destroy(mkldnn_attr);
241
242     auto q = [=](mkldnn_query_t query, int rpd_idx, int index = 0) {
243         return *mkldnn_primitive_desc_query_memory_d(
244                 mkldnn_primitive_desc_query_pd(rpd[rpd_idx], query, index));
245     };
246
247     for (int i = 0; i < 1 + (int)is_bwd; i++) {
248         rd[i].src_layer_desc = q(mkldnn_query_src_pd, i);
249         rd[i].src_iter_desc = q(mkldnn_query_src_pd, i, 1);
250         rd[i].weights_layer_desc = q(mkldnn_query_weights_pd, i);
251         rd[i].weights_iter_desc = q(mkldnn_query_weights_pd, i, 1);
252         rd[i].bias_desc = q(mkldnn_query_weights_pd, i, 2);
253         rd[i].dst_layer_desc = q(mkldnn_query_dst_pd, i);
254         rd[i].dst_iter_desc = q(mkldnn_query_dst_pd, i, 1);
255     }
256     if (is_bwd) {
257         rd[1].diff_src_layer_desc = q(mkldnn_query_diff_src_pd, 1);
258         rd[1].diff_src_iter_desc = q(mkldnn_query_diff_src_pd, 1, 1);
259         rd[1].diff_weights_layer_desc = q(mkldnn_query_diff_weights_pd, 1);
260         rd[1].diff_weights_iter_desc = q(mkldnn_query_diff_weights_pd, 1, 1);
261         rd[1].diff_bias_desc = q(mkldnn_query_diff_weights_pd, 1, 2);
262         rd[1].diff_dst_layer_desc = q(mkldnn_query_diff_dst_pd, 1);
263         rd[1].diff_dst_iter_desc = q(mkldnn_query_diff_dst_pd, 1, 1);
264     }
265
266     return OK;
267 }
268
269 int doit(const rnn_prb_t *p, res_t *r) {
270     res_t res_zero{};
271     *r = res_zero;
272
273     const auto fp = mkldnn_f32;
274
275     if (p->alg != VANILLA_LSTM && p->alg != VANILLA_RNN
276         && p->alg != VANILLA_GRU && p->alg != LBR_GRU) {
277         printf("p->alg: %d\n", (int)p->alg);
278         r->state = UNIMPLEMENTED;
279         return OK;
280     }
281
282     const bool is_bwd = p->prop == mkldnn_backward;
283
284     dnn_mem_t *input_dt = nullptr;
285     dnn_mem_t *states_dt = nullptr;
286     dnn_mem_t *weights_input_dt = nullptr;
287     dnn_mem_t *weights_states_dt = nullptr;
288     dnn_mem_t *bias_dt = nullptr;
289     dnn_mem_t *dst_last_layer_dt = nullptr;
290     dnn_mem_t *dst_last_iteration_dt = nullptr;
291
292     dnn_mem_t *bwd_weights_input_dt = nullptr;
293     dnn_mem_t *bwd_weights_states_dt = nullptr;
294     dnn_mem_t *dst_diff_input_dt = nullptr;
295     dnn_mem_t *dst_diff_states_dt = nullptr;
296     dnn_mem_t *dst_diff_weights_input_dt = nullptr;
297     dnn_mem_t *dst_diff_weights_states_dt = nullptr;
298     dnn_mem_t *dst_diff_bias_dt = nullptr;
299     dnn_mem_t *diff_last_layer_dt = nullptr;
300     dnn_mem_t *diff_last_iteration_dt = nullptr;
301
302     dnn_mem_t *input_fp = nullptr;
303     dnn_mem_t *states_fp = nullptr;
304     dnn_mem_t *weights_input_fp = nullptr;
305     dnn_mem_t *weights_states_fp = nullptr;
306     dnn_mem_t *bias_fp = nullptr;
307     dnn_mem_t *dst_last_layer_fp = nullptr;
308     dnn_mem_t *dst_last_iteration_fp = nullptr;
309
310     dnn_mem_t *dst_diff_input_fp = nullptr;
311     dnn_mem_t *dst_diff_states_fp = nullptr;
312     dnn_mem_t *dst_diff_weights_input_fp = nullptr;
313     dnn_mem_t *dst_diff_weights_states_fp = nullptr;
314     dnn_mem_t *dst_diff_bias_fp = nullptr;
315     dnn_mem_t *diff_last_layer_fp = nullptr;
316     dnn_mem_t *diff_last_iteration_fp = nullptr;
317
318     dnn_mem_t *workspace_dt = nullptr;
319
320     mkldnn_rnn_desc_t rd[2];
321     mkldnn_primitive_desc_t rpd[2] = {nullptr};
322     mkldnn_primitive_t c{};
323     SAFE(init_pd(p, rd, rpd, r), WARN);
324     if (r->state == SKIPPED || r->state == UNIMPLEMENTED)
325         return OK;
326
327     auto &input_dt_d = rd[0].src_layer_desc;
328     auto &states_dt_d = rd[0].src_iter_desc;
329     auto &weights_input_dt_d = rd[0].weights_layer_desc;
330     auto &weights_states_dt_d = rd[0].weights_iter_desc;
331     auto &bias_dt_d = rd[0].bias_desc;
332     auto &dst_last_layer_dt_d = rd[0].dst_layer_desc;
333     auto &dst_last_iteration_dt_d = rd[0].dst_iter_desc;
334
335     auto &bwd_weights_input_dt_d = rd[1].weights_layer_desc;
336     auto &bwd_weights_states_dt_d = rd[1].weights_iter_desc;
337     auto &diff_src_layer_dt_d = rd[1].diff_src_layer_desc;
338     auto &diff_src_iter_dt_d = rd[1].diff_src_iter_desc;
339     auto &diff_weights_layer_dt_d = rd[1].diff_weights_layer_desc;
340     auto &diff_weights_iter_dt_d = rd[1].diff_weights_iter_desc;
341     auto &diff_bias_dt_d = rd[1].diff_bias_desc;
342     auto &diff_dst_layer_dt_d = rd[1].diff_dst_layer_desc;
343     auto &diff_dst_iter_dt_d = rd[1].diff_dst_iter_desc;
344
345     input_dt = new dnn_mem_t(input_dt_d, p->cfg[input].dt);
346     states_dt = new dnn_mem_t(states_dt_d, p->cfg[states].dt);
347     weights_input_dt
348             = new dnn_mem_t(weights_input_dt_d, p->cfg[weights_input].dt);
349     weights_states_dt
350             = new dnn_mem_t(weights_states_dt_d, p->cfg[weights_states].dt);
351     bias_dt = new dnn_mem_t(bias_dt_d, p->cfg[bias].dt);
352     dst_last_layer_dt
353             = new dnn_mem_t(dst_last_layer_dt_d, p->cfg[dst_last_layer].dt);
354     dst_last_iteration_dt = new dnn_mem_t(
355             dst_last_iteration_dt_d, p->cfg[dst_last_iteration].dt);
356
357     if (is_bwd) {
358         bwd_weights_input_dt = new dnn_mem_t(bwd_weights_input_dt_d, fp);
359         bwd_weights_states_dt = new dnn_mem_t(bwd_weights_states_dt_d, fp);
360         dst_diff_input_dt = new dnn_mem_t(diff_src_layer_dt_d, fp);
361         dst_diff_states_dt = new dnn_mem_t(diff_src_iter_dt_d, fp);
362         dst_diff_weights_input_dt = new dnn_mem_t(diff_weights_layer_dt_d, fp);
363         dst_diff_weights_states_dt = new dnn_mem_t(diff_weights_iter_dt_d, fp);
364         dst_diff_bias_dt = new dnn_mem_t(diff_bias_dt_d, fp);
365         diff_last_layer_dt = new dnn_mem_t(diff_dst_layer_dt_d, fp);
366         diff_last_iteration_dt = new dnn_mem_t(diff_dst_iter_dt_d, fp);
367     }
368
369     input_fp = new dnn_mem_t(input_dt_d, fp, mkldnn_tnc);
370     states_fp = new dnn_mem_t(states_dt_d, fp, mkldnn_ldsnc);
371     weights_input_fp = new dnn_mem_t(weights_input_dt_d, fp, mkldnn_ldigo);
372     weights_states_fp = new dnn_mem_t(weights_states_dt_d, fp, mkldnn_ldigo);
373     bias_fp = new dnn_mem_t(bias_dt_d, fp, mkldnn_ldgo);
374     dst_last_layer_fp = new dnn_mem_t(dst_last_layer_dt_d, fp, mkldnn_tnc);
375     dst_last_iteration_fp
376             = new dnn_mem_t(dst_last_iteration_dt_d, fp, mkldnn_ldsnc);
377
378     if (is_bwd) {
379         dst_diff_input_fp = new dnn_mem_t(diff_src_layer_dt_d, fp, mkldnn_tnc);
380         dst_diff_states_fp
381                 = new dnn_mem_t(diff_src_iter_dt_d, fp, mkldnn_ldsnc);
382         dst_diff_weights_input_fp
383                 = new dnn_mem_t(diff_weights_layer_dt_d, fp, mkldnn_ldigo);
384         dst_diff_weights_states_fp
385                 = new dnn_mem_t(diff_weights_iter_dt_d, fp, mkldnn_ldigo);
386         dst_diff_bias_fp = new dnn_mem_t(diff_bias_dt_d, fp, mkldnn_ldgo);
387         diff_last_layer_fp = new dnn_mem_t(diff_dst_layer_dt_d, fp, mkldnn_tnc);
388         diff_last_iteration_fp
389                 = new dnn_mem_t(diff_dst_iter_dt_d, fp, mkldnn_ldsnc);
390
391         const auto ws_pd = mkldnn_primitive_desc_query_pd(
392                 rpd[0], mkldnn_query_workspace_pd, 0);
393         SAFE(ws_pd != NULL ? OK : FAIL, WARN);
394         workspace_dt
395                 = new dnn_mem_t(*mkldnn_primitive_desc_query_memory_d(ws_pd));
396     }
397
398     SAFE(fill_memory(p, input, *input_dt, *input_fp), WARN);
399     SAFE(fill_memory(p, states, *states_dt, *states_fp), WARN);
400     SAFE(fill_memory(p, weights_input, *weights_input_dt, *weights_input_fp),
401             WARN);
402     SAFE(fill_memory(p, weights_states, *weights_states_dt, *weights_states_fp),
403             WARN);
404     SAFE(fill_memory(p, bias, *bias_dt, *bias_fp), WARN);
405     SAFE(fill_memory(p, dst_last_layer, *dst_last_layer_dt, *dst_last_layer_fp),
406             WARN);
407     SAFE(fill_memory(p, dst_last_iteration, *dst_last_iteration_dt,
408                  *dst_last_iteration_fp),
409             WARN);
410
411     if (is_bwd) {
412         SAFE(bwd_weights_states_dt->reorder(*weights_states_dt), WARN);
413         SAFE(bwd_weights_input_dt->reorder(*weights_input_dt), WARN);
414         SAFE(fill_memory(
415                      p, dst_diff_input, *dst_diff_input_dt, *dst_diff_input_fp),
416                 WARN);
417         SAFE(fill_memory(p, dst_diff_states, *dst_diff_states_dt,
418                      *dst_diff_states_fp),
419                 WARN);
420         SAFE(fill_memory(p, dst_diff_weights_input, *dst_diff_weights_input_dt,
421                      *dst_diff_weights_input_fp),
422                 WARN);
423         SAFE(fill_memory(p, dst_diff_weights_states,
424                      *dst_diff_weights_states_dt, *dst_diff_weights_states_fp),
425                 WARN);
426         SAFE(fill_memory(
427                      p, dst_diff_bias, *dst_diff_bias_dt, *dst_diff_bias_fp),
428                 WARN);
429         SAFE(fill_memory(p, diff_last_layer, *diff_last_layer_dt,
430                      *diff_last_layer_fp),
431                 WARN);
432         SAFE(fill_memory(p, diff_last_iteration, *diff_last_iteration_dt,
433                      *diff_last_iteration_fp),
434                 WARN);
435     }
436
437     // Running the forward pass
438     {
439         mkldnn_primitive_at_t inputs[] = { { input_dt->p_, 0 },
440             { states_dt->p_, 0 }, { weights_input_dt->p_, 0 },
441             { weights_states_dt->p_, 0 }, { bias_dt->p_, 0 } };
442         const_mkldnn_primitive_t outputs[] = { dst_last_layer_dt->p_,
443             dst_last_iteration_dt->p_, workspace_dt ? workspace_dt->p_ : 0 };
444 #ifdef CALL_MKLDNN_RNN
445         DNN_SAFE(mkldnn_primitive_create(&c, rpd[0], inputs, outputs), WARN);
446         SAFE(execute(c), WARN);
447 #endif
448         if ((p->prop == mkldnn_forward) && (bench_mode & CORR)) {
449             compute_ref_fwd(p, *input_fp, *states_fp, *weights_input_fp,
450                     *weights_states_fp, *bias_fp, *dst_last_layer_fp,
451                     *dst_last_iteration_fp, p->direction);
452             dnn_mem_t dst_last_layer(*dst_last_layer_dt, fp, mkldnn_tnc);
453             dnn_mem_t dst_last_iteration(
454                     *dst_last_iteration_dt, fp, mkldnn_ldsnc);
455             SAFE(compare_dst_last_layer(
456                          p, dst_last_layer, *dst_last_layer_fp, r, true),
457                     WARN);
458             SAFE(compare_dst_last_iteration(p, dst_last_iteration,
459                          *dst_last_iteration_fp, r, true),
460                     WARN);
461         }
462     }
463
464     if (is_bwd) {
465         mkldnn_primitive_at_t inputs[] = {
466             { input_dt->p_, 0 }, { states_dt->p_, 0 },
467             { bwd_weights_input_dt->p_, 0 }, { bwd_weights_states_dt->p_, 0 },
468             { bias_dt->p_, 0 }, { dst_last_layer_dt->p_, 0 },
469             { dst_last_iteration_dt->p_, 0 }, { diff_last_layer_dt->p_, 0 },
470             { diff_last_iteration_dt->p_, 0 }, { workspace_dt->p_, 0 },
471         };
472         const_mkldnn_primitive_t outputs[] = { dst_diff_input_dt->p_,
473             dst_diff_states_dt->p_, dst_diff_weights_input_dt->p_,
474             dst_diff_weights_states_dt->p_, dst_diff_bias_dt->p_ };
475
476 #ifdef CALL_MKLDNN_RNN
477         DNN_SAFE(mkldnn_primitive_create(&c, rpd[1], inputs, outputs), WARN);
478         SAFE(execute(c), WARN);
479 #endif
480
481         if (bench_mode & CORR) {
482             compute_ref_bwd(p, *input_fp, *states_fp, *diff_last_layer_fp,
483                     *diff_last_iteration_fp, *weights_input_fp,
484                     *weights_states_fp, *bias_fp, *dst_last_layer_fp,
485                     *dst_last_iteration_fp, *dst_diff_input_fp,
486                     *dst_diff_states_fp, *dst_diff_weights_input_fp,
487                     *dst_diff_weights_states_fp, *dst_diff_bias_fp,
488                     p->direction);
489
490             dnn_mem_t dst_last_layer(*dst_last_layer_dt, fp, mkldnn_tnc);
491             dnn_mem_t dst_last_iteration(
492                     *dst_last_iteration_dt, fp, mkldnn_ldsnc);
493             SAFE(compare_dst_last_layer(
494                          p, dst_last_layer, *dst_last_layer_fp, r, true),
495                     WARN);
496             SAFE(compare_dst_last_iteration(p, dst_last_iteration,
497                          *dst_last_iteration_fp, r, true),
498                     WARN);
499
500             dnn_mem_t diff_input(*dst_diff_input_dt, fp, mkldnn_tnc);
501             dnn_mem_t diff_states(*dst_diff_states_dt, fp, mkldnn_ldsnc);
502             SAFE(compare_input(p, diff_input, *dst_diff_input_fp, r, true),
503                     WARN);
504             SAFE(compare_states(p, diff_states, *dst_diff_states_fp, r, true),
505                     WARN);
506
507             dnn_mem_t diff_weights_input(
508                     *dst_diff_weights_input_dt, fp, mkldnn_ldigo);
509             dnn_mem_t diff_weights_states(
510                     *dst_diff_weights_states_dt, fp, mkldnn_ldigo);
511             SAFE(compare_weights_input(p, diff_weights_input,
512                          *dst_diff_weights_input_fp, r, true),
513                     WARN);
514             SAFE(compare_weights_states(p, diff_weights_states,
515                          *dst_diff_weights_states_fp, r, true),
516                     WARN);
517
518             dnn_mem_t diff_bias(*dst_diff_bias_dt, fp, mkldnn_ldgo);
519             SAFE(compare_bias(p, diff_bias, *dst_diff_bias_fp, r, true), WARN);
520         }
521     }
522
523     if (bench_mode & PERF) {
524         auto &t = r->timer;
525         t.reset();
526         while (true) {
527 #ifdef CALL_MKLDNN_RNN
528             SAFE(execute(c), WARN);
529 #endif
530             t.stamp();
531             const bool stop = false
532                     || (fix_times_per_prb && t.times() >= fix_times_per_prb)
533                     || (!fix_times_per_prb && t.total_ms() >= max_ms_per_prb
534                                && t.times() >= min_times_per_prb);
535             if (stop)
536                 break;
537         }
538     }
539
540     // cleanup
541     delete input_fp;
542     delete states_fp;
543     delete weights_input_fp;
544     delete weights_states_fp;
545     delete bias_fp;
546     delete dst_last_layer_fp;
547     delete dst_last_iteration_fp;
548
549     if (is_bwd) {
550         delete bwd_weights_input_dt;
551         delete bwd_weights_states_dt;
552         delete dst_diff_input_fp;
553         delete dst_diff_states_fp;
554         delete dst_diff_weights_input_fp;
555         delete dst_diff_weights_states_fp;
556         delete dst_diff_bias_fp;
557         delete diff_last_layer_fp;
558         delete diff_last_iteration_fp;
559     }
560
561     delete input_dt;
562     delete states_dt;
563     delete weights_input_dt;
564     delete weights_states_dt;
565     delete bias_dt;
566     delete dst_last_layer_dt;
567     delete dst_last_iteration_dt;
568
569     if (is_bwd) {
570         delete dst_diff_input_dt;
571         delete dst_diff_states_dt;
572         delete dst_diff_weights_input_dt;
573         delete dst_diff_weights_states_dt;
574         delete dst_diff_bias_dt;
575         delete diff_last_layer_dt;
576         delete diff_last_iteration_dt;
577     }
578
579     delete workspace_dt;
580
581     DNN_SAFE(mkldnn_primitive_desc_destroy(rpd[0]), CRIT);
582     DNN_SAFE(mkldnn_primitive_desc_destroy(rpd[1]), CRIT);
583     DNN_SAFE(mkldnn_primitive_destroy(c), CRIT);
584
585     return OK;
586 }
587 } // namespace rnn