1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef CPU_RNN_PD_HPP
18 #define CPU_RNN_PD_HPP
20 #include "c_types_map.hpp"
21 #include "cpu_engine.hpp"
22 #include "cpu_memory.hpp"
23 #include "cpu_primitive.hpp"
26 #include "type_helpers.hpp"
33 struct cpu_rnn_fwd_pd_t : public rnn_fwd_pd_t {
34 using cpu_memory_pd_t = cpu_memory_t::pd_t;
36 cpu_rnn_fwd_pd_t(engine_t *engine, const rnn_desc_t *adesc,
37 const primitive_attr_t *attr, const rnn_fwd_pd_t *hint_fwd_pd)
38 : rnn_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
39 , src_layer_pd_(engine, &desc_.src_layer_desc)
40 , src_iter_pd_(engine, &desc_.src_iter_desc)
41 , weights_layer_pd_(engine, &desc_.weights_layer_desc)
42 , weights_iter_pd_(engine, &desc_.weights_iter_desc)
43 , bias_pd_(engine, &desc_.bias_desc)
44 , dst_layer_pd_(engine, &desc_.dst_layer_desc)
45 , dst_iter_pd_(engine, &desc_.dst_iter_desc)
47 virtual ~cpu_rnn_fwd_pd_t() {}
49 virtual const cpu_memory_pd_t *src_pd(int index = 0) const override {
51 return &src_layer_pd_;
52 if (index == 1 && this->with_src_iter())
56 virtual const cpu_memory_pd_t *weights_pd(int index = 0) const override {
58 return &weights_layer_pd_;
60 return &weights_iter_pd_;
61 if (index == 2 && this->with_bias())
65 virtual const cpu_memory_pd_t *dst_pd(int index = 0) const override {
67 return &dst_layer_pd_;
68 if (index == 1 && this->with_dst_iter())
72 virtual const cpu_memory_pd_t *workspace_pd(int index = 0) const override {
73 return (index == 0 && !ws_pd_.is_zero()) ? &ws_pd_ : nullptr;
77 cpu_memory_pd_t src_layer_pd_;
78 cpu_memory_pd_t src_iter_pd_;
79 cpu_memory_pd_t weights_layer_pd_;
80 cpu_memory_pd_t weights_iter_pd_;
81 cpu_memory_pd_t bias_pd_;
82 cpu_memory_pd_t dst_layer_pd_;
83 cpu_memory_pd_t dst_iter_pd_;
84 cpu_memory_pd_t ws_pd_;
86 virtual status_t set_default_params() {
87 using namespace memory_format;
88 if (src_layer_pd_.desc()->format == any)
89 CHECK(src_layer_pd_.set_format(tnc));
90 if (weights_layer_pd_.desc()->format == any)
91 CHECK(weights_layer_pd_.set_format(ldigo));
92 if (weights_iter_pd_.desc()->format == any)
93 CHECK(weights_iter_pd_.set_format(ldigo));
94 if (dst_layer_pd_.desc()->format == any)
95 CHECK(dst_layer_pd_.set_format(tnc));
97 // Optional parameters
98 if ((!src_iter_pd_.is_zero()) && (src_iter_pd_.desc()->format == any))
99 CHECK(src_iter_pd_.set_format(ldsnc));
100 if ((!bias_pd_.is_zero()) && (bias_pd_.desc()->format == any))
101 CHECK(bias_pd_.set_format(ldgo));
102 if ((!dst_iter_pd_.is_zero()) && (dst_iter_pd_.desc()->format == any))
103 CHECK(dst_iter_pd_.set_format(ldsnc));
105 return status::success;
109 struct cpu_rnn_bwd_pd_t : public rnn_bwd_pd_t {
110 using cpu_memory_pd_t = cpu_memory_t::pd_t;
112 cpu_rnn_bwd_pd_t(engine_t *engine, const rnn_desc_t *adesc,
113 const primitive_attr_t *attr, const rnn_bwd_pd_t *hint_bwd_pd)
114 : rnn_bwd_pd_t(engine, adesc, attr, hint_bwd_pd)
115 , src_layer_pd_(engine, &desc_.src_layer_desc)
116 , src_iter_pd_(engine, &desc_.src_iter_desc)
117 , weights_layer_pd_(engine, &desc_.weights_layer_desc)
118 , weights_iter_pd_(engine, &desc_.weights_iter_desc)
119 , bias_pd_(engine, &desc_.bias_desc)
120 , dst_layer_pd_(engine, &desc_.dst_layer_desc)
121 , dst_iter_pd_(engine, &desc_.dst_iter_desc)
122 , diff_src_layer_pd_(engine, &desc_.diff_src_layer_desc)
123 , diff_states_pd_(engine, &desc_.diff_src_iter_desc)
124 , diff_weights_layer_pd_(engine, &desc_.diff_weights_layer_desc)
125 , diff_weights_iter_pd_(engine, &desc_.diff_weights_iter_desc)
126 , diff_bias_pd_(engine, &desc_.diff_bias_desc)
127 , diff_dst_layer_pd_(engine, &desc_.diff_dst_layer_desc)
128 , diff_dst_iter_pd_(engine, &desc_.diff_dst_iter_desc)
130 virtual ~cpu_rnn_bwd_pd_t() {}
132 virtual const cpu_memory_pd_t *src_pd(int index = 0) const override {
134 return &src_layer_pd_;
135 if (index == 1 && this->with_src_iter())
136 return &src_iter_pd_;
139 virtual const cpu_memory_pd_t *weights_pd(int index = 0) const override {
141 return &weights_layer_pd_;
143 return &weights_iter_pd_;
144 if (index == 2 && this->with_bias())
148 virtual const cpu_memory_pd_t *dst_pd(int index = 0) const override {
150 return &dst_layer_pd_;
151 if (index == 1 && this->with_dst_iter())
152 return &dst_iter_pd_;
155 virtual const cpu_memory_pd_t *diff_src_pd(int index = 0) const override {
157 return &diff_src_layer_pd_;
158 if (index == 1 && this->with_src_iter())
159 return &diff_states_pd_;
162 virtual const cpu_memory_pd_t *diff_weights_pd(
163 int index = 0) const override {
165 return &diff_weights_layer_pd_;
167 return &diff_weights_iter_pd_;
168 if (index == 2 && this->with_bias())
169 return &diff_bias_pd_;
172 virtual const cpu_memory_pd_t *diff_dst_pd(int index = 0) const override {
174 return &diff_dst_layer_pd_;
175 if (index == 1 && this->with_dst_iter())
176 return &diff_dst_iter_pd_;
179 virtual const cpu_memory_pd_t *workspace_pd(int index = 0) const override {
180 return (index == 0 && !ws_pd_.is_zero()) ? &ws_pd_ : nullptr;
184 cpu_memory_pd_t src_layer_pd_;
185 cpu_memory_pd_t src_iter_pd_;
186 cpu_memory_pd_t weights_layer_pd_;
187 cpu_memory_pd_t weights_iter_pd_;
188 cpu_memory_pd_t bias_pd_;
189 cpu_memory_pd_t dst_layer_pd_;
190 cpu_memory_pd_t dst_iter_pd_;
191 cpu_memory_pd_t diff_src_layer_pd_;
192 cpu_memory_pd_t diff_states_pd_;
193 cpu_memory_pd_t diff_weights_layer_pd_;
194 cpu_memory_pd_t diff_weights_iter_pd_;
195 cpu_memory_pd_t diff_bias_pd_;
196 cpu_memory_pd_t diff_dst_layer_pd_;
197 cpu_memory_pd_t diff_dst_iter_pd_;
198 cpu_memory_pd_t ws_pd_;
200 virtual status_t set_default_params() {
201 using namespace memory_format;
202 if (src_layer_pd_.desc()->format == any)
203 CHECK(src_layer_pd_.set_format(tnc));
204 if (diff_src_layer_pd_.desc()->format == any)
205 CHECK(diff_src_layer_pd_.set_format(tnc));
206 if (weights_layer_pd_.desc()->format == any)
207 CHECK(weights_layer_pd_.set_format(ldgoi));
208 if (diff_weights_layer_pd_.desc()->format == any)
209 CHECK(diff_weights_layer_pd_.set_format(ldigo));
210 if (weights_iter_pd_.desc()->format == any)
211 CHECK(weights_iter_pd_.set_format(ldgoi));
212 if (diff_weights_iter_pd_.desc()->format == any)
213 CHECK(diff_weights_iter_pd_.set_format(ldigo));
214 if (dst_layer_pd_.desc()->format == any)
215 CHECK(dst_layer_pd_.set_format(tnc));
216 if (diff_dst_layer_pd_.desc()->format == any)
217 CHECK(diff_dst_layer_pd_.set_format(tnc));
219 // Optional parameters
220 if ((!src_iter_pd_.is_zero()) && (src_iter_pd_.desc()->format == any))
221 CHECK(src_iter_pd_.set_format(ldsnc));
222 if ((!diff_states_pd_.is_zero())
223 && (diff_states_pd_.desc()->format == any))
224 CHECK(diff_states_pd_.set_format(ldsnc));
225 if ((!bias_pd_.is_zero()) && (bias_pd_.desc()->format == any))
226 CHECK(bias_pd_.set_format(ldgo));
227 if ((!diff_bias_pd_.is_zero()) && (diff_bias_pd_.desc()->format == any))
228 CHECK(diff_bias_pd_.set_format(ldgo));
229 if ((!dst_iter_pd_.is_zero()) && (dst_iter_pd_.desc()->format == any))
230 CHECK(dst_iter_pd_.set_format(ldsnc));
231 if ((!diff_dst_iter_pd_.is_zero())
232 && (diff_dst_iter_pd_.desc()->format == any))
233 CHECK(diff_dst_iter_pd_.set_format(ldsnc));
235 return status::success;