1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "c_types_map.hpp"
18 #include "math_utils.hpp"
19 #include "mkldnn_thread.hpp"
21 #include "ref_rnn.hpp"
22 #include "rnn_utils.hpp"
23 #include "type_helpers.hpp"
29 using namespace mkldnn::impl::utils;
30 using namespace rnn_utils;
31 using namespace memory_format;
32 using namespace rnn_packed_format;
33 using namespace data_type;
35 void rnn_utils::init_conf(rnn_conf_t &rnn, const rnn_desc_t &rd,
36 const memory_desc_wrapper &src_layer_d,
37 const memory_desc_wrapper &src_iter_d,
38 const memory_desc_wrapper &weights_layer_d,
39 const memory_desc_wrapper &weights_iter_d,
40 const memory_desc_wrapper &dst_layer_d) {
41 rnn.is_fwd = utils::one_of(rd.prop_kind, prop_kind::forward_training,
42 prop_kind::forward_inference);
43 rnn.is_training = utils::one_of(
44 rd.prop_kind, prop_kind::forward_training, prop_kind::backward);
45 rnn.is_lbr = rd.cell_desc.cell_kind == mkldnn_gru_linear_before_reset;
47 switch (rd.direction) {
48 case mkldnn_unidirectional_left2right: rnn.exec_dir = l2r; break;
49 case mkldnn_unidirectional_right2left: rnn.exec_dir = r2l; break;
50 case mkldnn_bidirectional_concat: rnn.exec_dir = bi_concat; break;
51 case mkldnn_bidirectional_sum: rnn.exec_dir = bi_sum; break;
55 if (everyone_is(f32, src_layer_d.data_type(), dst_layer_d.data_type(),
56 weights_layer_d.data_type()))
57 rnn.dt_conf = all_f32;
58 else if (dst_layer_d.data_type() == u8) {
59 if (IMPLICATION(src_iter_d._md, src_iter_d.data_type() == u8))
60 rnn.dt_conf = u8u8u8u8;
62 rnn.dt_conf = f32u8f32u8;
64 if (IMPLICATION(src_iter_d._md, src_iter_d.data_type() == u8))
65 rnn.dt_conf = u8u8u8f32;
67 rnn.dt_conf = f32u8f32f32;
70 rnn.n_layer = weights_layer_d.dims()[0];
71 rnn.n_iter = src_layer_d.dims()[0];
72 rnn.n_dir = weights_layer_d.dims()[1];
73 rnn.n_gates = weights_layer_d.dims()[3];
74 rnn.n_states = mkldnn_rnn_cell_get_states_count(&rd.cell_desc);
75 rnn.n_bias = rnn.n_gates + rnn.is_lbr;
76 rnn.mb = src_layer_d.dims()[1];
77 rnn.sic = weights_iter_d.dims()[2];
78 rnn.slc = weights_layer_d.dims()[2];
79 rnn.dic = weights_layer_d.dims()[4];
80 rnn.dlc = dst_layer_d.dims()[2];
82 rnn.gates_ld = rnn.dic * rnn.n_gates;
83 rnn.gates_nld = rnn.mb;
84 rnn.states_nld = rnn.mb;
86 /* Set the correct number of weights parts */
87 bool is_orig_gru = rd.cell_desc.cell_kind == alg_kind::vanilla_gru;
88 rnn.n_parts_weights_layer = 1;
89 rnn.parts_weights_layer[0] = rnn.n_gates;
90 rnn.parts_weights_layer[1] = 0;
92 rnn.n_parts_weights_iter = is_orig_gru ? 2 : 1;
93 rnn.parts_weights_iter[0] = is_orig_gru ? 2 : rnn.n_gates;
94 rnn.parts_weights_iter[1] = is_orig_gru ? 1 : 0;
97 rnn.parts_bias[0] = rnn.n_bias;
98 rnn.parts_bias[1] = 0;
100 /* Decide wich gemm implementation to use: packed/nonpacked jit/cblas
101 * and if to mergre gemm across iterations */
102 bool is_int8 = rnn.dt_conf != all_f32;
103 rnn.merge_gemm_layer = ((rnn.is_fwd && rnn.mb < 128) || !rnn.is_fwd)
105 bool is_gru = utils::one_of(rd.cell_desc.cell_kind, alg_kind::vanilla_gru,
106 alg_kind::gru_linear_before_reset);
107 rnn.merge_gemm_iter = !(rnn.is_fwd || is_gru) || is_int8;
108 bool is_inference = !rnn.is_training;
110 rnn.use_jit_gemm = !mayiuse(avx512_mic)
111 && ((is_inference && (rnn.n_layer > 1 || rnn.mb < 100))
112 || (rnn.is_training && rnn.dic < 500));
114 /* Decide to copy bias */
115 rnn.copy_bias = rnn.dt_conf != all_f32;
117 #if USE_MKL_PACKED_GEMM
118 rnn.use_layer_packed_gemm
119 = (weights_layer_d.format() == any && rnn.slc > 760 && rnn.dic > 760
121 || is_int8; // packed gemm is the only supported option for int8
122 rnn.use_iter_packed_gemm = (weights_iter_d.format() == any && rnn.sic > 760
123 && rnn.dic > 760 && is_inference)
126 rnn.use_layer_packed_gemm = false;
127 rnn.use_iter_packed_gemm = false;
130 /* Set packed gemm sizes */
131 if (rnn.use_layer_packed_gemm) {
132 rnn.weights_layer_pack_size = 0;
133 for (int p = 0; p < rnn.n_parts_weights_layer; p++) {
135 ? (rnn.parts_weights_layer[p] * rnn.dic)
139 : (rnn.parts_weights_layer[p] * rnn.dic);
140 int n_p = rnn.merge_gemm_layer ? rnn.mb * rnn.n_iter : rnn.mb;
142 #if USE_MKL_PACKED_GEMM
143 if (rnn.dt_conf == all_f32)
144 rnn.part_weights_layer_pack_size[p] = cblas_sgemm_pack_get_size(
145 CblasAMatrix, m_p, n_p, k_p);
147 rnn.part_weights_layer_pack_size[p]
148 = cblas_gemm_s8u8s32_pack_get_size(
149 CblasAMatrix, m_p, n_p, k_p);
154 rnn.part_weights_layer_pack_size[p] = 0;
156 rnn.weights_layer_pack_size += rnn.n_layer * rnn.n_dir
157 * rnn.part_weights_layer_pack_size[p];
159 rnn.weights_layer_comp_offset = rnn.weights_layer_pack_size;
160 rnn.weights_layer_pack_size += rnn.dt_conf == all_f32 ? 0 : rnn.n_layer
161 * rnn.n_dir * rnn.n_gates * rnn.dlc * sizeof(float);
164 if (rnn.use_iter_packed_gemm) {
165 rnn.weights_iter_pack_size = 0;
166 for (int p = 0; p < rnn.n_parts_weights_iter; p++) {
167 int m_p = rnn.is_fwd ? (rnn.parts_weights_iter[p] * rnn.dic) :
169 int k_p = rnn.is_fwd ? rnn.sic :
170 (rnn.parts_weights_iter[p] * rnn.dic);
171 int n_p = rnn.merge_gemm_iter ? rnn.mb * rnn.n_iter : rnn.mb;
173 #if USE_MKL_PACKED_GEMM
174 if (rnn.dt_conf == all_f32)
175 rnn.part_weights_iter_pack_size[p] = cblas_sgemm_pack_get_size(
176 CblasAMatrix, m_p, n_p, k_p);
178 rnn.part_weights_iter_pack_size[p]
179 = cblas_gemm_s8u8s32_pack_get_size(
180 CblasAMatrix, m_p, n_p, k_p);
185 rnn.part_weights_iter_pack_size[p] = 0;
187 rnn.weights_iter_pack_size += rnn.n_layer * rnn.n_dir
188 * rnn.part_weights_iter_pack_size[p];
190 rnn.weights_iter_comp_offset = rnn.weights_iter_pack_size;
191 rnn.weights_iter_pack_size += rnn.dt_conf == all_f32 ? 0 : rnn.n_layer
192 * rnn.n_dir * rnn.n_gates * rnn.dic * sizeof(float);
197 void rnn_utils::set_conf(rnn_conf_t &rnn, const rnn_desc_t &rd,
198 const memory_desc_wrapper &weights_layer_d,
199 const memory_desc_wrapper &weights_iter_d,
200 const memory_desc_wrapper &diff_weights_layer_d,
201 const memory_desc_wrapper &diff_weights_iter_d) {
203 /* Set leading dimensions for input weights arrays depending on input format
205 rnn.weights_layer_fmt = weights_layer_d.format();
206 rnn.weights_iter_fmt = weights_iter_d.format();
207 rnn.weights_layer_is_packed = rnn.weights_layer_fmt == rnn_packed;
208 rnn.weights_iter_is_packed = rnn.weights_iter_fmt == rnn_packed;
210 auto set_dims = [&](const memory_desc_wrapper &md, int &ld, int &nld) {
211 switch (md.format()) {
213 ld = (int)md.blocking_desc().strides[0][2];
217 ld = (int)md.blocking_desc().strides[0][4];
218 nld = md.dims()[3] * md.dims()[4];
220 default: ld = 0; nld = 0;
223 set_dims(weights_layer_d, rnn.weights_layer_ld, rnn.weights_layer_nld);
224 set_dims(weights_iter_d, rnn.weights_iter_ld, rnn.weights_iter_nld);
226 set_dims(diff_weights_layer_d, rnn.diff_weights_layer_ld,
227 rnn.diff_weights_layer_nld);
228 set_dims(diff_weights_iter_d, rnn.diff_weights_iter_ld,
229 rnn.diff_weights_iter_nld);
233 = rnn.dt_conf == all_f32 ? sizeof(float) : sizeof(uint8_t);
235 = get_good_ld(nstl::max(rnn.slc, nstl::max(rnn.sic, rnn.dic)),
237 rnn.gates_ws_ld = get_good_ld(rnn.gates_ld, sizeof(float));
239 /* Set workspace sizes to store:
240 * states to copmute a pass
241 * diff states to copmute bwd pass (training only)
242 * intermediate results from the gates
244 rnn.use_workspace = rnn.is_training;
245 rnn.ws_states_size = (size_t)(rnn.n_layer + 1) * rnn.n_dir
246 * (rnn.n_iter + 1) * rnn.mb * rnn.states_ws_ld * sizeof_states_dt;
247 bool is_lstm = rd.cell_desc.cell_kind == mkldnn_vanilla_lstm;
248 rnn.ws_c_states_size = is_lstm
249 ? (size_t)(rnn.n_layer + 1) * rnn.n_dir * (rnn.n_iter + 1) * rnn.mb
250 * rnn.states_ws_ld * sizeof(float)
252 rnn.ws_diff_states_size = rnn.is_training
253 ? (size_t)(rnn.n_layer + 1) * rnn.n_dir * (rnn.n_iter + 1)
254 * (rnn.n_states + 1) * rnn.mb * rnn.states_ws_ld
257 rnn.ws_gates_size = (size_t)rnn.n_layer * rnn.n_dir * rnn.n_iter * rnn.mb
258 * rnn.gates_ws_ld * sizeof(float);
260 /* set other sizes */
261 rnn.ws_per_cell = (size_t)rnn.is_lbr * rnn.mb * rnn.dic * sizeof(float);
262 rnn.ws_cell_comp_size
263 = rnn.is_lbr || rnn.dt_conf != all_f32
264 ? (size_t) rnn.gates_nld * rnn.gates_ws_ld * sizeof(float)
266 rnn.ws_grid_comp_size = (size_t)rnn.is_lbr * rnn.is_training * rnn.n_layer
267 * rnn.n_dir * rnn.n_iter * rnn.ws_per_cell * sizeof(float);
268 rnn.ws_bias_size = (size_t)rnn.n_layer * rnn.n_dir * rnn.n_bias * rnn.dic
272 int rnn_utils::get_good_ld(int dim, int sizeof_dt) {
273 // we want matrices leading dimentions to be 64-byte aligned,
274 // and not divisible by 256 to avoid 4K aliasing effects
275 int ld = rnd_up(dim, 64 / sizeof_dt);
276 return (ld % 256 == 0) ? ld + 64 / sizeof_dt : ld;
279 void rnn_utils::set_offsets(const rnn_conf_t &rnn, size_t &ws_gates_offset,
280 size_t &ws_states_offset, size_t &ws_c_states_offset,
281 size_t &ws_diff_states_offset, size_t &ws_grid_comp_offset,
282 size_t &ws_cell_comp_offset, size_t &ws_bias_offset,
283 size_t &scratchpad_size, size_t &workspace_size) {
285 const size_t page_size = 4096; // 2097152;
286 size_t current_offset;
287 /* Mandatory workspaces: go to workspace if use_workspace, scratchpad
289 current_offset = 0; // assumes the workspace base pointer is page aligned
290 ws_gates_offset = current_offset;
291 current_offset += rnn.ws_gates_size;
293 current_offset = utils::rnd_up(current_offset, page_size);
294 ws_states_offset = current_offset;
295 current_offset += rnn.ws_states_size;
297 current_offset = utils::rnd_up(current_offset, page_size);
298 ws_c_states_offset = current_offset;
299 current_offset += rnn.ws_c_states_size;
301 current_offset = utils::rnd_up(current_offset, page_size);
302 ws_diff_states_offset = current_offset;
303 current_offset += rnn.ws_diff_states_size;
305 current_offset = utils::rnd_up(current_offset, page_size);
306 ws_grid_comp_offset = current_offset;
307 current_offset += rnn.ws_grid_comp_size;
309 current_offset = utils::rnd_up(current_offset, page_size);
310 ws_cell_comp_offset = current_offset;
311 current_offset += rnn.ws_cell_comp_size;
313 workspace_size = rnn.use_workspace ? current_offset : 0;
315 /* Optional scratchpads */
316 // Assumes the scratchpad base pointer is page aligned.
317 // If use_workspace, the following goes to scratchpad alone,
318 // otherwise, all goes to scratchpad and continue incrementing offset
319 current_offset = rnn.use_workspace ? 0 : current_offset;
322 current_offset = utils::rnd_up(current_offset, page_size);
323 ws_bias_offset = current_offset;
324 current_offset += rnn.ws_bias_size;
327 scratchpad_size = current_offset;
330 void rnn_utils::get_scratchpad_and_workspace_sizes(const rnn_conf_t &rnn,
331 size_t &scratchpad_size, size_t &workspace_size) {
332 size_t ws_gates_offset, ws_states_offset, ws_c_states_offset,
333 ws_diff_states_offset, ws_grid_comp_offset, ws_cell_comp_offset,
335 set_offsets(rnn, ws_gates_offset, ws_states_offset, ws_diff_states_offset,
336 ws_c_states_offset, ws_grid_comp_offset, ws_cell_comp_offset,
337 ws_bias_offset, scratchpad_size, workspace_size);
340 status_t rnn_utils::set_good_strides(memory_desc_t &weights_md) {
341 auto &strides = weights_md.layout_desc.blocking.strides[0];
342 auto dims = weights_md.dims;
344 if (weights_md.format == ldigo) {
345 strides[2] = rnn_utils::get_good_ld((int)strides[2],
346 (int)types::data_type_size(weights_md.data_type));
347 strides[1] = dims[2] * strides[2];
348 strides[0] = dims[1] * strides[1];
349 } else if (weights_md.format == ldgoi) {
350 strides[4] = rnn_utils::get_good_ld((int)strides[4],
351 (int)types::data_type_size(weights_md.data_type));
352 strides[3] = dims[4] * strides[4];
353 strides[1] = dims[3] * strides[3];
354 strides[0] = dims[1] * strides[1];
356 return unimplemented;
361 status_t rnn_utils::set_expected_desc(rnn_conf_t &rnn,
362 memory_desc_t &weights_md, bool is_iter) {
363 bool use_packed_gemm = is_iter
364 ? rnn.use_iter_packed_gemm
365 : rnn.use_layer_packed_gemm;
366 if (use_packed_gemm) {
367 weights_md.format = rnn_packed;
368 rnn_packed_data_t &rnn_pdata = weights_md.layout_desc.rnn_packed_desc;
369 rnn_pdata.format = rnn.is_fwd ? mkldnn_ldigo_p : mkldnn_ldgoi_p;
371 rnn_pdata.n = rnn.mb;
372 rnn_pdata.n_parts = rnn.n_parts_weights_iter;
373 array_copy(rnn_pdata.parts, rnn.parts_weights_iter,
374 MKLDNN_RNN_MAX_N_PARTS);
375 array_copy(rnn_pdata.part_pack_size,
376 rnn.part_weights_iter_pack_size, MKLDNN_RNN_MAX_N_PARTS);
377 rnn_pdata.offset_compensation = rnn.weights_iter_comp_offset;
378 rnn_pdata.size = rnn.weights_iter_pack_size;
380 rnn_pdata.n = rnn.merge_gemm_layer ? rnn.n_iter * rnn.mb : rnn.mb;
381 rnn_pdata.n_parts = rnn.n_parts_weights_layer;
382 array_copy(rnn_pdata.parts, rnn.parts_weights_layer,
383 MKLDNN_RNN_MAX_N_PARTS);
384 array_copy(rnn_pdata.part_pack_size,
385 rnn.part_weights_layer_pack_size, MKLDNN_RNN_MAX_N_PARTS);
386 rnn_pdata.offset_compensation = rnn.weights_layer_comp_offset;
387 rnn_pdata.size = rnn.weights_layer_pack_size;
390 weights_md.format = rnn.is_fwd ? ldigo : ldgoi;
391 CHECK(memory_desc_wrapper::compute_blocking(weights_md));
392 // Adjust strides for good leading dimension in GEMM
393 CHECK(set_good_strides(weights_md));