1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef CPU_CONVOLUTION_FWD_PD_HPP
18 #define CPU_CONVOLUTION_FWD_PD_HPP
22 #include "c_types_map.hpp"
23 #include "convolution_pd.hpp"
24 #include "cpu_engine.hpp"
25 #include "cpu_memory.hpp"
26 #include "cpu_primitive.hpp"
27 #include "type_helpers.hpp"
34 struct cpu_convolution_fwd_pd_t: public convolution_fwd_pd_t {
35 using cpu_memory_pd_t = cpu_memory_t::pd_t;
37 cpu_convolution_fwd_pd_t(engine_t *engine,
38 const convolution_desc_t *adesc,
39 const primitive_attr_t *attr,
40 const typename cpu_convolution_fwd_pd_t::base_class *hint_fwd_pd)
41 : convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
42 , src_pd_(this->engine_, &this->desc()->src_desc)
43 , dst_pd_(this->engine_, &this->desc()->dst_desc)
44 , weights_pd_(this->engine_, &this->desc()->weights_desc)
45 , bias_pd_(this->engine_, &this->desc()->bias_desc) {}
46 virtual ~cpu_convolution_fwd_pd_t() {}
48 virtual const cpu_memory_pd_t *src_pd(int index = 0) const override
49 { return index == 0 ? &src_pd_ : nullptr; }
50 virtual const cpu_memory_pd_t *dst_pd(int index = 0) const override
51 { return index == 0 ? &dst_pd_ : nullptr; }
52 virtual const cpu_memory_pd_t *weights_pd(int index = 0) const override {
53 if (index == 0) return &weights_pd_;
54 if (index == 1 && this->with_bias()) return &bias_pd_;
58 bool has_padded_dst() const {
59 memory_desc_wrapper dst_d(&dst_pd_);
60 if (!dst_d.is_blocking_desc()) return false;
61 return this->OC() != dst_d.blocking_desc().padding_dims[1];
64 bool wants_padded_bias() const {
65 if (!this->with_bias()) return false;
66 return has_padded_dst();
69 bool wants_zero_pad_dst(bool jit_impl = true) const {
70 if (!has_padded_dst()) return false;
71 const auto &po = this->attr()->post_ops_;
73 if ((idx = po.find(primitive_kind::eltwise)) == -1) return false;
74 return !math::eltwise_fwd_preserves_zero(po.entry_[idx].eltwise.alg,
79 cpu_memory_pd_t src_pd_, dst_pd_;
80 cpu_memory_pd_t weights_pd_, bias_pd_;
82 inline memory_format_t src_format()
84 using namespace memory_format;
85 return utils::pick(this->desc()->src_desc.ndims - 3, ncw, nchw, ncdhw);
87 inline memory_format_t wei_format()
89 using namespace memory_format;
90 return this->with_groups()
91 ? utils::pick(this->desc()->src_desc.ndims - 3, goiw, goihw, goidhw)
92 : utils::pick(this->desc()->src_desc.ndims - 3, oiw, oihw, oidhw);
95 virtual status_t set_default_params() {
96 using namespace memory_format;
97 if (src_pd_.desc()->format == any)
98 CHECK(src_pd_.set_format(src_format()));
99 if (dst_pd_.desc()->format == any)
100 CHECK(dst_pd_.set_format(src_pd_.desc()->format));
101 if (weights_pd_.desc()->format == any)
102 CHECK(weights_pd_.set_format(wei_format()));
103 if (bias_pd_.desc()->format == any)
104 CHECK(bias_pd_.set_format(x));
105 if (this->desc()->alg_kind == alg_kind::convolution_auto)
106 CHECK(this->set_alg_kind(alg_kind::convolution_direct));
107 return status::success;
111 struct cpu_convolution_bwd_data_pd_t: public convolution_bwd_data_pd_t {
112 using cpu_memory_pd_t = cpu_memory_t::pd_t;
114 cpu_convolution_bwd_data_pd_t(engine_t *engine,
115 const convolution_desc_t *adesc,
116 const primitive_attr_t *attr,
117 const convolution_fwd_pd_t *hint_fwd_pd)
118 : convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd)
119 , diff_src_pd_(this->engine_, &this->desc_.diff_src_desc)
120 , diff_dst_pd_(this->engine_, &this->desc_.diff_dst_desc)
121 , weights_pd_(this->engine_, &this->desc_.weights_desc)
122 , bias_pd_(this->engine_, &this->desc_.bias_desc) {}
123 virtual ~cpu_convolution_bwd_data_pd_t() {}
125 virtual const cpu_memory_pd_t *diff_src_pd(int index = 0) const override
126 { return index == 0 ? &diff_src_pd_ : nullptr; }
127 virtual const cpu_memory_pd_t *diff_dst_pd(int index = 0) const override
128 { return index == 0 ? &diff_dst_pd_ : nullptr; }
129 virtual const cpu_memory_pd_t *weights_pd(int index = 0) const override {
130 if (index == 0) return &weights_pd_;
131 if (index == 1 && this->with_bias()) return &bias_pd_;
136 cpu_memory_pd_t diff_src_pd_, diff_dst_pd_;
137 cpu_memory_pd_t weights_pd_, bias_pd_;
139 inline memory_format_t src_format()
141 using namespace memory_format;
142 return utils::pick(this->desc_.diff_src_desc.ndims - 3, ncw, nchw, ncdhw);
144 inline memory_format_t wei_format()
146 using namespace memory_format;
147 return this->with_groups()
148 ? utils::pick(this->desc_.diff_src_desc.ndims - 3, goiw, goihw, goidhw)
149 : utils::pick(this->desc_.diff_src_desc.ndims - 3, oiw, oihw, oidhw);
152 virtual status_t set_default_params() {
153 using namespace memory_format;
154 if (diff_src_pd_.desc()->format == any)
155 CHECK(diff_src_pd_.set_format(src_format()));
156 if (diff_dst_pd_.desc()->format == any)
157 CHECK(diff_dst_pd_.set_format(diff_src_pd_.desc()->format));
158 if (weights_pd_.desc()->format == any)
159 CHECK(weights_pd_.set_format(wei_format()));
160 if (bias_pd_.desc()->format == any)
161 CHECK(bias_pd_.set_format(x));
162 if (this->desc()->alg_kind == alg_kind::convolution_auto)
163 CHECK(this->set_alg_kind(alg_kind::convolution_direct));
164 return status::success;
168 struct cpu_convolution_bwd_weights_pd_t: public convolution_bwd_weights_pd_t {
169 using cpu_memory_pd_t = cpu_memory_t::pd_t;
171 cpu_convolution_bwd_weights_pd_t(engine_t *engine,
172 const convolution_desc_t *adesc,
173 const primitive_attr_t *attr,
174 const convolution_fwd_pd_t *hint_fwd_pd)
175 : convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd)
176 , src_pd_(this->engine_, &this->desc_.src_desc)
177 , diff_dst_pd_(this->engine_, &this->desc_.diff_dst_desc)
178 , diff_weights_pd_(this->engine_, &this->desc_.diff_weights_desc)
179 , diff_bias_pd_(this->engine_, &this->desc_.diff_bias_desc) {}
180 virtual ~cpu_convolution_bwd_weights_pd_t() {}
182 virtual const cpu_memory_pd_t *src_pd(int index = 0) const override
183 { return index == 0 ? &src_pd_ : nullptr; }
184 virtual const cpu_memory_pd_t *diff_dst_pd(int index = 0) const override
185 { return index == 0 ? &diff_dst_pd_ : nullptr; }
186 virtual const cpu_memory_pd_t *diff_weights_pd(int index = 0) const
188 if (index == 0) return &diff_weights_pd_;
189 if (index == 1 && this->with_bias()) return &diff_bias_pd_;
193 bool wants_padded_bias() const {
194 if (!this->with_bias()) return false;
195 memory_desc_wrapper diff_dst_d(&diff_dst_pd_);
196 if (!diff_dst_d.is_blocking_desc()) return false;
197 return OC() != diff_dst_d.blocking_desc().padding_dims[1];
201 cpu_memory_pd_t src_pd_;
202 cpu_memory_pd_t diff_dst_pd_;
203 cpu_memory_pd_t diff_weights_pd_, diff_bias_pd_;
205 inline memory_format_t src_format()
207 using namespace memory_format;
208 return utils::pick(this->desc_.src_desc.ndims - 3, ncw, nchw, ncdhw);
210 inline memory_format_t wei_format()
212 using namespace memory_format;
213 return this->with_groups()
214 ? utils::pick(this->desc_.src_desc.ndims - 3, goiw, goihw, goidhw)
215 : utils::pick(this->desc_.src_desc.ndims - 3, oiw, oihw, oidhw);
218 virtual status_t set_default_params() {
219 using namespace memory_format;
220 if (src_pd_.desc()->format == any)
221 CHECK(src_pd_.set_format(src_format()));
222 if (diff_dst_pd_.desc()->format == any)
223 CHECK(diff_dst_pd_.set_format(src_format()));
224 if (diff_weights_pd_.desc()->format == any)
225 CHECK(diff_weights_pd_.set_format(wei_format()));
226 if (diff_bias_pd_.desc()->format == any)
227 CHECK(diff_bias_pd_.set_format(x));
228 if (this->desc()->alg_kind == alg_kind::convolution_auto)
229 CHECK(this->set_alg_kind(alg_kind::convolution_direct));
230 return status::success;
240 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s