1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef CPU_RNN_PD_HPP
18 #define CPU_RNN_PD_HPP
20 #include "c_types_map.hpp"
21 #include "../cpu_engine.hpp"
22 #include "../cpu_memory.hpp"
23 #include "../cpu_primitive.hpp"
26 #include "type_helpers.hpp"
28 #include "rnn_utils.hpp"
34 struct cpu_rnn_fwd_pd_t : public rnn_fwd_pd_t {
35 using cpu_memory_pd_t = cpu_memory_t::pd_t;
37 cpu_rnn_fwd_pd_t(engine_t *engine, const rnn_desc_t *adesc,
38 const primitive_attr_t *attr, const rnn_fwd_pd_t *hint_fwd_pd)
39 : rnn_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
40 , src_layer_pd_(engine, &desc_.src_layer_desc)
41 , src_iter_pd_(engine, &desc_.src_iter_desc)
42 , weights_layer_pd_(engine, &desc_.weights_layer_desc)
43 , weights_iter_pd_(engine, &desc_.weights_iter_desc)
44 , bias_pd_(engine, &desc_.bias_desc)
45 , dst_layer_pd_(engine, &desc_.dst_layer_desc)
46 , dst_iter_pd_(engine, &desc_.dst_iter_desc)
48 virtual ~cpu_rnn_fwd_pd_t() {}
50 virtual const cpu_memory_pd_t *src_pd(int index = 0) const override {
52 return &src_layer_pd_;
53 if (index == 1 && this->with_src_iter())
57 virtual const cpu_memory_pd_t *weights_pd(int index = 0) const override {
59 return &weights_layer_pd_;
61 return &weights_iter_pd_;
62 if (index == 2 && this->with_bias())
66 virtual const cpu_memory_pd_t *dst_pd(int index = 0) const override {
68 return &dst_layer_pd_;
69 if (index == 1 && this->with_dst_iter())
73 virtual const cpu_memory_pd_t *workspace_pd(int index = 0) const override {
74 return (index == 0 && !ws_pd_.is_zero()) ? &ws_pd_ : nullptr;
78 cpu_memory_pd_t src_layer_pd_;
79 cpu_memory_pd_t src_iter_pd_;
80 cpu_memory_pd_t weights_layer_pd_;
81 cpu_memory_pd_t weights_iter_pd_;
82 cpu_memory_pd_t bias_pd_;
83 cpu_memory_pd_t dst_layer_pd_;
84 cpu_memory_pd_t dst_iter_pd_;
85 cpu_memory_pd_t ws_pd_;
87 virtual status_t set_default_params() {
88 using namespace memory_format;
89 if (src_layer_pd_.desc()->format == any)
90 CHECK(src_layer_pd_.set_format(tnc));
91 if (dst_layer_pd_.desc()->format == any)
92 CHECK(dst_layer_pd_.set_format(tnc));
94 // Optional parameters
95 if ((!src_iter_pd_.is_zero()) && (src_iter_pd_.desc()->format == any))
96 CHECK(src_iter_pd_.set_format(ldsnc));
97 if ((!bias_pd_.is_zero()) && (bias_pd_.desc()->format == any))
98 CHECK(bias_pd_.set_format(ldgo));
99 if ((!dst_iter_pd_.is_zero()) && (dst_iter_pd_.desc()->format == any))
100 CHECK(dst_iter_pd_.set_format(ldsnc));
102 return status::success;
105 status_t check_layout_consistency() {
106 using namespace memory_format;
107 using namespace utils;
108 using namespace data_type;
110 ok = ok && src_layer_pd_.desc()->format == tnc
111 && dst_layer_pd_.desc()->format == tnc;
112 ok = ok && IMPLICATION(!src_iter_pd_.is_zero(),
113 src_iter_pd_.desc()->format == ldsnc)
114 && IMPLICATION(!dst_iter_pd_.is_zero(),
115 dst_iter_pd_.desc()->format == ldsnc);
117 ok = ok && one_of(weights_layer_pd_.desc()->format, ldigo, rnn_packed)
118 && one_of(weights_iter_pd_.desc()->format, ldigo, rnn_packed);
119 ok = ok && IMPLICATION(weights_iter_pd_.desc()->format == rnn_packed,
120 weights_iter_pd_.desc()
121 ->layout_desc.rnn_packed_desc.format
123 ok = ok && IMPLICATION(weights_layer_pd_.desc()->format == rnn_packed,
124 weights_layer_pd_.desc()
125 ->layout_desc.rnn_packed_desc.format
128 ok = ok && IMPLICATION(!bias_pd_.is_zero(),
129 bias_pd_.desc()->format == ldgo);
131 /* Int8 is supported only for packed weights */
132 data_type_t weights_iter_dt = weights_iter_pd_.desc()->data_type;
133 data_type_t weights_layer_dt = weights_layer_pd_.desc()->data_type;
134 ok = ok && IMPLICATION(weights_iter_dt == s8,
135 weights_iter_pd_.desc()->format == rnn_packed);
136 ok = ok && IMPLICATION(weights_layer_dt == s8,
137 weights_layer_pd_.desc()->format == rnn_packed);
139 return ok ? status::success : status::unimplemented;
143 struct cpu_rnn_bwd_pd_t : public rnn_bwd_pd_t {
144 using cpu_memory_pd_t = cpu_memory_t::pd_t;
146 cpu_rnn_bwd_pd_t(engine_t *engine, const rnn_desc_t *adesc,
147 const primitive_attr_t *attr, const rnn_fwd_pd_t *hint_fwd_pd)
148 : rnn_bwd_pd_t(engine, adesc, attr, hint_fwd_pd)
149 , src_layer_pd_(engine, &desc_.src_layer_desc)
150 , src_iter_pd_(engine, &desc_.src_iter_desc)
151 , weights_layer_pd_(engine, &desc_.weights_layer_desc)
152 , weights_iter_pd_(engine, &desc_.weights_iter_desc)
153 , bias_pd_(engine, &desc_.bias_desc)
154 , dst_layer_pd_(engine, &desc_.dst_layer_desc)
155 , dst_iter_pd_(engine, &desc_.dst_iter_desc)
156 , diff_src_layer_pd_(engine, &desc_.diff_src_layer_desc)
157 , diff_states_pd_(engine, &desc_.diff_src_iter_desc)
158 , diff_weights_layer_pd_(engine, &desc_.diff_weights_layer_desc)
159 , diff_weights_iter_pd_(engine, &desc_.diff_weights_iter_desc)
160 , diff_bias_pd_(engine, &desc_.diff_bias_desc)
161 , diff_dst_layer_pd_(engine, &desc_.diff_dst_layer_desc)
162 , diff_dst_iter_pd_(engine, &desc_.diff_dst_iter_desc)
164 virtual ~cpu_rnn_bwd_pd_t() {}
166 virtual const cpu_memory_pd_t *src_pd(int index = 0) const override {
168 return &src_layer_pd_;
169 if (index == 1 && this->with_src_iter())
170 return &src_iter_pd_;
173 virtual const cpu_memory_pd_t *weights_pd(int index = 0) const override {
175 return &weights_layer_pd_;
177 return &weights_iter_pd_;
178 if (index == 2 && this->with_bias())
182 virtual const cpu_memory_pd_t *dst_pd(int index = 0) const override {
184 return &dst_layer_pd_;
185 if (index == 1 && this->with_dst_iter())
186 return &dst_iter_pd_;
189 virtual const cpu_memory_pd_t *diff_src_pd(int index = 0) const override {
191 return &diff_src_layer_pd_;
192 if (index == 1 && this->with_src_iter())
193 return &diff_states_pd_;
196 virtual const cpu_memory_pd_t *diff_weights_pd(
197 int index = 0) const override {
199 return &diff_weights_layer_pd_;
201 return &diff_weights_iter_pd_;
202 if (index == 2 && this->with_bias())
203 return &diff_bias_pd_;
206 virtual const cpu_memory_pd_t *diff_dst_pd(int index = 0) const override {
208 return &diff_dst_layer_pd_;
209 if (index == 1 && this->with_dst_iter())
210 return &diff_dst_iter_pd_;
213 virtual const cpu_memory_pd_t *workspace_pd(int index = 0) const override {
214 return (index == 0 && !ws_pd_.is_zero()) ? &ws_pd_ : nullptr;
218 cpu_memory_pd_t src_layer_pd_;
219 cpu_memory_pd_t src_iter_pd_;
220 cpu_memory_pd_t weights_layer_pd_;
221 cpu_memory_pd_t weights_iter_pd_;
222 cpu_memory_pd_t bias_pd_;
223 cpu_memory_pd_t dst_layer_pd_;
224 cpu_memory_pd_t dst_iter_pd_;
225 cpu_memory_pd_t diff_src_layer_pd_;
226 cpu_memory_pd_t diff_states_pd_;
227 cpu_memory_pd_t diff_weights_layer_pd_;
228 cpu_memory_pd_t diff_weights_iter_pd_;
229 cpu_memory_pd_t diff_bias_pd_;
230 cpu_memory_pd_t diff_dst_layer_pd_;
231 cpu_memory_pd_t diff_dst_iter_pd_;
232 cpu_memory_pd_t ws_pd_;
234 virtual status_t set_default_params() {
235 using namespace memory_format;
236 if (src_layer_pd_.desc()->format == any)
237 CHECK(src_layer_pd_.set_format(tnc));
238 if (diff_src_layer_pd_.desc()->format == any)
239 CHECK(diff_src_layer_pd_.set_format(tnc));
240 if (diff_weights_layer_pd_.desc()->format == any) {
241 memory_desc_t md = *(diff_weights_layer_pd_.desc());
243 CHECK(memory_desc_wrapper::compute_blocking(md));
244 CHECK(rnn_utils::set_good_strides(md));
245 cpu_memory_t::pd_t new_pd(engine_, &md);
246 diff_weights_layer_pd_ = new_pd;
248 if (diff_weights_iter_pd_.desc()->format == any) {
249 memory_desc_t md = *(diff_weights_iter_pd_.desc());
251 CHECK(memory_desc_wrapper::compute_blocking(md));
252 CHECK(rnn_utils::set_good_strides(md));
253 cpu_memory_t::pd_t new_pd(engine_, &md);
254 diff_weights_iter_pd_ = new_pd;
256 if (dst_layer_pd_.desc()->format == any)
257 CHECK(dst_layer_pd_.set_format(tnc));
258 if (diff_dst_layer_pd_.desc()->format == any)
259 CHECK(diff_dst_layer_pd_.set_format(tnc));
261 // Optional parameters
262 if ((!src_iter_pd_.is_zero()) && (src_iter_pd_.desc()->format == any))
263 CHECK(src_iter_pd_.set_format(ldsnc));
264 if ((!diff_states_pd_.is_zero())
265 && (diff_states_pd_.desc()->format == any))
266 CHECK(diff_states_pd_.set_format(ldsnc));
267 if ((!bias_pd_.is_zero()) && (bias_pd_.desc()->format == any))
268 CHECK(bias_pd_.set_format(ldgo));
269 if ((!diff_bias_pd_.is_zero()) && (diff_bias_pd_.desc()->format == any))
270 CHECK(diff_bias_pd_.set_format(ldgo));
271 if ((!dst_iter_pd_.is_zero()) && (dst_iter_pd_.desc()->format == any))
272 CHECK(dst_iter_pd_.set_format(ldsnc));
273 if ((!diff_dst_iter_pd_.is_zero())
274 && (diff_dst_iter_pd_.desc()->format == any))
275 CHECK(diff_dst_iter_pd_.set_format(ldsnc));
277 return status::success;
280 status_t check_layout_consistency() {
281 using namespace memory_format;
282 using namespace utils;
284 ok = ok && src_layer_pd_.desc()->format == tnc
285 && dst_layer_pd_.desc()->format == tnc;
286 ok = ok && IMPLICATION(!src_iter_pd_.is_zero(),
287 src_iter_pd_.desc()->format == ldsnc)
288 && IMPLICATION(!dst_iter_pd_.is_zero(),
289 dst_iter_pd_.desc()->format == ldsnc);
291 ok = ok && one_of(weights_layer_pd_.desc()->format, ldgoi, rnn_packed)
292 && one_of(weights_iter_pd_.desc()->format, ldgoi, rnn_packed);
293 ok = ok && IMPLICATION(weights_iter_pd_.desc()->format == rnn_packed,
294 weights_iter_pd_.desc()
295 ->layout_desc.rnn_packed_desc.format
297 ok = ok && IMPLICATION(weights_layer_pd_.desc()->format == rnn_packed,
298 weights_layer_pd_.desc()
299 ->layout_desc.rnn_packed_desc.format
302 ok = ok && IMPLICATION(!bias_pd_.is_zero(),
303 bias_pd_.desc()->format == ldgo);
305 ok = ok && diff_src_layer_pd_.desc()->format == tnc
306 && diff_dst_layer_pd_.desc()->format == tnc;
307 ok = ok && IMPLICATION(!diff_states_pd_.is_zero(),
308 diff_states_pd_.desc()->format == ldsnc)
309 && IMPLICATION(!diff_dst_iter_pd_.is_zero(),
310 diff_dst_iter_pd_.desc()->format == ldsnc);
311 ok = ok && diff_weights_layer_pd_.desc()->format == ldigo
312 && diff_weights_iter_pd_.desc()->format == ldigo;
313 ok = ok && IMPLICATION(!diff_bias_pd_.is_zero(),
314 diff_bias_pd_.desc()->format == ldgo);
316 return ok ? status::success : status::unimplemented;