Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / common / rnn.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "mkldnn.h"
18
19 #include "c_types_map.hpp"
20 #include "type_helpers.hpp"
21 #include "utils.hpp"
22 #include "cpu/gemm/os_blas.hpp"
23
24 using namespace mkldnn::impl;
25 using namespace mkldnn::impl::status;
26 using namespace mkldnn::impl::types;
27 using namespace mkldnn::impl::utils;
28
29 namespace {
30 memory_desc_t copy_maybe_null(const memory_desc_t *md) {
31     return md ? *md : zero_md();
32 }
33
34 rnn_desc_t zero_rnn_desc() {
35     auto rd = rnn_desc_t();
36     rd.src_layer_desc = zero_md();
37     rd.src_iter_desc = zero_md();
38     rd.weights_layer_desc = zero_md();
39     rd.weights_iter_desc = zero_md();
40     rd.bias_desc = zero_md();
41     rd.dst_layer_desc = zero_md();
42     rd.dst_iter_desc = zero_md();
43     rd.diff_src_layer_desc = zero_md();
44     rd.diff_src_iter_desc = zero_md();
45     rd.diff_weights_layer_desc = zero_md();
46     rd.diff_weights_iter_desc = zero_md();
47     rd.diff_bias_desc = zero_md();
48     rd.diff_dst_layer_desc = zero_md();
49     rd.diff_dst_iter_desc = zero_md();
50     return rd;
51 }
52 }
53
54 /* Public C Api */
55
56 status_t mkldnn_rnn_cell_desc_init(rnn_cell_desc_t *rnn_cell_desc,
57         mkldnn_alg_kind_t cell_kind, mkldnn_alg_kind_t act_f,
58         unsigned int flags, float alpha, float clipping) {
59     using namespace mkldnn::impl::alg_kind;
60
61     bool args_ok = true
62             && one_of(cell_kind, vanilla_rnn, vanilla_lstm, vanilla_gru,
63                     gru_linear_before_reset)
64             && IMPLICATION(cell_kind == vanilla_rnn,
65                     one_of(act_f, eltwise_relu, eltwise_tanh, eltwise_logistic));
66     if (!args_ok)
67         return invalid_arguments;
68
69     auto rcd = mkldnn_rnn_cell_desc_t();
70
71     rcd.cell_kind = cell_kind;
72     rcd.activation_kind = act_f;
73     rcd.flags = flags;
74     rcd.alpha = rcd.flags & mkldnn_rnn_cell_with_relu ? alpha : 0;
75     rcd.clipping = rcd.flags & mkldnn_rnn_cell_with_clipping ? clipping : 0;
76
77     *rnn_cell_desc = rcd;
78
79     return success;
80 }
81
82 int mkldnn_rnn_cell_get_gates_count(const rnn_cell_desc_t *rnn_cell_desc) {
83     switch (rnn_cell_desc->cell_kind) {
84     case mkldnn::impl::alg_kind::vanilla_rnn: return 1;
85     case mkldnn::impl::alg_kind::vanilla_gru: return 3;
86     case mkldnn::impl::alg_kind::gru_linear_before_reset: return 3;
87     case mkldnn::impl::alg_kind::vanilla_lstm: return 4;
88     default: assert(!"unknown cell kind"); return 0;
89     }
90     return 0;
91 }
92
93 int mkldnn_rnn_cell_get_states_count(const rnn_cell_desc_t *rnn_cell_desc) {
94     switch (rnn_cell_desc->cell_kind) {
95     case mkldnn::impl::alg_kind::vanilla_rnn: return 1;
96     case mkldnn::impl::alg_kind::vanilla_gru: return 1;
97     case mkldnn::impl::alg_kind::gru_linear_before_reset: return 1;
98     case mkldnn::impl::alg_kind::vanilla_lstm: return 2;
99     default: assert(!"unknown cell kind"); return 0;
100     }
101     return 0;
102 }
103
104 status_t check_data_type_consistency_fwd(const rnn_cell_desc_t *rnn_cell_desc,
105         prop_kind_t prop_kind, const memory_desc_t *src_layer_desc,
106         const memory_desc_t *src_iter_desc,
107         const memory_desc_t *weights_layer_desc,
108         const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc,
109         const memory_desc_t *dst_layer_desc,
110         const memory_desc_t *dst_iter_desc) {
111     using namespace data_type;
112     data_type_t src_layer_dt = src_layer_desc->data_type;
113     data_type_t dst_layer_dt = dst_layer_desc->data_type;
114     data_type_t weights_iter_dt = weights_iter_desc->data_type;
115     data_type_t weights_layer_dt = weights_layer_desc->data_type;
116
117     bool is_f32 = everyone_is(f32, src_layer_dt, dst_layer_dt, weights_iter_dt,
118                           weights_layer_dt)
119             && IMPLICATION(!is_zero_md(src_iter_desc),
120                           src_iter_desc->data_type == f32)
121             && IMPLICATION(!is_zero_md(dst_iter_desc),
122                           dst_iter_desc->data_type == f32)
123             && IMPLICATION(!is_zero_md(bias_desc), bias_desc->data_type == f32);
124
125 #if USE_MKL_PACKED_GEMM
126     bool is_u8u8u8 = src_layer_dt == u8
127             && IMPLICATION(!is_zero_md(src_iter_desc),
128                              src_iter_desc->data_type == u8)
129             && IMPLICATION(!is_zero_md(dst_iter_desc),
130                              dst_iter_desc->data_type == u8)
131             && one_of(dst_layer_dt, u8, f32)
132             && everyone_is(s8, weights_iter_dt, weights_layer_dt)
133             && IMPLICATION(!is_zero_md(bias_desc), bias_desc->data_type == f32);
134
135     bool is_f32u8f32 = src_layer_dt == u8
136             && IMPLICATION(!is_zero_md(src_iter_desc),
137                                src_iter_desc->data_type == f32)
138             && IMPLICATION(!is_zero_md(dst_iter_desc),
139                                dst_iter_desc->data_type == f32)
140             && one_of(dst_layer_dt, u8, f32)
141             && everyone_is(s8, weights_iter_dt, weights_layer_dt)
142             && IMPLICATION(!is_zero_md(bias_desc), bias_desc->data_type == f32);
143
144     bool is_inference = prop_kind == prop_kind::forward_inference;
145     bool is_lstm = rnn_cell_desc->cell_kind == mkldnn_vanilla_lstm;
146
147     return (is_f32 || ((is_u8u8u8 || is_f32u8f32) && is_lstm && is_inference))
148             ? success
149             : unimplemented;
150 #else
151     return is_f32 ? success : unimplemented;
152 #endif
153 }
154
155 status_t check_dim_consistency(const rnn_cell_desc_t *rnn_cell_desc,
156         rnn_direction_t direction, int L, int D, int T, int N, int S, int G,
157         int SLC, int SIC, int DLC, int DIC, const memory_desc_t *src_layer_desc,
158         const memory_desc_t *src_iter_desc,
159         const memory_desc_t *weights_layer_desc,
160         const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc,
161         const memory_desc_t *dst_layer_desc,
162         const memory_desc_t *dst_iter_desc) {
163     bool args_ok;
164
165     // * algorithm specific
166     args_ok = true
167         && IMPLICATION(rnn_cell_desc->cell_kind == alg_kind::vanilla_gru,
168                        DIC == SIC);
169     if (!args_ok) return invalid_arguments;
170     int extra_bias =
171             rnn_cell_desc->cell_kind == alg_kind::gru_linear_before_reset;
172
173     // * on num layers
174     args_ok = true
175         && L == weights_layer_desc->dims[0]
176         && L == weights_iter_desc->dims[0]
177         && IMPLICATION(!is_zero_md(bias_desc), L == bias_desc->dims[0])
178         && IMPLICATION(!is_zero_md(src_iter_desc), L == src_iter_desc->dims[0])
179         && IMPLICATION(!is_zero_md(dst_iter_desc), L == dst_iter_desc->dims[0]);
180     if (!args_ok) return invalid_arguments;
181
182     // * on num directions
183     args_ok = true
184         && D == weights_layer_desc->dims[1]
185         && D == weights_iter_desc->dims[1]
186         && IMPLICATION(!is_zero_md(bias_desc), D == bias_desc->dims[1])
187         && IMPLICATION(!is_zero_md(src_iter_desc), D == src_iter_desc->dims[1])
188         && IMPLICATION(!is_zero_md(dst_iter_desc), D == dst_iter_desc->dims[1]);
189     if (!args_ok) return invalid_arguments;
190
191     // * on num iterations
192     args_ok = true
193         && T == src_layer_desc->dims[0]
194         && T == dst_layer_desc->dims[0];
195     if (!args_ok) return invalid_arguments;
196
197     // * on mb
198     args_ok = true
199         && N == src_layer_desc->dims[1]
200         && N == dst_layer_desc->dims[1]
201         && IMPLICATION(!is_zero_md(src_iter_desc), N == src_iter_desc->dims[3])
202         && IMPLICATION(!is_zero_md(dst_iter_desc), N == dst_iter_desc->dims[3]);
203     if (!args_ok) return invalid_arguments;
204
205     // * on num gates
206     args_ok = true
207         && G == mkldnn_rnn_cell_get_gates_count(rnn_cell_desc)
208         && G == weights_layer_desc->dims[3]
209         && G == weights_iter_desc->dims[3]
210         && IMPLICATION(!is_zero_md(bias_desc),
211                 G + extra_bias == bias_desc->dims[2]);
212     if (!args_ok) return invalid_arguments;
213
214     // * on num states
215     args_ok = true
216         && S == mkldnn_rnn_cell_get_states_count(rnn_cell_desc)
217         && IMPLICATION(!is_zero_md(src_iter_desc), S == src_iter_desc->dims[2])
218         && IMPLICATION(!is_zero_md(dst_iter_desc), S == dst_iter_desc->dims[2]);
219     if (!args_ok) return invalid_arguments;
220
221     // * on slc
222     args_ok = true
223         && SLC == weights_layer_desc->dims[2]
224         && SLC == src_layer_desc->dims[2];
225     if (!args_ok) return invalid_arguments;
226
227     // * on sic
228     args_ok = true
229         && SIC == weights_iter_desc->dims[2]
230         && IMPLICATION(!is_zero_md(src_iter_desc),
231                 SIC == src_iter_desc->dims[4]);
232     if (!args_ok) return invalid_arguments;
233
234     // * on dlc
235     int dlc_multiplier = (direction == mkldnn_bidirectional_concat) ? 2 : 1;
236     args_ok = true
237         && DLC == dlc_multiplier * DIC
238         && DLC == dst_layer_desc->dims[2];
239     if (!args_ok) return invalid_arguments;
240
241     // * on dic
242     args_ok = true
243         && DIC == weights_layer_desc->dims[4]
244         && DIC == weights_iter_desc->dims[4]
245         && IMPLICATION(!is_zero_md(bias_desc), DIC == bias_desc->dims[3])
246         && IMPLICATION(!is_zero_md(dst_iter_desc),
247                 DIC == dst_iter_desc->dims[4]);
248     if (!args_ok) return invalid_arguments;
249
250     // * unrolling/fusion conditions
251     args_ok = true
252         && IMPLICATION(L > 1, (dlc_multiplier * SLC) == DLC)
253         && IMPLICATION(T > 1, SIC == DIC);
254     if (!args_ok) return invalid_arguments;
255
256     return success;
257 }
258
259 status_t MKLDNN_API mkldnn_rnn_forward_desc_init(mkldnn_rnn_desc_t *rnn_desc,
260         prop_kind_t prop_kind, const rnn_cell_desc_t *rnn_cell_desc,
261         const rnn_direction_t direction, const memory_desc_t *src_layer_desc,
262         const memory_desc_t *src_iter_desc,
263         const memory_desc_t *weights_layer_desc,
264         const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc,
265         const memory_desc_t *dst_layer_desc,
266         const memory_desc_t *dst_iter_desc) {
267     bool args_ok = true && rnn_cell_desc != nullptr
268             && !any_null(src_layer_desc, weights_layer_desc, weights_iter_desc,
269                        dst_layer_desc);
270     if (!args_ok) return invalid_arguments;
271
272     //check dimensions consistency
273     int L = weights_layer_desc->dims[0];
274     int T = src_layer_desc->dims[0];
275     int N = src_layer_desc->dims[1];
276     const int D = one_of(direction, mkldnn_unidirectional_left2right,
277                           mkldnn_unidirectional_right2left) ?
278             1 :
279             2;
280     int G = mkldnn_rnn_cell_get_gates_count(rnn_cell_desc);
281     int S = mkldnn_rnn_cell_get_states_count(rnn_cell_desc);
282     int SLC = src_layer_desc->dims[2];
283     int SIC = weights_iter_desc->dims[2];
284     int DLC = dst_layer_desc->dims[2];
285     int DIC = weights_layer_desc->dims[4];
286
287     CHECK(check_dim_consistency(rnn_cell_desc, direction, L, D, T, N, S,
288             G, SLC, SIC, DLC, DIC, src_layer_desc, src_iter_desc,
289             weights_layer_desc, weights_iter_desc, bias_desc, dst_layer_desc,
290             dst_iter_desc));
291
292     CHECK(check_data_type_consistency_fwd(rnn_cell_desc, prop_kind,
293             src_layer_desc, src_iter_desc, weights_layer_desc,
294             weights_iter_desc, bias_desc, dst_layer_desc, dst_iter_desc));
295
296     // Create the descriptor
297     mkldnn_rnn_desc_t rd = zero_rnn_desc();
298
299     rd.primitive_kind = primitive_kind::rnn;
300     rd.prop_kind = prop_kind;
301     rd.cell_desc = *rnn_cell_desc;
302     rd.direction = direction;
303     rd.src_layer_desc = copy_maybe_null(src_layer_desc);
304     rd.src_iter_desc = copy_maybe_null(src_iter_desc);
305     rd.weights_layer_desc = copy_maybe_null(weights_layer_desc);
306     rd.weights_iter_desc = copy_maybe_null(weights_iter_desc);
307     rd.bias_desc = copy_maybe_null(bias_desc);
308     rd.dst_layer_desc = copy_maybe_null(dst_layer_desc);
309     rd.dst_iter_desc = copy_maybe_null(dst_iter_desc);
310
311     *rnn_desc = rd;
312
313     return success;
314 }
315
316 status_t MKLDNN_API mkldnn_rnn_backward_desc_init(mkldnn_rnn_desc_t *rnn_desc,
317         prop_kind_t prop_kind, const rnn_cell_desc_t *rnn_cell_desc,
318         const rnn_direction_t direction, const memory_desc_t *src_layer_desc,
319         const memory_desc_t *src_iter_desc,
320         const memory_desc_t *weights_layer_desc,
321         const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc,
322         const memory_desc_t *dst_layer_desc, const memory_desc_t *dst_iter_desc,
323         const memory_desc_t *diff_src_layer_desc,
324         const memory_desc_t *diff_src_iter_desc,
325         const memory_desc_t *diff_weights_layer_desc,
326         const memory_desc_t *diff_weights_iter_desc,
327         const memory_desc_t *diff_bias_desc,
328         const memory_desc_t *diff_dst_layer_desc,
329         const memory_desc_t *diff_dst_iter_desc) {
330     bool args_ok = true
331             && !any_null(src_layer_desc, weights_layer_desc, weights_iter_desc,
332                        dst_layer_desc, diff_src_layer_desc,
333                        diff_weights_layer_desc, diff_weights_iter_desc,
334                        diff_dst_layer_desc);
335     if (!args_ok)
336         return invalid_arguments;
337
338     auto xnor_md = [=](const memory_desc_t *a_md, const memory_desc_t *b_md) {
339         return is_zero_md(a_md) == is_zero_md(b_md);
340     };
341
342     args_ok = args_ok && xnor_md(bias_desc, diff_bias_desc)
343             && xnor_md(dst_iter_desc, diff_dst_iter_desc)
344             && xnor_md(src_iter_desc, diff_src_iter_desc);
345     if (!args_ok)
346         return invalid_arguments;
347
348     //check dimensions consistency
349     int L = weights_layer_desc->dims[0];
350     int T = src_layer_desc->dims[0];
351     int N = src_layer_desc->dims[1];
352     const int D = one_of(direction, mkldnn_unidirectional_left2right,
353                           mkldnn_unidirectional_right2left) ?
354             1 :
355             2;
356     int G = mkldnn_rnn_cell_get_gates_count(rnn_cell_desc);
357     int S = mkldnn_rnn_cell_get_states_count(rnn_cell_desc);
358     int SLC = src_layer_desc->dims[2];
359     int SIC = weights_iter_desc->dims[2];
360     int DLC = dst_layer_desc->dims[2];
361     int DIC = weights_layer_desc->dims[4];
362
363     status_t st = check_dim_consistency(rnn_cell_desc, direction, L, D, T, N, S,
364             G, SLC, SIC, DLC, DIC, src_layer_desc, src_iter_desc,
365             weights_layer_desc, weights_iter_desc, bias_desc, dst_layer_desc,
366             dst_iter_desc);
367     if (st != success) return st;
368
369     st = check_dim_consistency(rnn_cell_desc, direction, L, D, T, N, S,
370             G, SLC, SIC, DLC, DIC, diff_src_layer_desc, diff_src_iter_desc,
371             diff_weights_layer_desc, diff_weights_iter_desc, diff_bias_desc,
372             diff_dst_layer_desc, diff_dst_iter_desc);
373     if (st != success) return st;
374
375     mkldnn_rnn_desc_t rd = zero_rnn_desc();
376
377     rd.primitive_kind = primitive_kind::rnn;
378     rd.prop_kind = prop_kind;
379     rd.cell_desc = *rnn_cell_desc;
380     rd.direction = direction;
381
382     rd.src_layer_desc = copy_maybe_null(src_layer_desc);
383     rd.src_iter_desc = copy_maybe_null(src_iter_desc);
384     rd.weights_layer_desc = copy_maybe_null(weights_layer_desc);
385     rd.weights_iter_desc = copy_maybe_null(weights_iter_desc);
386     rd.bias_desc = copy_maybe_null(bias_desc);
387     rd.dst_layer_desc = copy_maybe_null(dst_layer_desc);
388     rd.dst_iter_desc = copy_maybe_null(dst_iter_desc);
389     rd.diff_src_layer_desc = copy_maybe_null(diff_src_layer_desc);
390     rd.diff_src_iter_desc = copy_maybe_null(diff_src_iter_desc);
391     rd.diff_weights_layer_desc = copy_maybe_null(diff_weights_layer_desc);
392     rd.diff_weights_iter_desc = copy_maybe_null(diff_weights_iter_desc);
393     rd.diff_bias_desc = copy_maybe_null(diff_bias_desc);
394     rd.diff_dst_layer_desc = copy_maybe_null(diff_dst_layer_desc);
395     rd.diff_dst_iter_desc = copy_maybe_null(diff_dst_iter_desc);
396
397     *rnn_desc = rd;
398
399     return success;
400 }