Add a section of how to link IE with CMake project (#99)
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / rnn / cpu_rnn_pd.hpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef CPU_RNN_PD_HPP
18 #define CPU_RNN_PD_HPP
19
20 #include "c_types_map.hpp"
21 #include "../cpu_engine.hpp"
22 #include "../cpu_memory.hpp"
23 #include "../cpu_primitive.hpp"
24 #include "nstl.hpp"
25 #include "rnn_pd.hpp"
26 #include "type_helpers.hpp"
27 #include "utils.hpp"
28 #include "rnn_utils.hpp"
29
30 namespace mkldnn {
31 namespace impl {
32 namespace cpu {
33
34 struct cpu_rnn_fwd_pd_t : public rnn_fwd_pd_t {
35     using cpu_memory_pd_t = cpu_memory_t::pd_t;
36
37     cpu_rnn_fwd_pd_t(engine_t *engine, const rnn_desc_t *adesc,
38             const primitive_attr_t *attr, const rnn_fwd_pd_t *hint_fwd_pd)
39         : rnn_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
40         , src_layer_pd_(engine, &desc_.src_layer_desc)
41         , src_iter_pd_(engine, &desc_.src_iter_desc)
42         , weights_layer_pd_(engine, &desc_.weights_layer_desc)
43         , weights_iter_pd_(engine, &desc_.weights_iter_desc)
44         , bias_pd_(engine, &desc_.bias_desc)
45         , dst_layer_pd_(engine, &desc_.dst_layer_desc)
46         , dst_iter_pd_(engine, &desc_.dst_iter_desc)
47         , ws_pd_(engine_) {}
48     virtual ~cpu_rnn_fwd_pd_t() {}
49
50     virtual const cpu_memory_pd_t *src_pd(int index = 0) const override {
51         if (index == 0)
52             return &src_layer_pd_;
53         if (index == 1 && this->with_src_iter())
54             return &src_iter_pd_;
55         return nullptr;
56     }
57     virtual const cpu_memory_pd_t *weights_pd(int index = 0) const override {
58         if (index == 0)
59             return &weights_layer_pd_;
60         if (index == 1)
61             return &weights_iter_pd_;
62         if (index == 2 && this->with_bias())
63             return &bias_pd_;
64         return nullptr;
65     }
66     virtual const cpu_memory_pd_t *dst_pd(int index = 0) const override {
67         if (index == 0)
68             return &dst_layer_pd_;
69         if (index == 1 && this->with_dst_iter())
70             return &dst_iter_pd_;
71         return nullptr;
72     }
73     virtual const cpu_memory_pd_t *workspace_pd(int index = 0) const override {
74         return (index == 0 && !ws_pd_.is_zero()) ? &ws_pd_ : nullptr;
75     }
76
77 protected:
78     cpu_memory_pd_t src_layer_pd_;
79     cpu_memory_pd_t src_iter_pd_;
80     cpu_memory_pd_t weights_layer_pd_;
81     cpu_memory_pd_t weights_iter_pd_;
82     cpu_memory_pd_t bias_pd_;
83     cpu_memory_pd_t dst_layer_pd_;
84     cpu_memory_pd_t dst_iter_pd_;
85     cpu_memory_pd_t ws_pd_;
86
87     virtual status_t set_default_params() {
88         using namespace memory_format;
89         if (src_layer_pd_.desc()->format == any)
90             CHECK(src_layer_pd_.set_format(tnc));
91         if (dst_layer_pd_.desc()->format == any)
92             CHECK(dst_layer_pd_.set_format(tnc));
93
94         // Optional parameters
95         if ((!src_iter_pd_.is_zero()) && (src_iter_pd_.desc()->format == any))
96             CHECK(src_iter_pd_.set_format(ldsnc));
97         if ((!bias_pd_.is_zero()) && (bias_pd_.desc()->format == any))
98             CHECK(bias_pd_.set_format(ldgo));
99         if ((!dst_iter_pd_.is_zero()) && (dst_iter_pd_.desc()->format == any))
100             CHECK(dst_iter_pd_.set_format(ldsnc));
101
102         return status::success;
103     }
104
105     status_t check_layout_consistency() {
106         using namespace memory_format;
107         using namespace utils;
108         using namespace data_type;
109         bool ok = true;
110         ok = ok && src_layer_pd_.desc()->format == tnc
111                 && dst_layer_pd_.desc()->format == tnc;
112         ok = ok && IMPLICATION(!src_iter_pd_.is_zero(),
113                            src_iter_pd_.desc()->format == ldsnc)
114                 && IMPLICATION(!dst_iter_pd_.is_zero(),
115                            dst_iter_pd_.desc()->format == ldsnc);
116
117         ok = ok && one_of(weights_layer_pd_.desc()->format, ldigo, rnn_packed)
118                 && one_of(weights_iter_pd_.desc()->format, ldigo, rnn_packed);
119         ok = ok && IMPLICATION(weights_iter_pd_.desc()->format == rnn_packed,
120                            weights_iter_pd_.desc()
121                                            ->layout_desc.rnn_packed_desc.format
122                                    == mkldnn_ldigo_p);
123         ok = ok && IMPLICATION(weights_layer_pd_.desc()->format == rnn_packed,
124                            weights_layer_pd_.desc()
125                                            ->layout_desc.rnn_packed_desc.format
126                                    == mkldnn_ldigo_p);
127
128         ok = ok && IMPLICATION(!bias_pd_.is_zero(),
129                            bias_pd_.desc()->format == ldgo);
130
131         /* Int8 is supported only for packed weights */
132         data_type_t weights_iter_dt = weights_iter_pd_.desc()->data_type;
133         data_type_t weights_layer_dt = weights_layer_pd_.desc()->data_type;
134         ok = ok && IMPLICATION(weights_iter_dt == s8,
135                            weights_iter_pd_.desc()->format == rnn_packed);
136         ok = ok && IMPLICATION(weights_layer_dt == s8,
137                            weights_layer_pd_.desc()->format == rnn_packed);
138
139         return ok ? status::success : status::unimplemented;
140     }
141 };
142
143 struct cpu_rnn_bwd_pd_t : public rnn_bwd_pd_t {
144     using cpu_memory_pd_t = cpu_memory_t::pd_t;
145
146     cpu_rnn_bwd_pd_t(engine_t *engine, const rnn_desc_t *adesc,
147             const primitive_attr_t *attr, const rnn_fwd_pd_t *hint_fwd_pd)
148         : rnn_bwd_pd_t(engine, adesc, attr, hint_fwd_pd)
149         , src_layer_pd_(engine, &desc_.src_layer_desc)
150         , src_iter_pd_(engine, &desc_.src_iter_desc)
151         , weights_layer_pd_(engine, &desc_.weights_layer_desc)
152         , weights_iter_pd_(engine, &desc_.weights_iter_desc)
153         , bias_pd_(engine, &desc_.bias_desc)
154         , dst_layer_pd_(engine, &desc_.dst_layer_desc)
155         , dst_iter_pd_(engine, &desc_.dst_iter_desc)
156         , diff_src_layer_pd_(engine, &desc_.diff_src_layer_desc)
157         , diff_states_pd_(engine, &desc_.diff_src_iter_desc)
158         , diff_weights_layer_pd_(engine, &desc_.diff_weights_layer_desc)
159         , diff_weights_iter_pd_(engine, &desc_.diff_weights_iter_desc)
160         , diff_bias_pd_(engine, &desc_.diff_bias_desc)
161         , diff_dst_layer_pd_(engine, &desc_.diff_dst_layer_desc)
162         , diff_dst_iter_pd_(engine, &desc_.diff_dst_iter_desc)
163         , ws_pd_(engine_) {}
164     virtual ~cpu_rnn_bwd_pd_t() {}
165
166     virtual const cpu_memory_pd_t *src_pd(int index = 0) const override {
167         if (index == 0)
168             return &src_layer_pd_;
169         if (index == 1 && this->with_src_iter())
170             return &src_iter_pd_;
171         return nullptr;
172     }
173     virtual const cpu_memory_pd_t *weights_pd(int index = 0) const override {
174         if (index == 0)
175             return &weights_layer_pd_;
176         if (index == 1)
177             return &weights_iter_pd_;
178         if (index == 2 && this->with_bias())
179             return &bias_pd_;
180         return nullptr;
181     }
182     virtual const cpu_memory_pd_t *dst_pd(int index = 0) const override {
183         if (index == 0)
184             return &dst_layer_pd_;
185         if (index == 1 && this->with_dst_iter())
186             return &dst_iter_pd_;
187         return nullptr;
188     }
189     virtual const cpu_memory_pd_t *diff_src_pd(int index = 0) const override {
190         if (index == 0)
191             return &diff_src_layer_pd_;
192         if (index == 1 && this->with_src_iter())
193             return &diff_states_pd_;
194         return nullptr;
195     }
196     virtual const cpu_memory_pd_t *diff_weights_pd(
197             int index = 0) const override {
198         if (index == 0)
199             return &diff_weights_layer_pd_;
200         if (index == 1)
201             return &diff_weights_iter_pd_;
202         if (index == 2 && this->with_bias())
203             return &diff_bias_pd_;
204         return nullptr;
205     }
206     virtual const cpu_memory_pd_t *diff_dst_pd(int index = 0) const override {
207         if (index == 0)
208             return &diff_dst_layer_pd_;
209         if (index == 1 && this->with_dst_iter())
210             return &diff_dst_iter_pd_;
211         return nullptr;
212     }
213     virtual const cpu_memory_pd_t *workspace_pd(int index = 0) const override {
214         return (index == 0 && !ws_pd_.is_zero()) ? &ws_pd_ : nullptr;
215     }
216
217 protected:
218     cpu_memory_pd_t src_layer_pd_;
219     cpu_memory_pd_t src_iter_pd_;
220     cpu_memory_pd_t weights_layer_pd_;
221     cpu_memory_pd_t weights_iter_pd_;
222     cpu_memory_pd_t bias_pd_;
223     cpu_memory_pd_t dst_layer_pd_;
224     cpu_memory_pd_t dst_iter_pd_;
225     cpu_memory_pd_t diff_src_layer_pd_;
226     cpu_memory_pd_t diff_states_pd_;
227     cpu_memory_pd_t diff_weights_layer_pd_;
228     cpu_memory_pd_t diff_weights_iter_pd_;
229     cpu_memory_pd_t diff_bias_pd_;
230     cpu_memory_pd_t diff_dst_layer_pd_;
231     cpu_memory_pd_t diff_dst_iter_pd_;
232     cpu_memory_pd_t ws_pd_;
233
234     virtual status_t set_default_params() {
235         using namespace memory_format;
236         if (src_layer_pd_.desc()->format == any)
237             CHECK(src_layer_pd_.set_format(tnc));
238         if (diff_src_layer_pd_.desc()->format == any)
239             CHECK(diff_src_layer_pd_.set_format(tnc));
240         if (diff_weights_layer_pd_.desc()->format == any) {
241             memory_desc_t md = *(diff_weights_layer_pd_.desc());
242             md.format = ldigo;
243             CHECK(memory_desc_wrapper::compute_blocking(md));
244             CHECK(rnn_utils::set_good_strides(md));
245             cpu_memory_t::pd_t new_pd(engine_, &md);
246             diff_weights_layer_pd_ = new_pd;
247         }
248         if (diff_weights_iter_pd_.desc()->format == any) {
249             memory_desc_t md = *(diff_weights_iter_pd_.desc());
250             md.format = ldigo;
251             CHECK(memory_desc_wrapper::compute_blocking(md));
252             CHECK(rnn_utils::set_good_strides(md));
253             cpu_memory_t::pd_t new_pd(engine_, &md);
254             diff_weights_iter_pd_ = new_pd;
255         }
256         if (dst_layer_pd_.desc()->format == any)
257             CHECK(dst_layer_pd_.set_format(tnc));
258         if (diff_dst_layer_pd_.desc()->format == any)
259             CHECK(diff_dst_layer_pd_.set_format(tnc));
260
261         // Optional parameters
262         if ((!src_iter_pd_.is_zero()) && (src_iter_pd_.desc()->format == any))
263             CHECK(src_iter_pd_.set_format(ldsnc));
264         if ((!diff_states_pd_.is_zero())
265                 && (diff_states_pd_.desc()->format == any))
266             CHECK(diff_states_pd_.set_format(ldsnc));
267         if ((!bias_pd_.is_zero()) && (bias_pd_.desc()->format == any))
268             CHECK(bias_pd_.set_format(ldgo));
269         if ((!diff_bias_pd_.is_zero()) && (diff_bias_pd_.desc()->format == any))
270             CHECK(diff_bias_pd_.set_format(ldgo));
271         if ((!dst_iter_pd_.is_zero()) && (dst_iter_pd_.desc()->format == any))
272             CHECK(dst_iter_pd_.set_format(ldsnc));
273         if ((!diff_dst_iter_pd_.is_zero())
274                 && (diff_dst_iter_pd_.desc()->format == any))
275             CHECK(diff_dst_iter_pd_.set_format(ldsnc));
276
277         return status::success;
278     }
279
280     status_t check_layout_consistency() {
281         using namespace memory_format;
282         using namespace utils;
283         bool ok = true;
284         ok = ok && src_layer_pd_.desc()->format == tnc
285                 && dst_layer_pd_.desc()->format == tnc;
286         ok = ok && IMPLICATION(!src_iter_pd_.is_zero(),
287                            src_iter_pd_.desc()->format == ldsnc)
288                 && IMPLICATION(!dst_iter_pd_.is_zero(),
289                            dst_iter_pd_.desc()->format == ldsnc);
290
291         ok = ok && one_of(weights_layer_pd_.desc()->format, ldgoi, rnn_packed)
292                 && one_of(weights_iter_pd_.desc()->format, ldgoi, rnn_packed);
293         ok = ok && IMPLICATION(weights_iter_pd_.desc()->format == rnn_packed,
294                            weights_iter_pd_.desc()
295                                            ->layout_desc.rnn_packed_desc.format
296                                    == mkldnn_ldgoi_p);
297         ok = ok && IMPLICATION(weights_layer_pd_.desc()->format == rnn_packed,
298                            weights_layer_pd_.desc()
299                                            ->layout_desc.rnn_packed_desc.format
300                                    == mkldnn_ldgoi_p);
301
302         ok = ok && IMPLICATION(!bias_pd_.is_zero(),
303                            bias_pd_.desc()->format == ldgo);
304
305         ok = ok && diff_src_layer_pd_.desc()->format == tnc
306                 && diff_dst_layer_pd_.desc()->format == tnc;
307         ok = ok && IMPLICATION(!diff_states_pd_.is_zero(),
308                            diff_states_pd_.desc()->format == ldsnc)
309                 && IMPLICATION(!diff_dst_iter_pd_.is_zero(),
310                            diff_dst_iter_pd_.desc()->format == ldsnc);
311         ok = ok && diff_weights_layer_pd_.desc()->format == ldigo
312                 && diff_weights_iter_pd_.desc()->format == ldigo;
313         ok = ok && IMPLICATION(!diff_bias_pd_.is_zero(),
314                            diff_bias_pd_.desc()->format == ldgo);
315
316         return ok ? status::success : status::unimplemented;
317     }
318 };
319 }
320 }
321 }
322
323 #endif