1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
25 #include "src/common/mkldnn_thread.hpp"
27 #include "mkldnn_common.hpp"
28 #include "mkldnn_memory.hpp"
31 #include "rnn/rnn.hpp"
32 #include "rnn/rnn_aux.hpp"
36 #define CALL_MKLDNN_RNN 1
38 mkldnn_primitive_attr_t create_mkldnn_rnn_attr(const rnn_prb_t *p) {
39 mkldnn_primitive_attr_t mkldnn_attr = NULL;
41 DNN_SAFE_V(mkldnn_primitive_attr_create(&mkldnn_attr));
42 if (p->attr.irmode != attr_t::round_mode_t::NEAREST)
43 DNN_SAFE_V(mkldnn_primitive_attr_set_int_output_round_mode(
44 mkldnn_attr, (mkldnn_round_mode_t)p->attr.irmode));
46 if (p->scale_policy == PER_OC) {
47 DNN_SAFE_V(mkldnn_primitive_attr_set_rnn_weights_qparams(
48 mkldnn_attr, p->dic * p->n_gates(), 0x3, p->wei_oc_scales));
49 } else if (p->scale_policy == COMMON && p->wei_scale != 1.) {
50 DNN_SAFE_V(mkldnn_primitive_attr_set_rnn_weights_qparams(
51 mkldnn_attr, 1, 0, &p->wei_scale));
54 if (p->data_scale != 1.0 || p->data_shift != 0.0) {
55 DNN_SAFE_V(mkldnn_primitive_attr_set_rnn_data_qparams(
56 mkldnn_attr, p->data_scale, p->data_shift));
62 int fill_memory(const rnn_prb_t *p, rnn_data_kind_t kind, dnn_mem_t &mem1,
64 #ifdef CALL_MKLDNN_RNN
65 const size_t nelems = mem1.nelems();
66 assert(mem1.nelems() == mem2.nelems());
68 const size_t nelems = mem2.nelems();
71 dt_conf_t c = p->cfg[kind];
72 float mean = c.f_mean, var = c.f_var, min = c.f_min, max = c.f_max;
73 mkldnn::impl::parallel(0, [&](int ithr, int nthr) {
74 size_t chunk_size = (nelems + nthr - 1) / nthr;
75 size_t idx_start = ithr * chunk_size;
76 size_t idx_end = MIN2(idx_start + chunk_size, nelems);
78 msr.seed((unsigned long int)kind);
79 std::normal_distribution<float> gen(mean, var);
80 msr.discard(idx_start);
81 for (size_t idx = idx_start; idx < idx_end; ++idx) {
82 auto val = (c.dt == mkldnn_f32) ? gen(msr) : round(gen(msr));
83 mem2.set_elem(idx, MAX2(MIN2(val, max), min));
91 inline int init_pd(const rnn_prb_t *p, mkldnn_rnn_desc_t rd[2],
92 mkldnn_primitive_desc_t rpd[2], res_t *r) {
93 const bool is_bwd = p->prop == mkldnn_backward;
94 // If we are testing backward, we have to first run forward
95 // training first in order to generate a valid workspace.
96 auto fwd_prop = is_bwd ? mkldnn_forward_training : mkldnn_forward_inference;
97 const bool is_gru_lbr = p->alg == LBR_GRU;
99 /// @todo we need to add stride support for diff_* tensors too
100 mkldnn_memory_desc_t input_d, states_d, weights_input_d, weights_states_d,
101 bias_d, dst_last_layer_d, dst_last_iteration_d, diff_input_d,
102 diff_states_d, diff_weights_input_d, diff_weights_states_d,
103 diff_bias_d, diff_last_layer_d, diff_last_iteration_d;
105 // dimensions with ref
106 mkldnn_dims_t input_dims = { p->n_iter, p->mb, p->slc };
107 // bidirectional = 2, s for lstm = 2, for all other = 1
108 mkldnn_dims_t weights_input_dims
109 = { p->n_layer, p->n_directions(), p->slc, p->n_gates(), p->dic };
110 mkldnn_dims_t weights_states_dims
111 = { p->n_layer, p->n_directions(), p->sic, p->n_gates(), p->dic };
112 mkldnn_dims_t bias_dims
113 = { p->n_layer, p->n_directions(), p->n_gates() + is_gru_lbr, p->dic };
115 int lastlay_dlc = (p->direction == mkldnn_bidirectional_concat)
118 mkldnn_dims_t dst_last_layer_dims = { p->n_iter, p->mb, lastlay_dlc };
120 DNN_SAFE(mkldnn_memory_desc_init(
121 &input_d, 3, input_dims, p->cfg[input].dt, mkldnn_tnc),
123 input_d.layout_desc.blocking.strides[0][0] += the_stride;
125 mkldnn_dims_t states_dims
126 = { p->n_layer, p->n_directions(), p->n_states(), p->mb, p->sic };
127 DNN_SAFE(mkldnn_memory_desc_init(&states_d, 5, states_dims,
128 p->cfg[states].dt, mkldnn_ldsnc),
131 states_d.layout_desc.blocking.strides[0][3] = p->sic + the_stride;
132 states_d.layout_desc.blocking.strides[0][2]
133 = states_d.layout_desc.blocking.strides[0][3] * states_d.dims[3]
135 for (int d = 1; d >= 0; --d)
136 states_d.layout_desc.blocking.strides[0][d]
137 = states_d.layout_desc.blocking.strides[0][d + 1]
138 * states_d.dims[d + 1];
140 DNN_SAFE(mkldnn_memory_desc_init(&weights_input_d, 5, weights_input_dims,
141 p->cfg[weights_input].dt, mkldnn_any),
144 DNN_SAFE(mkldnn_memory_desc_init(&weights_states_d, 5, weights_states_dims,
145 p->cfg[weights_states].dt, mkldnn_any),
148 DNN_SAFE(mkldnn_memory_desc_init(
149 &bias_d, 4, bias_dims, p->cfg[bias].dt, mkldnn_any),
152 DNN_SAFE(mkldnn_memory_desc_init(&dst_last_layer_d, 3, dst_last_layer_dims,
153 p->cfg[dst_last_layer].dt, mkldnn_tnc),
155 dst_last_layer_d.layout_desc.blocking.strides[0][0] += the_stride;
157 mkldnn_dims_t dst_last_iteration_dims
158 = { p->n_layer, p->n_directions(), p->n_states(), p->mb, p->dic };
159 DNN_SAFE(mkldnn_memory_desc_init(&dst_last_iteration_d, 5,
160 dst_last_iteration_dims, p->cfg[dst_last_iteration].dt,
164 dst_last_iteration_d.layout_desc.blocking.strides[0][3]
165 = p->sic + the_stride;
166 dst_last_iteration_d.layout_desc.blocking.strides[0][2]
167 = dst_last_iteration_d.layout_desc.blocking.strides[0][3]
168 * dst_last_iteration_d.dims[3]
170 for (int d = 1; d >= 0; --d)
171 dst_last_iteration_d.layout_desc.blocking.strides[0][d]
172 = dst_last_iteration_d.layout_desc.blocking.strides[0][d + 1]
173 * dst_last_iteration_d.dims[d + 1];
175 mkldnn_alg_kind_t kind = alg2kind(p->alg);
176 mkldnn_alg_kind_t f = activation2kind(p->activation);
178 mkldnn_rnn_cell_desc_t rcd;
179 DNN_SAFE(mkldnn_rnn_cell_desc_init(&rcd, kind, f, 0U, 0, 0), WARN);
180 // Initializing the forward pass
181 // When inference, we use forward_inference
182 // When training, we use forward_training
184 mkldnn_status_t init_status = mkldnn_success;
185 init_status = mkldnn_rnn_forward_desc_init(&rd[0], fwd_prop, &rcd,
186 p->direction, &input_d, &states_d, &weights_input_d,
187 &weights_states_d, &bias_d, &dst_last_layer_d,
188 &dst_last_iteration_d);
189 if (init_status == mkldnn_unimplemented)
190 return r->state = UNIMPLEMENTED, OK;
192 SAFE(init_status, WARN);
196 DNN_SAFE(mkldnn_memory_desc_init(&diff_input_d, 3, input_dims,
197 p->cfg[dst_diff_input].dt, mkldnn_any),
199 DNN_SAFE(mkldnn_memory_desc_init(&diff_states_d, 5, states_dims,
200 p->cfg[dst_diff_states].dt, mkldnn_any),
202 DNN_SAFE(mkldnn_memory_desc_init(&diff_weights_input_d, 5,
203 weights_input_dims, p->cfg[dst_diff_weights_input].dt,
206 DNN_SAFE(mkldnn_memory_desc_init(&diff_weights_states_d, 5,
208 p->cfg[dst_diff_weights_states].dt, mkldnn_any),
210 DNN_SAFE(mkldnn_memory_desc_init(&diff_bias_d, 4, bias_dims,
211 p->cfg[dst_diff_bias].dt, mkldnn_any),
213 DNN_SAFE(mkldnn_memory_desc_init(&diff_last_layer_d, 3,
214 dst_last_layer_dims, p->cfg[diff_last_layer].dt,
217 DNN_SAFE(mkldnn_memory_desc_init(&diff_last_iteration_d, 5,
218 dst_last_iteration_dims,
219 p->cfg[diff_last_iteration].dt, mkldnn_any),
221 DNN_SAFE(mkldnn_rnn_backward_desc_init(&rd[1], p->prop, &rcd,
222 p->direction, &input_d, &states_d, &weights_input_d,
223 &weights_states_d, &bias_d, &dst_last_layer_d,
224 &dst_last_iteration_d, &diff_input_d, &diff_states_d,
225 &diff_weights_input_d, &diff_weights_states_d,
226 &diff_bias_d, &diff_last_layer_d,
227 &diff_last_iteration_d),
230 auto mkldnn_attr = create_mkldnn_rnn_attr(p);
231 mkldnn_status_t init_status = mkldnn_success;
232 for (int i = 0; i < 1 + (int)is_bwd; i++) {
233 init_status = mkldnn_primitive_desc_create_v2(
234 &(rpd[i]), &(rd[i]), mkldnn_attr, engine, NULL);
235 if (init_status == mkldnn_unimplemented)
236 return r->state = UNIMPLEMENTED, OK;
238 SAFE(init_status, WARN);
240 mkldnn_primitive_attr_destroy(mkldnn_attr);
242 auto q = [=](mkldnn_query_t query, int rpd_idx, int index = 0) {
243 return *mkldnn_primitive_desc_query_memory_d(
244 mkldnn_primitive_desc_query_pd(rpd[rpd_idx], query, index));
247 for (int i = 0; i < 1 + (int)is_bwd; i++) {
248 rd[i].src_layer_desc = q(mkldnn_query_src_pd, i);
249 rd[i].src_iter_desc = q(mkldnn_query_src_pd, i, 1);
250 rd[i].weights_layer_desc = q(mkldnn_query_weights_pd, i);
251 rd[i].weights_iter_desc = q(mkldnn_query_weights_pd, i, 1);
252 rd[i].bias_desc = q(mkldnn_query_weights_pd, i, 2);
253 rd[i].dst_layer_desc = q(mkldnn_query_dst_pd, i);
254 rd[i].dst_iter_desc = q(mkldnn_query_dst_pd, i, 1);
257 rd[1].diff_src_layer_desc = q(mkldnn_query_diff_src_pd, 1);
258 rd[1].diff_src_iter_desc = q(mkldnn_query_diff_src_pd, 1, 1);
259 rd[1].diff_weights_layer_desc = q(mkldnn_query_diff_weights_pd, 1);
260 rd[1].diff_weights_iter_desc = q(mkldnn_query_diff_weights_pd, 1, 1);
261 rd[1].diff_bias_desc = q(mkldnn_query_diff_weights_pd, 1, 2);
262 rd[1].diff_dst_layer_desc = q(mkldnn_query_diff_dst_pd, 1);
263 rd[1].diff_dst_iter_desc = q(mkldnn_query_diff_dst_pd, 1, 1);
269 int doit(const rnn_prb_t *p, res_t *r) {
273 const auto fp = mkldnn_f32;
275 if (p->alg != VANILLA_LSTM && p->alg != VANILLA_RNN
276 && p->alg != VANILLA_GRU && p->alg != LBR_GRU) {
277 printf("p->alg: %d\n", (int)p->alg);
278 r->state = UNIMPLEMENTED;
282 const bool is_bwd = p->prop == mkldnn_backward;
284 dnn_mem_t *input_dt = nullptr;
285 dnn_mem_t *states_dt = nullptr;
286 dnn_mem_t *weights_input_dt = nullptr;
287 dnn_mem_t *weights_states_dt = nullptr;
288 dnn_mem_t *bias_dt = nullptr;
289 dnn_mem_t *dst_last_layer_dt = nullptr;
290 dnn_mem_t *dst_last_iteration_dt = nullptr;
292 dnn_mem_t *bwd_weights_input_dt = nullptr;
293 dnn_mem_t *bwd_weights_states_dt = nullptr;
294 dnn_mem_t *dst_diff_input_dt = nullptr;
295 dnn_mem_t *dst_diff_states_dt = nullptr;
296 dnn_mem_t *dst_diff_weights_input_dt = nullptr;
297 dnn_mem_t *dst_diff_weights_states_dt = nullptr;
298 dnn_mem_t *dst_diff_bias_dt = nullptr;
299 dnn_mem_t *diff_last_layer_dt = nullptr;
300 dnn_mem_t *diff_last_iteration_dt = nullptr;
302 dnn_mem_t *input_fp = nullptr;
303 dnn_mem_t *states_fp = nullptr;
304 dnn_mem_t *weights_input_fp = nullptr;
305 dnn_mem_t *weights_states_fp = nullptr;
306 dnn_mem_t *bias_fp = nullptr;
307 dnn_mem_t *dst_last_layer_fp = nullptr;
308 dnn_mem_t *dst_last_iteration_fp = nullptr;
310 dnn_mem_t *dst_diff_input_fp = nullptr;
311 dnn_mem_t *dst_diff_states_fp = nullptr;
312 dnn_mem_t *dst_diff_weights_input_fp = nullptr;
313 dnn_mem_t *dst_diff_weights_states_fp = nullptr;
314 dnn_mem_t *dst_diff_bias_fp = nullptr;
315 dnn_mem_t *diff_last_layer_fp = nullptr;
316 dnn_mem_t *diff_last_iteration_fp = nullptr;
318 dnn_mem_t *workspace_dt = nullptr;
320 mkldnn_rnn_desc_t rd[2];
321 mkldnn_primitive_desc_t rpd[2] = {nullptr};
322 mkldnn_primitive_t c{};
323 SAFE(init_pd(p, rd, rpd, r), WARN);
324 if (r->state == SKIPPED || r->state == UNIMPLEMENTED)
327 auto &input_dt_d = rd[0].src_layer_desc;
328 auto &states_dt_d = rd[0].src_iter_desc;
329 auto &weights_input_dt_d = rd[0].weights_layer_desc;
330 auto &weights_states_dt_d = rd[0].weights_iter_desc;
331 auto &bias_dt_d = rd[0].bias_desc;
332 auto &dst_last_layer_dt_d = rd[0].dst_layer_desc;
333 auto &dst_last_iteration_dt_d = rd[0].dst_iter_desc;
335 auto &bwd_weights_input_dt_d = rd[1].weights_layer_desc;
336 auto &bwd_weights_states_dt_d = rd[1].weights_iter_desc;
337 auto &diff_src_layer_dt_d = rd[1].diff_src_layer_desc;
338 auto &diff_src_iter_dt_d = rd[1].diff_src_iter_desc;
339 auto &diff_weights_layer_dt_d = rd[1].diff_weights_layer_desc;
340 auto &diff_weights_iter_dt_d = rd[1].diff_weights_iter_desc;
341 auto &diff_bias_dt_d = rd[1].diff_bias_desc;
342 auto &diff_dst_layer_dt_d = rd[1].diff_dst_layer_desc;
343 auto &diff_dst_iter_dt_d = rd[1].diff_dst_iter_desc;
345 input_dt = new dnn_mem_t(input_dt_d, p->cfg[input].dt);
346 states_dt = new dnn_mem_t(states_dt_d, p->cfg[states].dt);
348 = new dnn_mem_t(weights_input_dt_d, p->cfg[weights_input].dt);
350 = new dnn_mem_t(weights_states_dt_d, p->cfg[weights_states].dt);
351 bias_dt = new dnn_mem_t(bias_dt_d, p->cfg[bias].dt);
353 = new dnn_mem_t(dst_last_layer_dt_d, p->cfg[dst_last_layer].dt);
354 dst_last_iteration_dt = new dnn_mem_t(
355 dst_last_iteration_dt_d, p->cfg[dst_last_iteration].dt);
358 bwd_weights_input_dt = new dnn_mem_t(bwd_weights_input_dt_d, fp);
359 bwd_weights_states_dt = new dnn_mem_t(bwd_weights_states_dt_d, fp);
360 dst_diff_input_dt = new dnn_mem_t(diff_src_layer_dt_d, fp);
361 dst_diff_states_dt = new dnn_mem_t(diff_src_iter_dt_d, fp);
362 dst_diff_weights_input_dt = new dnn_mem_t(diff_weights_layer_dt_d, fp);
363 dst_diff_weights_states_dt = new dnn_mem_t(diff_weights_iter_dt_d, fp);
364 dst_diff_bias_dt = new dnn_mem_t(diff_bias_dt_d, fp);
365 diff_last_layer_dt = new dnn_mem_t(diff_dst_layer_dt_d, fp);
366 diff_last_iteration_dt = new dnn_mem_t(diff_dst_iter_dt_d, fp);
369 input_fp = new dnn_mem_t(input_dt_d, fp, mkldnn_tnc);
370 states_fp = new dnn_mem_t(states_dt_d, fp, mkldnn_ldsnc);
371 weights_input_fp = new dnn_mem_t(weights_input_dt_d, fp, mkldnn_ldigo);
372 weights_states_fp = new dnn_mem_t(weights_states_dt_d, fp, mkldnn_ldigo);
373 bias_fp = new dnn_mem_t(bias_dt_d, fp, mkldnn_ldgo);
374 dst_last_layer_fp = new dnn_mem_t(dst_last_layer_dt_d, fp, mkldnn_tnc);
375 dst_last_iteration_fp
376 = new dnn_mem_t(dst_last_iteration_dt_d, fp, mkldnn_ldsnc);
379 dst_diff_input_fp = new dnn_mem_t(diff_src_layer_dt_d, fp, mkldnn_tnc);
381 = new dnn_mem_t(diff_src_iter_dt_d, fp, mkldnn_ldsnc);
382 dst_diff_weights_input_fp
383 = new dnn_mem_t(diff_weights_layer_dt_d, fp, mkldnn_ldigo);
384 dst_diff_weights_states_fp
385 = new dnn_mem_t(diff_weights_iter_dt_d, fp, mkldnn_ldigo);
386 dst_diff_bias_fp = new dnn_mem_t(diff_bias_dt_d, fp, mkldnn_ldgo);
387 diff_last_layer_fp = new dnn_mem_t(diff_dst_layer_dt_d, fp, mkldnn_tnc);
388 diff_last_iteration_fp
389 = new dnn_mem_t(diff_dst_iter_dt_d, fp, mkldnn_ldsnc);
391 const auto ws_pd = mkldnn_primitive_desc_query_pd(
392 rpd[0], mkldnn_query_workspace_pd, 0);
393 SAFE(ws_pd != NULL ? OK : FAIL, WARN);
395 = new dnn_mem_t(*mkldnn_primitive_desc_query_memory_d(ws_pd));
398 SAFE(fill_memory(p, input, *input_dt, *input_fp), WARN);
399 SAFE(fill_memory(p, states, *states_dt, *states_fp), WARN);
400 SAFE(fill_memory(p, weights_input, *weights_input_dt, *weights_input_fp),
402 SAFE(fill_memory(p, weights_states, *weights_states_dt, *weights_states_fp),
404 SAFE(fill_memory(p, bias, *bias_dt, *bias_fp), WARN);
405 SAFE(fill_memory(p, dst_last_layer, *dst_last_layer_dt, *dst_last_layer_fp),
407 SAFE(fill_memory(p, dst_last_iteration, *dst_last_iteration_dt,
408 *dst_last_iteration_fp),
412 SAFE(bwd_weights_states_dt->reorder(*weights_states_dt), WARN);
413 SAFE(bwd_weights_input_dt->reorder(*weights_input_dt), WARN);
415 p, dst_diff_input, *dst_diff_input_dt, *dst_diff_input_fp),
417 SAFE(fill_memory(p, dst_diff_states, *dst_diff_states_dt,
418 *dst_diff_states_fp),
420 SAFE(fill_memory(p, dst_diff_weights_input, *dst_diff_weights_input_dt,
421 *dst_diff_weights_input_fp),
423 SAFE(fill_memory(p, dst_diff_weights_states,
424 *dst_diff_weights_states_dt, *dst_diff_weights_states_fp),
427 p, dst_diff_bias, *dst_diff_bias_dt, *dst_diff_bias_fp),
429 SAFE(fill_memory(p, diff_last_layer, *diff_last_layer_dt,
430 *diff_last_layer_fp),
432 SAFE(fill_memory(p, diff_last_iteration, *diff_last_iteration_dt,
433 *diff_last_iteration_fp),
437 // Running the forward pass
439 mkldnn_primitive_at_t inputs[] = { { input_dt->p_, 0 },
440 { states_dt->p_, 0 }, { weights_input_dt->p_, 0 },
441 { weights_states_dt->p_, 0 }, { bias_dt->p_, 0 } };
442 const_mkldnn_primitive_t outputs[] = { dst_last_layer_dt->p_,
443 dst_last_iteration_dt->p_, workspace_dt ? workspace_dt->p_ : 0 };
444 #ifdef CALL_MKLDNN_RNN
445 DNN_SAFE(mkldnn_primitive_create(&c, rpd[0], inputs, outputs), WARN);
446 SAFE(execute(c), WARN);
448 if ((p->prop == mkldnn_forward) && (bench_mode & CORR)) {
449 compute_ref_fwd(p, *input_fp, *states_fp, *weights_input_fp,
450 *weights_states_fp, *bias_fp, *dst_last_layer_fp,
451 *dst_last_iteration_fp, p->direction);
452 dnn_mem_t dst_last_layer(*dst_last_layer_dt, fp, mkldnn_tnc);
453 dnn_mem_t dst_last_iteration(
454 *dst_last_iteration_dt, fp, mkldnn_ldsnc);
455 SAFE(compare_dst_last_layer(
456 p, dst_last_layer, *dst_last_layer_fp, r, true),
458 SAFE(compare_dst_last_iteration(p, dst_last_iteration,
459 *dst_last_iteration_fp, r, true),
465 mkldnn_primitive_at_t inputs[] = {
466 { input_dt->p_, 0 }, { states_dt->p_, 0 },
467 { bwd_weights_input_dt->p_, 0 }, { bwd_weights_states_dt->p_, 0 },
468 { bias_dt->p_, 0 }, { dst_last_layer_dt->p_, 0 },
469 { dst_last_iteration_dt->p_, 0 }, { diff_last_layer_dt->p_, 0 },
470 { diff_last_iteration_dt->p_, 0 }, { workspace_dt->p_, 0 },
472 const_mkldnn_primitive_t outputs[] = { dst_diff_input_dt->p_,
473 dst_diff_states_dt->p_, dst_diff_weights_input_dt->p_,
474 dst_diff_weights_states_dt->p_, dst_diff_bias_dt->p_ };
476 #ifdef CALL_MKLDNN_RNN
477 DNN_SAFE(mkldnn_primitive_create(&c, rpd[1], inputs, outputs), WARN);
478 SAFE(execute(c), WARN);
481 if (bench_mode & CORR) {
482 compute_ref_bwd(p, *input_fp, *states_fp, *diff_last_layer_fp,
483 *diff_last_iteration_fp, *weights_input_fp,
484 *weights_states_fp, *bias_fp, *dst_last_layer_fp,
485 *dst_last_iteration_fp, *dst_diff_input_fp,
486 *dst_diff_states_fp, *dst_diff_weights_input_fp,
487 *dst_diff_weights_states_fp, *dst_diff_bias_fp,
490 dnn_mem_t dst_last_layer(*dst_last_layer_dt, fp, mkldnn_tnc);
491 dnn_mem_t dst_last_iteration(
492 *dst_last_iteration_dt, fp, mkldnn_ldsnc);
493 SAFE(compare_dst_last_layer(
494 p, dst_last_layer, *dst_last_layer_fp, r, true),
496 SAFE(compare_dst_last_iteration(p, dst_last_iteration,
497 *dst_last_iteration_fp, r, true),
500 dnn_mem_t diff_input(*dst_diff_input_dt, fp, mkldnn_tnc);
501 dnn_mem_t diff_states(*dst_diff_states_dt, fp, mkldnn_ldsnc);
502 SAFE(compare_input(p, diff_input, *dst_diff_input_fp, r, true),
504 SAFE(compare_states(p, diff_states, *dst_diff_states_fp, r, true),
507 dnn_mem_t diff_weights_input(
508 *dst_diff_weights_input_dt, fp, mkldnn_ldigo);
509 dnn_mem_t diff_weights_states(
510 *dst_diff_weights_states_dt, fp, mkldnn_ldigo);
511 SAFE(compare_weights_input(p, diff_weights_input,
512 *dst_diff_weights_input_fp, r, true),
514 SAFE(compare_weights_states(p, diff_weights_states,
515 *dst_diff_weights_states_fp, r, true),
518 dnn_mem_t diff_bias(*dst_diff_bias_dt, fp, mkldnn_ldgo);
519 SAFE(compare_bias(p, diff_bias, *dst_diff_bias_fp, r, true), WARN);
523 if (bench_mode & PERF) {
527 #ifdef CALL_MKLDNN_RNN
528 SAFE(execute(c), WARN);
531 const bool stop = false
532 || (fix_times_per_prb && t.times() >= fix_times_per_prb)
533 || (!fix_times_per_prb && t.total_ms() >= max_ms_per_prb
534 && t.times() >= min_times_per_prb);
543 delete weights_input_fp;
544 delete weights_states_fp;
546 delete dst_last_layer_fp;
547 delete dst_last_iteration_fp;
550 delete bwd_weights_input_dt;
551 delete bwd_weights_states_dt;
552 delete dst_diff_input_fp;
553 delete dst_diff_states_fp;
554 delete dst_diff_weights_input_fp;
555 delete dst_diff_weights_states_fp;
556 delete dst_diff_bias_fp;
557 delete diff_last_layer_fp;
558 delete diff_last_iteration_fp;
563 delete weights_input_dt;
564 delete weights_states_dt;
566 delete dst_last_layer_dt;
567 delete dst_last_iteration_dt;
570 delete dst_diff_input_dt;
571 delete dst_diff_states_dt;
572 delete dst_diff_weights_input_dt;
573 delete dst_diff_weights_states_dt;
574 delete dst_diff_bias_dt;
575 delete diff_last_layer_dt;
576 delete diff_last_iteration_dt;
581 DNN_SAFE(mkldnn_primitive_desc_destroy(rpd[0]), CRIT);
582 DNN_SAFE(mkldnn_primitive_desc_destroy(rpd[1]), CRIT);
583 DNN_SAFE(mkldnn_primitive_destroy(c), CRIT);