1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef CPU_RNN_REORDERS_HPP
18 #define CPU_RNN_REORDERS_HPP
22 #include "type_helpers.hpp"
23 #include "mkldnn_thread.hpp"
25 #include "simple_q10n.hpp"
26 #include "cpu_reorder_pd.hpp"
27 #include "../gemm/os_blas.hpp"
33 template <data_type_t type_i, data_type_t type_o>
34 struct rnn_data_reorder_t : public cpu_primitive_t {
35 struct pd_t : public cpu_reorder_pd_t {
36 pd_t(const cpu_memory_pd_t *input_pd, const cpu_memory_pd_t *output_pd,
37 const primitive_attr_t *attr)
38 : cpu_reorder_pd_t(input_pd, output_pd, attr) {}
40 DECLARE_COMMON_PD_T("rnn_data_reorder", rnn_data_reorder_t);
42 static status_t create(reorder_pd_t **reorder_pd,
43 const memory_pd_t *input_pd, const memory_pd_t *output_pd,
44 const primitive_attr_t *attr) {
45 using namespace memory_format;
46 using namespace data_type;
47 assert(input_pd->engine()->kind() == engine_kind::cpu);
48 assert(output_pd->engine()->kind() == engine_kind::cpu);
50 const memory_desc_wrapper id(input_pd), od(output_pd);
52 && id.data_type() == type_i
53 && od.data_type() == type_o
54 && utils::one_of(id.format(), tnc, ldsnc)
55 && od.format() == id.format();
56 if (!args_ok) return status::invalid_arguments;
58 auto _pd = new pd_t((const cpu_memory_pd_t *)input_pd,
59 (const cpu_memory_pd_t *)output_pd, attr);
60 if (_pd == nullptr) return out_of_memory;
61 if (_pd->init() != success) { delete _pd; return unimplemented; }
62 return safe_ptr_assign<reorder_pd_t>(*reorder_pd, _pd);
67 typedef typename prec_traits<type_i>::type in_data_t;
68 typedef typename prec_traits<type_o>::type out_data_t;
70 rnn_data_reorder_t(const pd_t *apd, const input_vector &inputs,
71 const output_vector &outputs)
72 : cpu_primitive_t(apd, inputs, outputs) {}
74 virtual void execute(event_t *e) const {
75 auto input = reinterpret_cast<const in_data_t *>(input_memory(0));
76 auto output = reinterpret_cast<out_data_t *>(memory());
77 const memory_desc_wrapper &input_d = pd()->input_pd();
78 const memory_desc_wrapper &output_d = pd()->output_pd();
79 const round_mode_t rmode = pd()->attr()->round_mode_;
80 const size_t nelems = input_d.nelems();
81 const float scale = pd()->attr()->rnn_data_qparams_.scale_;
82 const float shift = pd()->attr()->rnn_data_qparams_.shift_;
84 parallel_nd(nelems, [&](size_t i) {
85 float in = (float)input[input_d.off_l(i)] * scale + shift;
86 output[output_d.off_l(i)] = qz_a1b0<float, out_data_t>()(in, rmode);
89 e->set_state(event_t::ready);
92 const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
95 template <data_type_t type_i, data_type_t type_o>
96 struct rnn_weights_reorder_t : public cpu_primitive_t {
97 struct pd_t : public cpu_reorder_pd_t {
98 pd_t(const cpu_memory_pd_t *input_pd, const cpu_memory_pd_t *output_pd,
99 const primitive_attr_t *attr)
100 : cpu_reorder_pd_t(input_pd, output_pd, attr) {}
102 DECLARE_COMMON_PD_T("rnn_weights_reorder", rnn_weights_reorder_t);
104 static status_t create(reorder_pd_t **reorder_pd,
105 const memory_pd_t *input_pd, const memory_pd_t *output_pd,
106 const primitive_attr_t *attr) {
107 #if !USE_MKL_PACKED_GEMM
108 return status::unimplemented;
110 using namespace memory_format;
111 assert(input_pd->engine()->kind() == engine_kind::cpu);
112 assert(output_pd->engine()->kind() == engine_kind::cpu);
113 const memory_desc_wrapper output_d(output_pd);
115 const memory_desc_wrapper id(input_pd), od(output_pd);
117 && id.data_type() == type_i
118 && od.data_type() == type_o
119 && utils::one_of(id.format(), ldigo, ldgoi)
120 && od.format() == rnn_packed
121 && od.rnn_packed_desc().format
123 && od.rnn_packed_desc().n_parts == 1
125 if (!args_ok) return status::invalid_arguments;
127 const int mask = attr->rnn_weights_qparams_.mask_;
128 if (!utils::one_of(mask, 0, 3)) return status::unimplemented;
130 auto _pd = new pd_t((const cpu_memory_pd_t *)input_pd,
131 (const cpu_memory_pd_t *)output_pd, attr);
132 if (_pd == nullptr) return out_of_memory;
133 if (_pd->init() != success) { delete _pd; return unimplemented; }
134 return safe_ptr_assign<reorder_pd_t>(*reorder_pd, _pd);
137 virtual status_t init() override {
138 status_t status = cpu_reorder_pd_t::init();
139 if (status != status::success) return status;
143 return status::success;
147 void init_scratchpad() {
148 const memory_desc_wrapper id(input_pd());
149 const size_t nelems = id.nelems();
150 const auto &dims = id.dims();
152 using namespace memory_tracking::names;
153 auto scratchpad = scratchpad_registry().registrar();
154 size_t quantization_size = sizeof(int8_t) * nelems;
155 size_t reduction_size = id.format() == ldigo
156 ? sizeof(int32_t) * mkldnn_get_max_threads() * dims[0]
157 * dims[1] * dims[3] * dims[4]
160 key_reorder_rnn_weights_quantization, quantization_size);
161 scratchpad.book(key_reorder_rnn_weights_reduction, reduction_size);
166 typedef typename prec_traits<type_i>::type in_data_t;
167 typedef typename prec_traits<type_o>::type out_data_t;
169 rnn_weights_reorder_t(const pd_t *apd, const input_vector &inputs,
170 const output_vector &outputs)
171 : cpu_primitive_t(apd, inputs, outputs) {}
173 virtual void execute(event_t *e) const {
174 #if USE_MKL_PACKED_GEMM
175 auto input = reinterpret_cast<const in_data_t *>(input_memory(0));
176 auto output = reinterpret_cast<char *>(memory());
177 const memory_desc_wrapper &input_d = pd()->input_pd();
178 const memory_desc_wrapper &output_d = pd()->output_pd();
179 const auto &dims = input_d.dims();
181 const int L = dims[0];
182 const int D = dims[1];
183 const int I = dims[2];
184 const int G = dims[3];
185 const int O = dims[4];
187 const bool is_igo = input_d.format() == memory_format::ldigo;
189 /* Quantize input & compute compensation */
190 auto quantized = (int8_t * __restrict)scratchpad().template get<void>(
191 memory_tracking::names::key_reorder_rnn_weights_quantization);
192 auto reduction = (int32_t * __restrict)scratchpad().template get<void>(
193 memory_tracking::names::key_reorder_rnn_weights_reduction);
194 float *comp = reinterpret_cast<float *>(
195 output + output_d.rnn_packed_desc().offset_compensation);
196 const round_mode_t rmode = pd()->attr()->round_mode_;
197 const float *scales = pd()->attr()->rnn_weights_qparams_.scales_;
198 const int mask = pd()->attr()->rnn_weights_qparams_.mask_;
201 int nthr = mkldnn_get_max_threads();
202 int LD_nthr = nstl::min(L * D, nthr);
203 int I_nthr = nstl::min(I, nthr / LD_nthr);
204 parallel(nthr, [&](const int ithr, const int nthr) {
205 int LD_ithr = -1, LD_s = -1, LD_e = -1;
206 int I_ithr = -1, I_s = -1, I_e = -1;
207 if (ithr < LD_nthr * I_nthr) {
208 LD_ithr = ithr % LD_nthr;
209 I_ithr = ithr / LD_nthr;
210 balance211(L * D, LD_nthr, LD_ithr, LD_s, LD_e);
211 balance211(I, I_nthr, I_ithr, I_s, I_e);
213 int32_t *comp_ithr = reduction + I_ithr * L * D * G * O;
214 for (int ld = LD_s; ld < LD_e; ld++) {
215 for (int go = 0; go < G * O; go++)
216 comp_ithr[ld * G * O + go] = 0;
217 for (int i = I_s; i < I_e; i++) {
219 for (int go = 0; go < G * O; go++) {
220 const float s = scales[(mask == 0) ? 0 : go];
221 int8_t q = qz_b0<in_data_t, out_data_t>()(
222 input[ld * I * G * O + i * G * O + go], s,
224 quantized[ld * I * G * O + i * G * O + go]
226 comp_ithr[ld * G * O + go] += (int32_t)q;
231 parallel_nd(L * D * G * O,
232 [&](int s) { comp[s] = saturate<float>(reduction[s]); });
233 for (int i = 1; i < I_nthr; i++) {
234 parallel_nd(L * D * G * O, [&](int s) {
235 comp[s] += saturate<float>(
236 reduction[i * L * D * G * O + s]);
240 parallel_nd(L * D, G * O, [&](int ld, int go) {
241 int32_t compensation = 0;
242 const float s = scales[(mask == 0) ? 0 : go];
244 for (int i = 0; i < I; i++) {
245 int8_t q = qz_b0<in_data_t, out_data_t>()(
246 input[ld * G * O * I + go * I + i], s, rmode);
247 compensation += (int32_t)q;
248 quantized[ld * G * O * I + go * I + i] = q;
250 comp[ld * G * O + go] = saturate<float>(compensation);
255 auto off_igo = [&](int l, int d, int i, int g, int o) {
256 return l * D * I * G * O + d * I * G * O + i * G * O + g * O + o;
258 auto off_goi = [&](int l, int d, int i, int g, int o) {
259 return l * D * G * O * I + d * G * O * I + g * O * I + o * I + i;
261 int n_parts = output_d.rnn_packed_desc().n_parts;
262 const size_t *size_packed_cell
263 = output_d.rnn_packed_desc().part_pack_size;
264 const int *parts = output_d.rnn_packed_desc().parts;
265 const int n = output_d.rnn_packed_desc().n;
266 char *to_pack = output;
267 for (int l = 0; l < L; l++) {
268 for (int d = 0; d < D; d++) {
269 for (int p = 0; p < n_parts; p++) {
270 int g = (p > 0) ? parts[p - 1] : 0;
271 int m_p = parts[p] * O;
273 cblas_gemm_s8u8s32_pack(CblasColMajor, CblasAMatrix,
274 is_igo ? CblasNoTrans : CblasTrans, m_p, n, k_p,
275 &quantized[is_igo ? off_igo(l, d, 0, g, 0) :
276 off_goi(l, d, g, 0, 0)],
277 is_igo ? G * O : I, to_pack);
278 to_pack += size_packed_cell[p];
283 e->set_state(event_t::ready);
286 const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
290 struct rnn_weights_reorder_t<data_type::f32, data_type::f32>
291 : public cpu_primitive_t {
292 struct pd_t : public cpu_reorder_pd_t {
293 pd_t(const cpu_memory_pd_t *input_pd, const cpu_memory_pd_t *output_pd,
294 const primitive_attr_t *attr)
295 : cpu_reorder_pd_t(input_pd, output_pd, attr) {}
297 DECLARE_COMMON_PD_T("rnn_weights_reorder", rnn_weights_reorder_t);
299 static status_t create(reorder_pd_t **reorder_pd,
300 const memory_pd_t *input_pd, const memory_pd_t *output_pd,
301 const primitive_attr_t *attr) {
302 #if !USE_MKL_PACKED_GEMM
303 return status::unimplemented;
305 using namespace memory_format;
306 using namespace data_type;
307 assert(input_pd->engine()->kind() == engine_kind::cpu);
308 assert(output_pd->engine()->kind() == engine_kind::cpu);
309 const memory_desc_wrapper output_d(output_pd);
311 const memory_desc_wrapper id(input_pd), od(output_pd);
313 && id.data_type() == f32
314 && od.data_type() == f32
315 && utils::one_of(id.format(), ldigo, ldgoi)
316 && od.format() == rnn_packed
317 && utils::one_of(od.rnn_packed_desc().format,
318 mkldnn_ldigo_p, mkldnn_ldgoi_p)
319 && attr->has_default_values();
320 if (!args_ok) return status::invalid_arguments;
322 const int mask = attr->rnn_weights_qparams_.mask_;
323 if (!utils::one_of(mask, 0, 3)) return status::unimplemented;
325 auto _pd = new pd_t((const cpu_memory_pd_t *)input_pd,
326 (const cpu_memory_pd_t *)output_pd, attr);
327 if (_pd == nullptr) return out_of_memory;
328 if (_pd->init() != success) { delete _pd; return unimplemented; }
329 return safe_ptr_assign<reorder_pd_t>(*reorder_pd, _pd);
334 rnn_weights_reorder_t(const pd_t *apd, const input_vector &inputs,
335 const output_vector &outputs)
336 : cpu_primitive_t(apd, inputs, outputs) {}
338 virtual void execute(event_t *e) const {
339 #if USE_MKL_PACKED_GEMM
340 auto input = reinterpret_cast<const float *>(input_memory(0));
341 auto output = reinterpret_cast<float *>(memory());
342 const memory_desc_wrapper &input_d = pd()->input_pd();
343 const memory_desc_wrapper &output_d = pd()->output_pd();
344 const auto &dims = input_d.dims();
345 const rnn_packed_data_t &rnn_pdata = output_d.rnn_packed_desc();
346 const int L = dims[0];
347 const int D = dims[1];
348 const int I = dims[2];
349 const int G = dims[3];
350 const int O = dims[4];
353 bool cross_case = (input_d.format() == memory_format::ldigo
354 && rnn_pdata.format == mkldnn_ldgoi_p)
355 || (input_d.format() == memory_format::ldgoi
356 && rnn_pdata.format == mkldnn_ldigo_p);
357 auto trans = cross_case ? CblasTrans : CblasNoTrans;
358 int n_parts = rnn_pdata.n_parts;
359 const size_t *size_packed_cell = rnn_pdata.part_pack_size;
360 const int *parts = rnn_pdata.parts;
361 const int n = rnn_pdata.n;
363 const bool is_igo = input_d.format() == memory_format::ldigo;
364 auto off_igo = [&](int l, int d, int i, int g, int o) {
365 return l * D * I * G * O + d * I * G * O + i * G * O + g * O + o;
367 auto off_goi = [&](int l, int d, int i, int g, int o) {
368 return l * D * G * O * I + d * G * O * I + g * O * I + o * I + i;
370 for (int l = 0; l < L; l++) {
371 for (int d = 0; d < D; d++) {
372 for (int p = 0; p < n_parts; p++) {
373 int g = (p > 0) ? parts[p - 1] : 0;
374 int m_p = is_igo ? parts[p] * O : I;
375 int k_p = is_igo ? I : parts[p] * O;
376 int ld = is_igo ? G * O : I;
377 cblas_sgemm_pack(CblasColMajor, CblasAMatrix, trans, m_p, n,
378 k_p, 1.0f, &input[is_igo ? off_igo(l, d, 0, g, 0) :
379 off_goi(l, d, 0, g, 0)],
381 output += size_packed_cell[p] / sizeof(float);
385 e->set_state(event_t::ready);
389 const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
394 } // namespace mkldnn