Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / rnn / rnn_reorders.hpp
1 /*******************************************************************************
2  * Copyright 2018 Intel Corporation
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  *******************************************************************************/
16
17 #ifndef CPU_RNN_REORDERS_HPP
18 #define CPU_RNN_REORDERS_HPP
19
20 #include <assert.h>
21
22 #include "type_helpers.hpp"
23 #include "mkldnn_thread.hpp"
24 #include "utils.hpp"
25 #include "simple_q10n.hpp"
26 #include "cpu_reorder_pd.hpp"
27 #include "../gemm/os_blas.hpp"
28
29 namespace mkldnn {
30 namespace impl {
31 namespace cpu {
32
33 template <data_type_t type_i, data_type_t type_o>
34 struct rnn_data_reorder_t : public cpu_primitive_t {
35     struct pd_t : public cpu_reorder_pd_t {
36         pd_t(const cpu_memory_pd_t *input_pd, const cpu_memory_pd_t *output_pd,
37                 const primitive_attr_t *attr)
38             : cpu_reorder_pd_t(input_pd, output_pd, attr) {}
39
40         DECLARE_COMMON_PD_T("rnn_data_reorder", rnn_data_reorder_t);
41
42         static status_t create(reorder_pd_t **reorder_pd,
43                 const memory_pd_t *input_pd, const memory_pd_t *output_pd,
44                 const primitive_attr_t *attr) {
45             using namespace memory_format;
46             using namespace data_type;
47             assert(input_pd->engine()->kind() == engine_kind::cpu);
48             assert(output_pd->engine()->kind() == engine_kind::cpu);
49
50             const memory_desc_wrapper id(input_pd), od(output_pd);
51             bool args_ok = true
52                     && id.data_type() == type_i
53                     && od.data_type() == type_o
54                     && utils::one_of(id.format(), tnc, ldsnc)
55                     && od.format() == id.format();
56             if (!args_ok) return status::invalid_arguments;
57
58             auto _pd = new pd_t((const cpu_memory_pd_t *)input_pd,
59                     (const cpu_memory_pd_t *)output_pd, attr);
60             if (_pd == nullptr) return out_of_memory;
61             if (_pd->init() != success) { delete _pd; return unimplemented; }
62             return safe_ptr_assign<reorder_pd_t>(*reorder_pd, _pd);
63         }
64     };
65
66 private:
67     typedef typename prec_traits<type_i>::type in_data_t;
68     typedef typename prec_traits<type_o>::type out_data_t;
69
70     rnn_data_reorder_t(const pd_t *apd, const input_vector &inputs,
71             const output_vector &outputs)
72         : cpu_primitive_t(apd, inputs, outputs) {}
73
74     virtual void execute(event_t *e) const {
75         auto input = reinterpret_cast<const in_data_t *>(input_memory(0));
76         auto output = reinterpret_cast<out_data_t *>(memory());
77         const memory_desc_wrapper &input_d = pd()->input_pd();
78         const memory_desc_wrapper &output_d = pd()->output_pd();
79         const round_mode_t rmode = pd()->attr()->round_mode_;
80         const size_t nelems = input_d.nelems();
81         const float scale = pd()->attr()->rnn_data_qparams_.scale_;
82         const float shift = pd()->attr()->rnn_data_qparams_.shift_;
83
84         parallel_nd(nelems, [&](size_t i) {
85             float in = (float)input[input_d.off_l(i)] * scale + shift;
86             output[output_d.off_l(i)] = qz_a1b0<float, out_data_t>()(in, rmode);
87         });
88
89         e->set_state(event_t::ready);
90     }
91
92     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
93 };
94
95 template <data_type_t type_i, data_type_t type_o>
96 struct rnn_weights_reorder_t : public cpu_primitive_t {
97     struct pd_t : public cpu_reorder_pd_t {
98         pd_t(const cpu_memory_pd_t *input_pd, const cpu_memory_pd_t *output_pd,
99                 const primitive_attr_t *attr)
100             : cpu_reorder_pd_t(input_pd, output_pd, attr) {}
101
102         DECLARE_COMMON_PD_T("rnn_weights_reorder", rnn_weights_reorder_t);
103
104         static status_t create(reorder_pd_t **reorder_pd,
105                 const memory_pd_t *input_pd, const memory_pd_t *output_pd,
106                 const primitive_attr_t *attr) {
107 #if !USE_MKL_PACKED_GEMM
108             return status::unimplemented;
109 #endif
110             using namespace memory_format;
111             assert(input_pd->engine()->kind() == engine_kind::cpu);
112             assert(output_pd->engine()->kind() == engine_kind::cpu);
113             const memory_desc_wrapper output_d(output_pd);
114
115             const memory_desc_wrapper id(input_pd), od(output_pd);
116             bool args_ok = true
117                     && id.data_type() == type_i
118                     && od.data_type() == type_o
119                     && utils::one_of(id.format(), ldigo, ldgoi)
120                     && od.format() == rnn_packed
121                     && od.rnn_packed_desc().format
122                             == mkldnn_ldigo_p
123                     && od.rnn_packed_desc().n_parts == 1
124                     && attr != nullptr;
125             if (!args_ok) return status::invalid_arguments;
126
127             const int mask = attr->rnn_weights_qparams_.mask_;
128             if (!utils::one_of(mask, 0, 3)) return status::unimplemented;
129
130             auto _pd = new pd_t((const cpu_memory_pd_t *)input_pd,
131                     (const cpu_memory_pd_t *)output_pd, attr);
132             if (_pd == nullptr) return out_of_memory;
133             if (_pd->init() != success) { delete _pd; return unimplemented; }
134             return safe_ptr_assign<reorder_pd_t>(*reorder_pd, _pd);
135         }
136
137         virtual status_t init() override {
138             status_t status = cpu_reorder_pd_t::init();
139             if (status != status::success) return status;
140
141             init_scratchpad();
142
143             return status::success;
144         }
145
146     private:
147         void init_scratchpad() {
148             const memory_desc_wrapper id(input_pd());
149             const size_t nelems = id.nelems();
150             const auto &dims = id.dims();
151
152             using namespace memory_tracking::names;
153             auto scratchpad = scratchpad_registry().registrar();
154             size_t quantization_size = sizeof(int8_t) * nelems;
155             size_t reduction_size = id.format() == ldigo
156                     ? sizeof(int32_t) * mkldnn_get_max_threads() * dims[0]
157                             * dims[1] * dims[3] * dims[4]
158                     : 0;
159             scratchpad.book(
160                     key_reorder_rnn_weights_quantization, quantization_size);
161             scratchpad.book(key_reorder_rnn_weights_reduction, reduction_size);
162         }
163     };
164
165 private:
166     typedef typename prec_traits<type_i>::type in_data_t;
167     typedef typename prec_traits<type_o>::type out_data_t;
168
169     rnn_weights_reorder_t(const pd_t *apd, const input_vector &inputs,
170             const output_vector &outputs)
171         : cpu_primitive_t(apd, inputs, outputs) {}
172
173     virtual void execute(event_t *e) const {
174 #if USE_MKL_PACKED_GEMM
175         auto input = reinterpret_cast<const in_data_t *>(input_memory(0));
176         auto output = reinterpret_cast<char *>(memory());
177         const memory_desc_wrapper &input_d = pd()->input_pd();
178         const memory_desc_wrapper &output_d = pd()->output_pd();
179         const auto &dims = input_d.dims();
180
181         const int L = dims[0];
182         const int D = dims[1];
183         const int I = dims[2];
184         const int G = dims[3];
185         const int O = dims[4];
186
187         const bool is_igo = input_d.format() == memory_format::ldigo;
188
189         /* Quantize input & compute compensation */
190         auto quantized = (int8_t * __restrict)scratchpad().template get<void>(
191                 memory_tracking::names::key_reorder_rnn_weights_quantization);
192         auto reduction = (int32_t * __restrict)scratchpad().template get<void>(
193                 memory_tracking::names::key_reorder_rnn_weights_reduction);
194         float *comp = reinterpret_cast<float *>(
195                 output + output_d.rnn_packed_desc().offset_compensation);
196         const round_mode_t rmode = pd()->attr()->round_mode_;
197         const float *scales = pd()->attr()->rnn_weights_qparams_.scales_;
198         const int mask = pd()->attr()->rnn_weights_qparams_.mask_;
199
200         if (is_igo) {
201             int nthr = mkldnn_get_max_threads();
202             int LD_nthr = nstl::min(L * D, nthr);
203             int I_nthr = nstl::min(I, nthr / LD_nthr);
204             parallel(nthr, [&](const int ithr, const int nthr) {
205                 int LD_ithr = -1, LD_s = -1, LD_e = -1;
206                 int I_ithr = -1, I_s = -1, I_e = -1;
207                 if (ithr < LD_nthr * I_nthr) {
208                     LD_ithr = ithr % LD_nthr;
209                     I_ithr = ithr / LD_nthr;
210                     balance211(L * D, LD_nthr, LD_ithr, LD_s, LD_e);
211                     balance211(I, I_nthr, I_ithr, I_s, I_e);
212                 }
213                 int32_t *comp_ithr = reduction + I_ithr * L * D * G * O;
214                 for (int ld = LD_s; ld < LD_e; ld++) {
215                     for (int go = 0; go < G * O; go++)
216                         comp_ithr[ld * G * O + go] = 0;
217                     for (int i = I_s; i < I_e; i++) {
218                         PRAGMA_OMP_SIMD()
219                         for (int go = 0; go < G * O; go++) {
220                             const float s = scales[(mask == 0) ? 0 : go];
221                             int8_t q = qz_b0<in_data_t, out_data_t>()(
222                                     input[ld * I * G * O + i * G * O + go], s,
223                                     rmode);
224                             quantized[ld * I * G * O + i * G * O + go]
225                                     = (int32_t)q;
226                             comp_ithr[ld * G * O + go] += (int32_t)q;
227                         }
228                     }
229                 }
230             });
231             parallel_nd(L * D * G * O,
232                     [&](int s) { comp[s] = saturate<float>(reduction[s]); });
233             for (int i = 1; i < I_nthr; i++) {
234                 parallel_nd(L * D * G * O, [&](int s) {
235                     comp[s] += saturate<float>(
236                             reduction[i * L * D * G * O + s]);
237                 });
238             }
239         } else {
240             parallel_nd(L * D, G * O, [&](int ld, int go) {
241                 int32_t compensation = 0;
242                 const float s = scales[(mask == 0) ? 0 : go];
243                 PRAGMA_OMP_SIMD()
244                 for (int i = 0; i < I; i++) {
245                     int8_t q = qz_b0<in_data_t, out_data_t>()(
246                             input[ld * G * O * I + go * I + i], s, rmode);
247                     compensation += (int32_t)q;
248                     quantized[ld * G * O * I + go * I + i] = q;
249                 }
250                 comp[ld * G * O + go] = saturate<float>(compensation);
251             });
252         }
253
254         /* Pack */
255         auto off_igo = [&](int l, int d, int i, int g, int o) {
256             return l * D * I * G * O + d * I * G * O + i * G * O + g * O + o;
257         };
258         auto off_goi = [&](int l, int d, int i, int g, int o) {
259             return l * D * G * O * I + d * G * O * I + g * O * I + o * I + i;
260         };
261         int n_parts = output_d.rnn_packed_desc().n_parts;
262         const size_t *size_packed_cell
263                 = output_d.rnn_packed_desc().part_pack_size;
264         const int *parts = output_d.rnn_packed_desc().parts;
265         const int n = output_d.rnn_packed_desc().n;
266         char *to_pack = output;
267         for (int l = 0; l < L; l++) {
268             for (int d = 0; d < D; d++) {
269                 for (int p = 0; p < n_parts; p++) {
270                     int g = (p > 0) ? parts[p - 1] : 0;
271                     int m_p = parts[p] * O;
272                     int k_p = I;
273                     cblas_gemm_s8u8s32_pack(CblasColMajor, CblasAMatrix,
274                             is_igo ? CblasNoTrans : CblasTrans, m_p, n, k_p,
275                             &quantized[is_igo ? off_igo(l, d, 0, g, 0) :
276                                                 off_goi(l, d, g, 0, 0)],
277                             is_igo ? G * O : I, to_pack);
278                     to_pack += size_packed_cell[p];
279                 }
280             }
281         }
282 #endif
283         e->set_state(event_t::ready);
284     }
285
286     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
287 };
288
289 template <>
290 struct rnn_weights_reorder_t<data_type::f32, data_type::f32>
291         : public cpu_primitive_t {
292     struct pd_t : public cpu_reorder_pd_t {
293         pd_t(const cpu_memory_pd_t *input_pd, const cpu_memory_pd_t *output_pd,
294                 const primitive_attr_t *attr)
295             : cpu_reorder_pd_t(input_pd, output_pd, attr) {}
296
297         DECLARE_COMMON_PD_T("rnn_weights_reorder", rnn_weights_reorder_t);
298
299         static status_t create(reorder_pd_t **reorder_pd,
300                 const memory_pd_t *input_pd, const memory_pd_t *output_pd,
301                 const primitive_attr_t *attr) {
302 #if !USE_MKL_PACKED_GEMM
303             return status::unimplemented;
304 #endif
305             using namespace memory_format;
306             using namespace data_type;
307             assert(input_pd->engine()->kind() == engine_kind::cpu);
308             assert(output_pd->engine()->kind() == engine_kind::cpu);
309             const memory_desc_wrapper output_d(output_pd);
310
311             const memory_desc_wrapper id(input_pd), od(output_pd);
312             bool args_ok = true
313                     && id.data_type() == f32
314                     && od.data_type() == f32
315                     && utils::one_of(id.format(), ldigo, ldgoi)
316                     && od.format() == rnn_packed
317                     && utils::one_of(od.rnn_packed_desc().format,
318                         mkldnn_ldigo_p, mkldnn_ldgoi_p)
319                     && attr->has_default_values();
320             if (!args_ok) return status::invalid_arguments;
321
322             const int mask = attr->rnn_weights_qparams_.mask_;
323             if (!utils::one_of(mask, 0, 3)) return status::unimplemented;
324
325             auto _pd = new pd_t((const cpu_memory_pd_t *)input_pd,
326                     (const cpu_memory_pd_t *)output_pd, attr);
327             if (_pd == nullptr) return out_of_memory;
328             if (_pd->init() != success) { delete _pd; return unimplemented; }
329             return safe_ptr_assign<reorder_pd_t>(*reorder_pd, _pd);
330         }
331     };
332
333 private:
334     rnn_weights_reorder_t(const pd_t *apd, const input_vector &inputs,
335             const output_vector &outputs)
336         : cpu_primitive_t(apd, inputs, outputs) {}
337
338     virtual void execute(event_t *e) const {
339 #if USE_MKL_PACKED_GEMM
340         auto input = reinterpret_cast<const float *>(input_memory(0));
341         auto output = reinterpret_cast<float *>(memory());
342         const memory_desc_wrapper &input_d = pd()->input_pd();
343         const memory_desc_wrapper &output_d = pd()->output_pd();
344         const auto &dims = input_d.dims();
345         const rnn_packed_data_t &rnn_pdata = output_d.rnn_packed_desc();
346         const int L = dims[0];
347         const int D = dims[1];
348         const int I = dims[2];
349         const int G = dims[3];
350         const int O = dims[4];
351
352         /* Pack */
353         bool cross_case = (input_d.format() == memory_format::ldigo
354                         && rnn_pdata.format == mkldnn_ldgoi_p)
355                 || (input_d.format() == memory_format::ldgoi
356                         && rnn_pdata.format == mkldnn_ldigo_p);
357         auto trans = cross_case ? CblasTrans : CblasNoTrans;
358         int n_parts = rnn_pdata.n_parts;
359         const size_t *size_packed_cell = rnn_pdata.part_pack_size;
360         const int *parts = rnn_pdata.parts;
361         const int n = rnn_pdata.n;
362
363         const bool is_igo = input_d.format() == memory_format::ldigo;
364         auto off_igo = [&](int l, int d, int i, int g, int o) {
365             return l * D * I * G * O + d * I * G * O + i * G * O + g * O + o;
366         };
367         auto off_goi = [&](int l, int d, int i, int g, int o) {
368             return l * D * G * O * I + d * G * O * I + g * O * I + o * I + i;
369         };
370         for (int l = 0; l < L; l++) {
371             for (int d = 0; d < D; d++) {
372                 for (int p = 0; p < n_parts; p++) {
373                     int g = (p > 0) ? parts[p - 1] : 0;
374                     int m_p = is_igo ? parts[p] * O : I;
375                     int k_p = is_igo ? I : parts[p] * O;
376                     int ld = is_igo ? G * O : I;
377                     cblas_sgemm_pack(CblasColMajor, CblasAMatrix, trans, m_p, n,
378                             k_p, 1.0f, &input[is_igo ? off_igo(l, d, 0, g, 0) :
379                                                        off_goi(l, d, 0, g, 0)],
380                             ld, output);
381                     output += size_packed_cell[p] / sizeof(float);
382                 }
383             }
384         }
385         e->set_state(event_t::ready);
386 #endif
387     }
388
389     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
390 };
391
392 } // namespace cpu
393 } // namespace impl
394 } // namespace mkldnn
395
396 #endif