Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / rnn / rnn_utils.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "c_types_map.hpp"
18 #include "math_utils.hpp"
19 #include "mkldnn_thread.hpp"
20
21 #include "ref_rnn.hpp"
22 #include "rnn_utils.hpp"
23 #include "type_helpers.hpp"
24
25 namespace mkldnn {
26 namespace impl {
27 namespace cpu {
28
29 using namespace mkldnn::impl::utils;
30 using namespace rnn_utils;
31 using namespace memory_format;
32 using namespace rnn_packed_format;
33 using namespace data_type;
34
35 void rnn_utils::init_conf(rnn_conf_t &rnn, const rnn_desc_t &rd,
36         const memory_desc_wrapper &src_layer_d,
37         const memory_desc_wrapper &src_iter_d,
38         const memory_desc_wrapper &weights_layer_d,
39         const memory_desc_wrapper &weights_iter_d,
40         const memory_desc_wrapper &dst_layer_d) {
41     rnn.is_fwd = utils::one_of(rd.prop_kind, prop_kind::forward_training,
42             prop_kind::forward_inference);
43     rnn.is_training = utils::one_of(
44             rd.prop_kind, prop_kind::forward_training, prop_kind::backward);
45     rnn.is_lbr = rd.cell_desc.cell_kind == mkldnn_gru_linear_before_reset;
46
47     switch (rd.direction) {
48     case mkldnn_unidirectional_left2right: rnn.exec_dir = l2r; break;
49     case mkldnn_unidirectional_right2left: rnn.exec_dir = r2l; break;
50     case mkldnn_bidirectional_concat: rnn.exec_dir = bi_concat; break;
51     case mkldnn_bidirectional_sum: rnn.exec_dir = bi_sum; break;
52     default: break;
53     }
54
55     if (everyone_is(f32, src_layer_d.data_type(), dst_layer_d.data_type(),
56                 weights_layer_d.data_type()))
57         rnn.dt_conf = all_f32;
58     else if (dst_layer_d.data_type() == u8) {
59         if (IMPLICATION(src_iter_d._md, src_iter_d.data_type() == u8))
60             rnn.dt_conf = u8u8u8u8;
61         else
62             rnn.dt_conf = f32u8f32u8;
63     } else {
64         if (IMPLICATION(src_iter_d._md, src_iter_d.data_type() == u8))
65             rnn.dt_conf = u8u8u8f32;
66         else
67             rnn.dt_conf = f32u8f32f32;
68     }
69
70     rnn.n_layer = weights_layer_d.dims()[0];
71     rnn.n_iter = src_layer_d.dims()[0];
72     rnn.n_dir = weights_layer_d.dims()[1];
73     rnn.n_gates = weights_layer_d.dims()[3];
74     rnn.n_states = mkldnn_rnn_cell_get_states_count(&rd.cell_desc);
75     rnn.n_bias = rnn.n_gates + rnn.is_lbr;
76     rnn.mb = src_layer_d.dims()[1];
77     rnn.sic = weights_iter_d.dims()[2];
78     rnn.slc = weights_layer_d.dims()[2];
79     rnn.dic = weights_layer_d.dims()[4];
80     rnn.dlc = dst_layer_d.dims()[2];
81
82     rnn.gates_ld = rnn.dic * rnn.n_gates;
83     rnn.gates_nld = rnn.mb;
84     rnn.states_nld = rnn.mb;
85
86     /* Set the correct number of weights parts */
87     bool is_orig_gru = rd.cell_desc.cell_kind == alg_kind::vanilla_gru;
88     rnn.n_parts_weights_layer = 1;
89     rnn.parts_weights_layer[0] = rnn.n_gates;
90     rnn.parts_weights_layer[1] = 0;
91
92     rnn.n_parts_weights_iter = is_orig_gru ? 2 : 1;
93     rnn.parts_weights_iter[0] = is_orig_gru ? 2 : rnn.n_gates;
94     rnn.parts_weights_iter[1] = is_orig_gru ? 1 : 0;
95
96     rnn.n_parts_bias = 1;
97     rnn.parts_bias[0] = rnn.n_bias;
98     rnn.parts_bias[1] = 0;
99
100     /* Decide wich gemm implementation to use: packed/nonpacked jit/cblas
101      * and if to mergre gemm across iterations */
102     bool is_int8 = rnn.dt_conf != all_f32;
103     rnn.merge_gemm_layer = ((rnn.is_fwd && rnn.mb < 128) || !rnn.is_fwd)
104             || is_int8;
105     bool is_gru = utils::one_of(rd.cell_desc.cell_kind, alg_kind::vanilla_gru,
106             alg_kind::gru_linear_before_reset);
107     rnn.merge_gemm_iter = !(rnn.is_fwd || is_gru) || is_int8;
108     bool is_inference = !rnn.is_training;
109
110     rnn.use_jit_gemm = !mayiuse(avx512_mic)
111             && ((is_inference && (rnn.n_layer > 1 || rnn.mb < 100))
112                 || (rnn.is_training && rnn.dic < 500));
113
114     /* Decide to copy bias */
115     rnn.copy_bias = rnn.dt_conf != all_f32;
116
117 #if USE_MKL_PACKED_GEMM
118     rnn.use_layer_packed_gemm
119             = (weights_layer_d.format() == any && rnn.slc > 760 && rnn.dic > 760
120                       && is_inference)
121             || is_int8; // packed gemm is the only supported option for int8
122     rnn.use_iter_packed_gemm = (weights_iter_d.format() == any && rnn.sic > 760
123                                        && rnn.dic > 760 && is_inference)
124             || is_int8;
125 #else
126     rnn.use_layer_packed_gemm = false;
127     rnn.use_iter_packed_gemm = false;
128 #endif
129
130     /* Set packed gemm sizes */
131     if (rnn.use_layer_packed_gemm) {
132         rnn.weights_layer_pack_size = 0;
133         for (int p = 0; p < rnn.n_parts_weights_layer; p++) {
134             int m_p = rnn.is_fwd
135                 ? (rnn.parts_weights_layer[p] * rnn.dic)
136                 : rnn.slc;
137             int k_p = rnn.is_fwd
138                 ? rnn.slc
139                 : (rnn.parts_weights_layer[p] * rnn.dic);
140             int n_p = rnn.merge_gemm_layer ? rnn.mb * rnn.n_iter : rnn.mb;
141
142 #if USE_MKL_PACKED_GEMM
143             if (rnn.dt_conf == all_f32)
144                 rnn.part_weights_layer_pack_size[p] = cblas_sgemm_pack_get_size(
145                         CblasAMatrix, m_p, n_p, k_p);
146             else
147                 rnn.part_weights_layer_pack_size[p]
148                         = cblas_gemm_s8u8s32_pack_get_size(
149                                 CblasAMatrix, m_p, n_p, k_p);
150 #else
151             UNUSED(m_p);
152             UNUSED(k_p);
153             UNUSED(n_p);
154             rnn.part_weights_layer_pack_size[p] = 0;
155 #endif
156             rnn.weights_layer_pack_size += rnn.n_layer * rnn.n_dir
157                     * rnn.part_weights_layer_pack_size[p];
158         }
159         rnn.weights_layer_comp_offset = rnn.weights_layer_pack_size;
160         rnn.weights_layer_pack_size += rnn.dt_conf == all_f32 ? 0 : rnn.n_layer
161                         * rnn.n_dir * rnn.n_gates * rnn.dlc * sizeof(float);
162     }
163
164     if (rnn.use_iter_packed_gemm) {
165         rnn.weights_iter_pack_size = 0;
166         for (int p = 0; p < rnn.n_parts_weights_iter; p++) {
167             int m_p = rnn.is_fwd ? (rnn.parts_weights_iter[p] * rnn.dic) :
168                                    rnn.sic;
169             int k_p = rnn.is_fwd ? rnn.sic :
170                                    (rnn.parts_weights_iter[p] * rnn.dic);
171             int n_p = rnn.merge_gemm_iter ? rnn.mb * rnn.n_iter : rnn.mb;
172
173 #if USE_MKL_PACKED_GEMM
174             if (rnn.dt_conf == all_f32)
175                 rnn.part_weights_iter_pack_size[p] = cblas_sgemm_pack_get_size(
176                         CblasAMatrix, m_p, n_p, k_p);
177             else
178                 rnn.part_weights_iter_pack_size[p]
179                         = cblas_gemm_s8u8s32_pack_get_size(
180                                 CblasAMatrix, m_p, n_p, k_p);
181 #else
182             UNUSED(m_p);
183             UNUSED(k_p);
184             UNUSED(n_p);
185             rnn.part_weights_iter_pack_size[p] = 0;
186 #endif
187             rnn.weights_iter_pack_size += rnn.n_layer * rnn.n_dir
188                     * rnn.part_weights_iter_pack_size[p];
189         }
190         rnn.weights_iter_comp_offset = rnn.weights_iter_pack_size;
191         rnn.weights_iter_pack_size += rnn.dt_conf == all_f32 ? 0 : rnn.n_layer
192                         * rnn.n_dir * rnn.n_gates * rnn.dic * sizeof(float);
193     }
194
195 }
196
197 void rnn_utils::set_conf(rnn_conf_t &rnn, const rnn_desc_t &rd,
198         const memory_desc_wrapper &weights_layer_d,
199         const memory_desc_wrapper &weights_iter_d,
200         const memory_desc_wrapper &diff_weights_layer_d,
201         const memory_desc_wrapper &diff_weights_iter_d) {
202
203     /* Set leading dimensions for input weights arrays depending on input format
204      */
205     rnn.weights_layer_fmt = weights_layer_d.format();
206     rnn.weights_iter_fmt = weights_iter_d.format();
207     rnn.weights_layer_is_packed = rnn.weights_layer_fmt == rnn_packed;
208     rnn.weights_iter_is_packed = rnn.weights_iter_fmt == rnn_packed;
209
210     auto set_dims = [&](const memory_desc_wrapper &md, int &ld, int &nld) {
211         switch (md.format()) {
212         case ldigo:
213             ld = (int)md.blocking_desc().strides[0][2];
214             nld = md.dims()[2];
215             return;
216         case ldgoi:
217             ld = (int)md.blocking_desc().strides[0][4];
218             nld = md.dims()[3] * md.dims()[4];
219             return;
220         default: ld = 0; nld = 0;
221         }
222     };
223     set_dims(weights_layer_d, rnn.weights_layer_ld, rnn.weights_layer_nld);
224     set_dims(weights_iter_d, rnn.weights_iter_ld, rnn.weights_iter_nld);
225     if (!rnn.is_fwd) {
226         set_dims(diff_weights_layer_d, rnn.diff_weights_layer_ld,
227                 rnn.diff_weights_layer_nld);
228         set_dims(diff_weights_iter_d, rnn.diff_weights_iter_ld,
229                 rnn.diff_weights_iter_nld);
230     }
231
232     int sizeof_states_dt
233             = rnn.dt_conf == all_f32 ? sizeof(float) : sizeof(uint8_t);
234     rnn.states_ws_ld
235             = get_good_ld(nstl::max(rnn.slc, nstl::max(rnn.sic, rnn.dic)),
236                 sizeof_states_dt);
237     rnn.gates_ws_ld = get_good_ld(rnn.gates_ld, sizeof(float));
238
239     /* Set workspace sizes to store:
240      * states to copmute a pass
241      * diff states to copmute bwd pass (training only)
242      * intermediate results from the gates
243      */
244     rnn.use_workspace = rnn.is_training;
245     rnn.ws_states_size = (size_t)(rnn.n_layer + 1) * rnn.n_dir
246             * (rnn.n_iter + 1) * rnn.mb * rnn.states_ws_ld * sizeof_states_dt;
247     bool is_lstm = rd.cell_desc.cell_kind == mkldnn_vanilla_lstm;
248     rnn.ws_c_states_size = is_lstm
249             ? (size_t)(rnn.n_layer + 1) * rnn.n_dir * (rnn.n_iter + 1) * rnn.mb
250                     * rnn.states_ws_ld * sizeof(float)
251             : 0;
252     rnn.ws_diff_states_size = rnn.is_training
253             ? (size_t)(rnn.n_layer + 1) * rnn.n_dir * (rnn.n_iter + 1)
254                     * (rnn.n_states + 1) * rnn.mb * rnn.states_ws_ld
255                     * sizeof(float)
256             : (size_t)0;
257     rnn.ws_gates_size = (size_t)rnn.n_layer * rnn.n_dir * rnn.n_iter * rnn.mb
258             * rnn.gates_ws_ld * sizeof(float);
259
260     /* set other sizes */
261     rnn.ws_per_cell = (size_t)rnn.is_lbr * rnn.mb * rnn.dic * sizeof(float);
262     rnn.ws_cell_comp_size
263             = rnn.is_lbr || rnn.dt_conf != all_f32
264                 ? (size_t) rnn.gates_nld * rnn.gates_ws_ld * sizeof(float)
265                 : 0;
266     rnn.ws_grid_comp_size = (size_t)rnn.is_lbr * rnn.is_training * rnn.n_layer
267             * rnn.n_dir * rnn.n_iter * rnn.ws_per_cell * sizeof(float);
268     rnn.ws_bias_size = (size_t)rnn.n_layer * rnn.n_dir * rnn.n_bias * rnn.dic
269             * sizeof(float);
270 }
271
272 int rnn_utils::get_good_ld(int dim, int sizeof_dt) {
273     // we want matrices leading dimentions to be 64-byte aligned,
274     // and not divisible by 256 to avoid 4K aliasing effects
275     int ld = rnd_up(dim, 64 / sizeof_dt);
276     return (ld % 256 == 0) ? ld + 64 / sizeof_dt : ld;
277 }
278
279 void rnn_utils::set_offsets(const rnn_conf_t &rnn, size_t &ws_gates_offset,
280         size_t &ws_states_offset, size_t &ws_c_states_offset,
281         size_t &ws_diff_states_offset, size_t &ws_grid_comp_offset,
282         size_t &ws_cell_comp_offset, size_t &ws_bias_offset,
283         size_t &scratchpad_size, size_t &workspace_size) {
284
285     const size_t page_size = 4096; // 2097152;
286     size_t current_offset;
287     /* Mandatory workspaces: go to workspace if use_workspace, scratchpad
288      * otherwise */
289     current_offset = 0; // assumes the workspace base pointer is page aligned
290     ws_gates_offset = current_offset;
291     current_offset += rnn.ws_gates_size;
292
293     current_offset = utils::rnd_up(current_offset, page_size);
294     ws_states_offset = current_offset;
295     current_offset += rnn.ws_states_size;
296
297     current_offset = utils::rnd_up(current_offset, page_size);
298     ws_c_states_offset = current_offset;
299     current_offset += rnn.ws_c_states_size;
300
301     current_offset = utils::rnd_up(current_offset, page_size);
302     ws_diff_states_offset = current_offset;
303     current_offset += rnn.ws_diff_states_size;
304
305     current_offset = utils::rnd_up(current_offset, page_size);
306     ws_grid_comp_offset = current_offset;
307     current_offset += rnn.ws_grid_comp_size;
308
309     current_offset = utils::rnd_up(current_offset, page_size);
310     ws_cell_comp_offset = current_offset;
311     current_offset += rnn.ws_cell_comp_size;
312
313     workspace_size = rnn.use_workspace ? current_offset : 0;
314
315     /* Optional scratchpads */
316     // Assumes the scratchpad base pointer is page aligned.
317     // If use_workspace, the following goes to scratchpad alone,
318     // otherwise, all goes to scratchpad and continue incrementing offset
319     current_offset = rnn.use_workspace ? 0 : current_offset;
320
321     if (rnn.copy_bias) {
322         current_offset = utils::rnd_up(current_offset, page_size);
323         ws_bias_offset = current_offset;
324         current_offset += rnn.ws_bias_size;
325     }
326
327     scratchpad_size = current_offset;
328 }
329
330 void rnn_utils::get_scratchpad_and_workspace_sizes(const rnn_conf_t &rnn,
331         size_t &scratchpad_size, size_t &workspace_size) {
332     size_t ws_gates_offset, ws_states_offset, ws_c_states_offset,
333             ws_diff_states_offset, ws_grid_comp_offset, ws_cell_comp_offset,
334             ws_bias_offset;
335     set_offsets(rnn, ws_gates_offset, ws_states_offset, ws_diff_states_offset,
336             ws_c_states_offset, ws_grid_comp_offset, ws_cell_comp_offset,
337             ws_bias_offset, scratchpad_size, workspace_size);
338 }
339
340 status_t rnn_utils::set_good_strides(memory_desc_t &weights_md) {
341     auto &strides = weights_md.layout_desc.blocking.strides[0];
342     auto dims = weights_md.dims;
343
344     if (weights_md.format == ldigo) {
345         strides[2] = rnn_utils::get_good_ld((int)strides[2],
346                 (int)types::data_type_size(weights_md.data_type));
347         strides[1] = dims[2] * strides[2];
348         strides[0] = dims[1] * strides[1];
349     } else if (weights_md.format == ldgoi) {
350         strides[4] = rnn_utils::get_good_ld((int)strides[4],
351                 (int)types::data_type_size(weights_md.data_type));
352         strides[3] = dims[4] * strides[4];
353         strides[1] = dims[3] * strides[3];
354         strides[0] = dims[1] * strides[1];
355     } else
356         return unimplemented;
357
358     return success;
359 }
360
361 status_t rnn_utils::set_expected_desc(rnn_conf_t &rnn,
362         memory_desc_t &weights_md, bool is_iter) {
363     bool use_packed_gemm = is_iter
364         ? rnn.use_iter_packed_gemm
365         : rnn.use_layer_packed_gemm;
366     if (use_packed_gemm) {
367         weights_md.format = rnn_packed;
368         rnn_packed_data_t &rnn_pdata = weights_md.layout_desc.rnn_packed_desc;
369         rnn_pdata.format = rnn.is_fwd ? mkldnn_ldigo_p : mkldnn_ldgoi_p;
370         if (is_iter) {
371             rnn_pdata.n = rnn.mb;
372             rnn_pdata.n_parts = rnn.n_parts_weights_iter;
373             array_copy(rnn_pdata.parts, rnn.parts_weights_iter,
374                     MKLDNN_RNN_MAX_N_PARTS);
375             array_copy(rnn_pdata.part_pack_size,
376                     rnn.part_weights_iter_pack_size, MKLDNN_RNN_MAX_N_PARTS);
377             rnn_pdata.offset_compensation = rnn.weights_iter_comp_offset;
378             rnn_pdata.size = rnn.weights_iter_pack_size;
379         } else {
380             rnn_pdata.n = rnn.merge_gemm_layer ? rnn.n_iter * rnn.mb : rnn.mb;
381             rnn_pdata.n_parts = rnn.n_parts_weights_layer;
382             array_copy(rnn_pdata.parts, rnn.parts_weights_layer,
383                     MKLDNN_RNN_MAX_N_PARTS);
384             array_copy(rnn_pdata.part_pack_size,
385                     rnn.part_weights_layer_pack_size, MKLDNN_RNN_MAX_N_PARTS);
386             rnn_pdata.offset_compensation = rnn.weights_layer_comp_offset;
387             rnn_pdata.size = rnn.weights_layer_pack_size;
388         }
389     } else {
390         weights_md.format = rnn.is_fwd ? ldigo : ldgoi;
391         CHECK(memory_desc_wrapper::compute_blocking(weights_md));
392         // Adjust strides for good leading dimension in GEMM
393         CHECK(set_good_strides(weights_md));
394     }
395     return success;
396 }
397
398 }
399 }
400 }