Add a section of how to link IE with CMake project (#99)
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / cpu_rnn_pd.hpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef CPU_RNN_PD_HPP
18 #define CPU_RNN_PD_HPP
19
20 #include "c_types_map.hpp"
21 #include "cpu_engine.hpp"
22 #include "cpu_memory.hpp"
23 #include "cpu_primitive.hpp"
24 #include "nstl.hpp"
25 #include "rnn_pd.hpp"
26 #include "type_helpers.hpp"
27 #include "utils.hpp"
28
29 namespace mkldnn {
30 namespace impl {
31 namespace cpu {
32
33 struct cpu_rnn_fwd_pd_t : public rnn_fwd_pd_t {
34     using cpu_memory_pd_t = cpu_memory_t::pd_t;
35
36     cpu_rnn_fwd_pd_t(engine_t *engine, const rnn_desc_t *adesc,
37             const primitive_attr_t *attr, const rnn_fwd_pd_t *hint_fwd_pd)
38         : rnn_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
39         , src_layer_pd_(engine, &desc_.src_layer_desc)
40         , src_iter_pd_(engine, &desc_.src_iter_desc)
41         , weights_layer_pd_(engine, &desc_.weights_layer_desc)
42         , weights_iter_pd_(engine, &desc_.weights_iter_desc)
43         , bias_pd_(engine, &desc_.bias_desc)
44         , dst_layer_pd_(engine, &desc_.dst_layer_desc)
45         , dst_iter_pd_(engine, &desc_.dst_iter_desc)
46         , ws_pd_(engine_) {}
47     virtual ~cpu_rnn_fwd_pd_t() {}
48
49     virtual const cpu_memory_pd_t *src_pd(int index = 0) const override {
50         if (index == 0)
51             return &src_layer_pd_;
52         if (index == 1 && this->with_src_iter())
53             return &src_iter_pd_;
54         return nullptr;
55     }
56     virtual const cpu_memory_pd_t *weights_pd(int index = 0) const override {
57         if (index == 0)
58             return &weights_layer_pd_;
59         if (index == 1)
60             return &weights_iter_pd_;
61         if (index == 2 && this->with_bias())
62             return &bias_pd_;
63         return nullptr;
64     }
65     virtual const cpu_memory_pd_t *dst_pd(int index = 0) const override {
66         if (index == 0)
67             return &dst_layer_pd_;
68         if (index == 1 && this->with_dst_iter())
69             return &dst_iter_pd_;
70         return nullptr;
71     }
72     virtual const cpu_memory_pd_t *workspace_pd(int index = 0) const override {
73         return (index == 0 && !ws_pd_.is_zero()) ? &ws_pd_ : nullptr;
74     }
75
76 protected:
77     cpu_memory_pd_t src_layer_pd_;
78     cpu_memory_pd_t src_iter_pd_;
79     cpu_memory_pd_t weights_layer_pd_;
80     cpu_memory_pd_t weights_iter_pd_;
81     cpu_memory_pd_t bias_pd_;
82     cpu_memory_pd_t dst_layer_pd_;
83     cpu_memory_pd_t dst_iter_pd_;
84     cpu_memory_pd_t ws_pd_;
85
86     virtual status_t set_default_params() {
87         using namespace memory_format;
88         if (src_layer_pd_.desc()->format == any)
89             CHECK(src_layer_pd_.set_format(tnc));
90         if (weights_layer_pd_.desc()->format == any)
91             CHECK(weights_layer_pd_.set_format(ldigo));
92         if (weights_iter_pd_.desc()->format == any)
93             CHECK(weights_iter_pd_.set_format(ldigo));
94         if (dst_layer_pd_.desc()->format == any)
95             CHECK(dst_layer_pd_.set_format(tnc));
96
97         // Optional parameters
98         if ((!src_iter_pd_.is_zero()) && (src_iter_pd_.desc()->format == any))
99             CHECK(src_iter_pd_.set_format(ldsnc));
100         if ((!bias_pd_.is_zero()) && (bias_pd_.desc()->format == any))
101             CHECK(bias_pd_.set_format(ldgo));
102         if ((!dst_iter_pd_.is_zero()) && (dst_iter_pd_.desc()->format == any))
103             CHECK(dst_iter_pd_.set_format(ldsnc));
104
105         return status::success;
106     }
107 };
108
109 struct cpu_rnn_bwd_pd_t : public rnn_bwd_pd_t {
110     using cpu_memory_pd_t = cpu_memory_t::pd_t;
111
112     cpu_rnn_bwd_pd_t(engine_t *engine, const rnn_desc_t *adesc,
113             const primitive_attr_t *attr, const rnn_bwd_pd_t *hint_bwd_pd)
114         : rnn_bwd_pd_t(engine, adesc, attr, hint_bwd_pd)
115         , src_layer_pd_(engine, &desc_.src_layer_desc)
116         , src_iter_pd_(engine, &desc_.src_iter_desc)
117         , weights_layer_pd_(engine, &desc_.weights_layer_desc)
118         , weights_iter_pd_(engine, &desc_.weights_iter_desc)
119         , bias_pd_(engine, &desc_.bias_desc)
120         , dst_layer_pd_(engine, &desc_.dst_layer_desc)
121         , dst_iter_pd_(engine, &desc_.dst_iter_desc)
122         , diff_src_layer_pd_(engine, &desc_.diff_src_layer_desc)
123         , diff_states_pd_(engine, &desc_.diff_src_iter_desc)
124         , diff_weights_layer_pd_(engine, &desc_.diff_weights_layer_desc)
125         , diff_weights_iter_pd_(engine, &desc_.diff_weights_iter_desc)
126         , diff_bias_pd_(engine, &desc_.diff_bias_desc)
127         , diff_dst_layer_pd_(engine, &desc_.diff_dst_layer_desc)
128         , diff_dst_iter_pd_(engine, &desc_.diff_dst_iter_desc)
129         , ws_pd_(engine_) {}
130     virtual ~cpu_rnn_bwd_pd_t() {}
131
132     virtual const cpu_memory_pd_t *src_pd(int index = 0) const override {
133         if (index == 0)
134             return &src_layer_pd_;
135         if (index == 1 && this->with_src_iter())
136             return &src_iter_pd_;
137         return nullptr;
138     }
139     virtual const cpu_memory_pd_t *weights_pd(int index = 0) const override {
140         if (index == 0)
141             return &weights_layer_pd_;
142         if (index == 1)
143             return &weights_iter_pd_;
144         if (index == 2 && this->with_bias())
145             return &bias_pd_;
146         return nullptr;
147     }
148     virtual const cpu_memory_pd_t *dst_pd(int index = 0) const override {
149         if (index == 0)
150             return &dst_layer_pd_;
151         if (index == 1 && this->with_dst_iter())
152             return &dst_iter_pd_;
153         return nullptr;
154     }
155     virtual const cpu_memory_pd_t *diff_src_pd(int index = 0) const override {
156         if (index == 0)
157             return &diff_src_layer_pd_;
158         if (index == 1 && this->with_src_iter())
159             return &diff_states_pd_;
160         return nullptr;
161     }
162     virtual const cpu_memory_pd_t *diff_weights_pd(
163             int index = 0) const override {
164         if (index == 0)
165             return &diff_weights_layer_pd_;
166         if (index == 1)
167             return &diff_weights_iter_pd_;
168         if (index == 2 && this->with_bias())
169             return &diff_bias_pd_;
170         return nullptr;
171     }
172     virtual const cpu_memory_pd_t *diff_dst_pd(int index = 0) const override {
173         if (index == 0)
174             return &diff_dst_layer_pd_;
175         if (index == 1 && this->with_dst_iter())
176             return &diff_dst_iter_pd_;
177         return nullptr;
178     }
179     virtual const cpu_memory_pd_t *workspace_pd(int index = 0) const override {
180         return (index == 0 && !ws_pd_.is_zero()) ? &ws_pd_ : nullptr;
181     }
182
183 protected:
184     cpu_memory_pd_t src_layer_pd_;
185     cpu_memory_pd_t src_iter_pd_;
186     cpu_memory_pd_t weights_layer_pd_;
187     cpu_memory_pd_t weights_iter_pd_;
188     cpu_memory_pd_t bias_pd_;
189     cpu_memory_pd_t dst_layer_pd_;
190     cpu_memory_pd_t dst_iter_pd_;
191     cpu_memory_pd_t diff_src_layer_pd_;
192     cpu_memory_pd_t diff_states_pd_;
193     cpu_memory_pd_t diff_weights_layer_pd_;
194     cpu_memory_pd_t diff_weights_iter_pd_;
195     cpu_memory_pd_t diff_bias_pd_;
196     cpu_memory_pd_t diff_dst_layer_pd_;
197     cpu_memory_pd_t diff_dst_iter_pd_;
198     cpu_memory_pd_t ws_pd_;
199
200     virtual status_t set_default_params() {
201         using namespace memory_format;
202         if (src_layer_pd_.desc()->format == any)
203             CHECK(src_layer_pd_.set_format(tnc));
204         if (diff_src_layer_pd_.desc()->format == any)
205             CHECK(diff_src_layer_pd_.set_format(tnc));
206         if (weights_layer_pd_.desc()->format == any)
207             CHECK(weights_layer_pd_.set_format(ldgoi));
208         if (diff_weights_layer_pd_.desc()->format == any)
209             CHECK(diff_weights_layer_pd_.set_format(ldigo));
210         if (weights_iter_pd_.desc()->format == any)
211             CHECK(weights_iter_pd_.set_format(ldgoi));
212         if (diff_weights_iter_pd_.desc()->format == any)
213             CHECK(diff_weights_iter_pd_.set_format(ldigo));
214         if (dst_layer_pd_.desc()->format == any)
215             CHECK(dst_layer_pd_.set_format(tnc));
216         if (diff_dst_layer_pd_.desc()->format == any)
217             CHECK(diff_dst_layer_pd_.set_format(tnc));
218
219         // Optional parameters
220         if ((!src_iter_pd_.is_zero()) && (src_iter_pd_.desc()->format == any))
221             CHECK(src_iter_pd_.set_format(ldsnc));
222         if ((!diff_states_pd_.is_zero())
223                 && (diff_states_pd_.desc()->format == any))
224             CHECK(diff_states_pd_.set_format(ldsnc));
225         if ((!bias_pd_.is_zero()) && (bias_pd_.desc()->format == any))
226             CHECK(bias_pd_.set_format(ldgo));
227         if ((!diff_bias_pd_.is_zero()) && (diff_bias_pd_.desc()->format == any))
228             CHECK(diff_bias_pd_.set_format(ldgo));
229         if ((!dst_iter_pd_.is_zero()) && (dst_iter_pd_.desc()->format == any))
230             CHECK(dst_iter_pd_.set_format(ldsnc));
231         if ((!diff_dst_iter_pd_.is_zero())
232                 && (diff_dst_iter_pd_.desc()->format == any))
233             CHECK(diff_dst_iter_pd_.set_format(ldsnc));
234
235         return status::success;
236     }
237 };
238 }
239 }
240 }
241
242 #endif