#define CALL_MKLDNN_RNN 1
+mkldnn_primitive_attr_t create_mkldnn_rnn_attr(const rnn_prb_t *p) {
+ mkldnn_primitive_attr_t mkldnn_attr = NULL;
+
+ DNN_SAFE_V(mkldnn_primitive_attr_create(&mkldnn_attr));
+ if (p->attr.irmode != attr_t::round_mode_t::NEAREST)
+ DNN_SAFE_V(mkldnn_primitive_attr_set_int_output_round_mode(
+ mkldnn_attr, (mkldnn_round_mode_t)p->attr.irmode));
+
+ if (p->scale_policy == PER_OC) {
+ DNN_SAFE_V(mkldnn_primitive_attr_set_rnn_weights_qparams(
+ mkldnn_attr, p->dic * p->n_gates(), 0x3, p->wei_oc_scales));
+ } else if (p->scale_policy == COMMON && p->wei_scale != 1.) {
+ DNN_SAFE_V(mkldnn_primitive_attr_set_rnn_weights_qparams(
+ mkldnn_attr, 1, 0, &p->wei_scale));
+ }
+
+ if (p->data_scale != 1.0 || p->data_shift != 0.0) {
+ DNN_SAFE_V(mkldnn_primitive_attr_set_rnn_data_qparams(
+ mkldnn_attr, p->data_scale, p->data_shift));
+ }
+
+ return mkldnn_attr;
+}
+
int fill_memory(const rnn_prb_t *p, rnn_data_kind_t kind, dnn_mem_t &mem1,
dnn_mem_t &mem2) {
#ifdef CALL_MKLDNN_RNN
#else
const size_t nelems = mem2.nelems();
#endif
- size_t nchunks = mkldnn_get_max_threads();
- size_t chunk_size = (nelems + nchunks - 1) / nchunks;
+ dt_conf_t c = p->cfg[kind];
+ float mean = c.f_mean, var = c.f_var, min = c.f_min, max = c.f_max;
mkldnn::impl::parallel(0, [&](int ithr, int nthr) {
+ size_t chunk_size = (nelems + nthr - 1) / nthr;
size_t idx_start = ithr * chunk_size;
size_t idx_end = MIN2(idx_start + chunk_size, nelems);
-
std::minstd_rand msr;
- std::normal_distribution<float> gen(.0f, .001f);
+ msr.seed((unsigned long int)kind);
+ std::normal_distribution<float> gen(mean, var);
msr.discard(idx_start);
-
- for (size_t idx = idx_start; idx < idx_end; ++idx){
- auto val = gen(msr);
- mem2.set_elem(idx, MAX2(MIN2(val, 1.0f), -1.0f));
+ for (size_t idx = idx_start; idx < idx_end; ++idx) {
+ auto val = (c.dt == mkldnn_f32) ? gen(msr) : round(gen(msr));
+ mem2.set_elem(idx, MAX2(MIN2(val, max), min));
}
});
mkldnn_dims_t bias_dims
= { p->n_layer, p->n_directions(), p->n_gates() + is_gru_lbr, p->dic };
// mkldnn_tnc
- int lastlay_dlc = (p->direction == mkldnn_bidirectional_concat) ?
- 2 * p->dlc :
- p->dlc;
+ int lastlay_dlc = (p->direction == mkldnn_bidirectional_concat)
+ ? 2 * p->dlc
+ : p->dlc;
mkldnn_dims_t dst_last_layer_dims = { p->n_iter, p->mb, lastlay_dlc };
DNN_SAFE(mkldnn_memory_desc_init(
- &input_d, 3, input_dims, p->cfg[SRC].dt, mkldnn_tnc),
+ &input_d, 3, input_dims, p->cfg[input].dt, mkldnn_tnc),
WARN);
input_d.layout_desc.blocking.strides[0][0] += the_stride;
- DNN_SAFE(mkldnn_memory_desc_init(
- &diff_input_d, 3, input_dims, p->cfg[SRC].dt, mkldnn_any),
- WARN);
mkldnn_dims_t states_dims
= { p->n_layer, p->n_directions(), p->n_states(), p->mb, p->sic };
- DNN_SAFE(mkldnn_memory_desc_init(
- &states_d, 5, states_dims, p->cfg[SRC].dt, mkldnn_ldsnc),
+ DNN_SAFE(mkldnn_memory_desc_init(&states_d, 5, states_dims,
+ p->cfg[states].dt, mkldnn_ldsnc),
WARN);
states_d.layout_desc.blocking.strides[0][3] = p->sic + the_stride;
= states_d.layout_desc.blocking.strides[0][d + 1]
* states_d.dims[d + 1];
- DNN_SAFE(mkldnn_memory_desc_init(&diff_states_d, 5, states_dims,
- p->cfg[SRC].dt, mkldnn_any),
- WARN);
-
DNN_SAFE(mkldnn_memory_desc_init(&weights_input_d, 5, weights_input_dims,
- p->cfg[SRC].dt, mkldnn_any),
- WARN);
- DNN_SAFE(mkldnn_memory_desc_init(&diff_weights_input_d, 5,
- weights_input_dims, p->cfg[SRC].dt, mkldnn_any),
+ p->cfg[weights_input].dt, mkldnn_any),
WARN);
DNN_SAFE(mkldnn_memory_desc_init(&weights_states_d, 5, weights_states_dims,
- p->cfg[SRC].dt, mkldnn_any),
- WARN);
- DNN_SAFE(mkldnn_memory_desc_init(&diff_weights_states_d, 5,
- weights_states_dims, p->cfg[SRC].dt, mkldnn_any),
+ p->cfg[weights_states].dt, mkldnn_any),
WARN);
DNN_SAFE(mkldnn_memory_desc_init(
- &bias_d, 4, bias_dims, p->cfg[SRC].dt, mkldnn_any),
- WARN);
- DNN_SAFE(mkldnn_memory_desc_init(
- &diff_bias_d, 4, bias_dims, p->cfg[SRC].dt, mkldnn_any),
+ &bias_d, 4, bias_dims, p->cfg[bias].dt, mkldnn_any),
WARN);
DNN_SAFE(mkldnn_memory_desc_init(&dst_last_layer_d, 3, dst_last_layer_dims,
- p->cfg[SRC].dt, mkldnn_tnc),
+ p->cfg[dst_last_layer].dt, mkldnn_tnc),
WARN);
dst_last_layer_d.layout_desc.blocking.strides[0][0] += the_stride;
- DNN_SAFE(mkldnn_memory_desc_init(&diff_last_layer_d, 3, dst_last_layer_dims,
- p->cfg[SRC].dt, mkldnn_any),
- WARN);
mkldnn_dims_t dst_last_iteration_dims
= { p->n_layer, p->n_directions(), p->n_states(), p->mb, p->dic };
DNN_SAFE(mkldnn_memory_desc_init(&dst_last_iteration_d, 5,
- dst_last_iteration_dims, p->cfg[SRC].dt, mkldnn_ldsnc),
+ dst_last_iteration_dims, p->cfg[dst_last_iteration].dt,
+ mkldnn_ldsnc),
WARN);
dst_last_iteration_d.layout_desc.blocking.strides[0][3]
= dst_last_iteration_d.layout_desc.blocking.strides[0][d + 1]
* dst_last_iteration_d.dims[d + 1];
- DNN_SAFE(mkldnn_memory_desc_init(&diff_last_iteration_d, 5,
- dst_last_iteration_dims, p->cfg[SRC].dt, mkldnn_any),
- WARN);
-
mkldnn_alg_kind_t kind = alg2kind(p->alg);
mkldnn_alg_kind_t f = activation2kind(p->activation);
// When inference, we use forward_inference
// When training, we use forward_training
{
- DNN_SAFE(mkldnn_rnn_forward_desc_init(&rd[0], fwd_prop, &rcd,
+ mkldnn_status_t init_status = mkldnn_success;
+ init_status = mkldnn_rnn_forward_desc_init(&rd[0], fwd_prop, &rcd,
p->direction, &input_d, &states_d, &weights_input_d,
&weights_states_d, &bias_d, &dst_last_layer_d,
- &dst_last_iteration_d),
- WARN);
+ &dst_last_iteration_d);
+ if (init_status == mkldnn_unimplemented)
+ return r->state = UNIMPLEMENTED, OK;
+ else
+ SAFE(init_status, WARN);
}
if (is_bwd) {
+ DNN_SAFE(mkldnn_memory_desc_init(&diff_input_d, 3, input_dims,
+ p->cfg[dst_diff_input].dt, mkldnn_any),
+ WARN);
+ DNN_SAFE(mkldnn_memory_desc_init(&diff_states_d, 5, states_dims,
+ p->cfg[dst_diff_states].dt, mkldnn_any),
+ WARN);
+ DNN_SAFE(mkldnn_memory_desc_init(&diff_weights_input_d, 5,
+ weights_input_dims, p->cfg[dst_diff_weights_input].dt,
+ mkldnn_any),
+ WARN);
+ DNN_SAFE(mkldnn_memory_desc_init(&diff_weights_states_d, 5,
+ weights_states_dims,
+ p->cfg[dst_diff_weights_states].dt, mkldnn_any),
+ WARN);
+ DNN_SAFE(mkldnn_memory_desc_init(&diff_bias_d, 4, bias_dims,
+ p->cfg[dst_diff_bias].dt, mkldnn_any),
+ WARN);
+ DNN_SAFE(mkldnn_memory_desc_init(&diff_last_layer_d, 3,
+ dst_last_layer_dims, p->cfg[diff_last_layer].dt,
+ mkldnn_any),
+ WARN);
+ DNN_SAFE(mkldnn_memory_desc_init(&diff_last_iteration_d, 5,
+ dst_last_iteration_dims,
+ p->cfg[diff_last_iteration].dt, mkldnn_any),
+ WARN);
DNN_SAFE(mkldnn_rnn_backward_desc_init(&rd[1], p->prop, &rcd,
p->direction, &input_d, &states_d, &weights_input_d,
&weights_states_d, &bias_d, &dst_last_layer_d,
&diff_last_iteration_d),
WARN);
}
+ auto mkldnn_attr = create_mkldnn_rnn_attr(p);
mkldnn_status_t init_status = mkldnn_success;
for (int i = 0; i < 1 + (int)is_bwd; i++) {
- init_status = mkldnn_primitive_desc_create(
- &(rpd[i]), &(rd[i]), engine, NULL);
+ init_status = mkldnn_primitive_desc_create_v2(
+ &(rpd[i]), &(rd[i]), mkldnn_attr, engine, NULL);
if (init_status == mkldnn_unimplemented)
return r->state = UNIMPLEMENTED, OK;
else
SAFE(init_status, WARN);
}
-
- // const char *impl_str = query_impl_info(rpd);
+ mkldnn_primitive_attr_destroy(mkldnn_attr);
auto q = [=](mkldnn_query_t query, int rpd_idx, int index = 0) {
return *mkldnn_primitive_desc_query_memory_d(
auto &diff_dst_layer_dt_d = rd[1].diff_dst_layer_desc;
auto &diff_dst_iter_dt_d = rd[1].diff_dst_iter_desc;
- input_dt = new dnn_mem_t(input_dt_d, fp);
- states_dt = new dnn_mem_t(states_dt_d, fp);
- weights_input_dt = new dnn_mem_t(weights_input_dt_d, fp);
- weights_states_dt = new dnn_mem_t(weights_states_dt_d, fp);
- bias_dt = new dnn_mem_t(bias_dt_d, fp);
- dst_last_layer_dt = new dnn_mem_t(dst_last_layer_dt_d, fp);
- dst_last_iteration_dt = new dnn_mem_t(dst_last_iteration_dt_d, fp);
+ input_dt = new dnn_mem_t(input_dt_d, p->cfg[input].dt);
+ states_dt = new dnn_mem_t(states_dt_d, p->cfg[states].dt);
+ weights_input_dt
+ = new dnn_mem_t(weights_input_dt_d, p->cfg[weights_input].dt);
+ weights_states_dt
+ = new dnn_mem_t(weights_states_dt_d, p->cfg[weights_states].dt);
+ bias_dt = new dnn_mem_t(bias_dt_d, p->cfg[bias].dt);
+ dst_last_layer_dt
+ = new dnn_mem_t(dst_last_layer_dt_d, p->cfg[dst_last_layer].dt);
+ dst_last_iteration_dt = new dnn_mem_t(
+ dst_last_iteration_dt_d, p->cfg[dst_last_iteration].dt);
if (is_bwd) {
bwd_weights_input_dt = new dnn_mem_t(bwd_weights_input_dt_d, fp);
dnn_mem_t dst_last_layer(*dst_last_layer_dt, fp, mkldnn_tnc);
dnn_mem_t dst_last_iteration(
*dst_last_iteration_dt, fp, mkldnn_ldsnc);
- SAFE(dst_last_layer.reorder(*dst_last_layer_dt), WARN);
- SAFE(dst_last_iteration.reorder(*dst_last_iteration_dt), WARN);
SAFE(compare_dst_last_layer(
p, dst_last_layer, *dst_last_layer_fp, r, true),
WARN);
dnn_mem_t dst_last_layer(*dst_last_layer_dt, fp, mkldnn_tnc);
dnn_mem_t dst_last_iteration(
*dst_last_iteration_dt, fp, mkldnn_ldsnc);
- SAFE(dst_last_layer.reorder(*dst_last_layer_dt), WARN);
- SAFE(dst_last_iteration.reorder(*dst_last_iteration_dt), WARN);
SAFE(compare_dst_last_layer(
p, dst_last_layer, *dst_last_layer_fp, r, true),
WARN);
dnn_mem_t diff_input(*dst_diff_input_dt, fp, mkldnn_tnc);
dnn_mem_t diff_states(*dst_diff_states_dt, fp, mkldnn_ldsnc);
- SAFE(diff_input.reorder(*dst_diff_input_dt), WARN);
- SAFE(diff_states.reorder(*dst_diff_states_dt), WARN);
SAFE(compare_input(p, diff_input, *dst_diff_input_fp, r, true),
WARN);
SAFE(compare_states(p, diff_states, *dst_diff_states_fp, r, true),
*dst_diff_weights_input_dt, fp, mkldnn_ldigo);
dnn_mem_t diff_weights_states(
*dst_diff_weights_states_dt, fp, mkldnn_ldigo);
- SAFE(diff_weights_input.reorder(*dst_diff_weights_input_dt), WARN);
- SAFE(diff_weights_states.reorder(*dst_diff_weights_states_dt),
- WARN);
SAFE(compare_weights_input(p, diff_weights_input,
*dst_diff_weights_input_fp, r, true),
WARN);
WARN);
dnn_mem_t diff_bias(*dst_diff_bias_dt, fp, mkldnn_ldgo);
- SAFE(diff_bias.reorder(*dst_diff_bias_dt), WARN);
SAFE(compare_bias(p, diff_bias, *dst_diff_bias_fp, r, true), WARN);
}
}