Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / cpu_convolution_pd.hpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef CPU_CONVOLUTION_FWD_PD_HPP
18 #define CPU_CONVOLUTION_FWD_PD_HPP
19
20 #include <assert.h>
21
22 #include "c_types_map.hpp"
23 #include "convolution_pd.hpp"
24 #include "cpu_engine.hpp"
25 #include "cpu_memory.hpp"
26 #include "cpu_primitive.hpp"
27 #include "type_helpers.hpp"
28 #include "utils.hpp"
29
30 namespace mkldnn {
31 namespace impl {
32 namespace cpu {
33
34 struct cpu_convolution_fwd_pd_t: public convolution_fwd_pd_t {
35     using cpu_memory_pd_t = cpu_memory_t::pd_t;
36
37     cpu_convolution_fwd_pd_t(engine_t *engine,
38             const convolution_desc_t *adesc,
39             const primitive_attr_t *attr,
40             const typename cpu_convolution_fwd_pd_t::base_class *hint_fwd_pd)
41         : convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
42         , src_pd_(this->engine_, &this->desc()->src_desc)
43         , dst_pd_(this->engine_, &this->desc()->dst_desc)
44         , weights_pd_(this->engine_, &this->desc()->weights_desc)
45         , bias_pd_(this->engine_, &this->desc()->bias_desc) {}
46     virtual ~cpu_convolution_fwd_pd_t() {}
47
48     virtual const cpu_memory_pd_t *src_pd(int index = 0) const override
49     { return index == 0 ? &src_pd_ : nullptr; }
50     virtual const cpu_memory_pd_t *dst_pd(int index = 0) const override
51     { return index == 0 ? &dst_pd_ : nullptr; }
52     virtual const cpu_memory_pd_t *weights_pd(int index = 0) const override {
53         if (index == 0) return &weights_pd_;
54         if (index == 1 && this->with_bias()) return &bias_pd_;
55         return nullptr;
56     }
57
58     bool has_padded_dst() const {
59         memory_desc_wrapper dst_d(&dst_pd_);
60         if (!dst_d.is_blocking_desc()) return false;
61         return this->OC() != dst_d.blocking_desc().padding_dims[1];
62     }
63
64     bool wants_padded_bias() const {
65         if (!this->with_bias()) return false;
66         return has_padded_dst();
67     }
68
69     bool wants_zero_pad_dst(bool jit_impl = true) const {
70         if (!has_padded_dst()) return false;
71         const auto &po = this->attr()->post_ops_;
72         int idx;
73         if ((idx = po.find(primitive_kind::eltwise)) == -1) return false;
74         return !math::eltwise_fwd_preserves_zero(po.entry_[idx].eltwise.alg,
75                 jit_impl);
76     }
77
78 protected:
79     cpu_memory_pd_t src_pd_, dst_pd_;
80     cpu_memory_pd_t weights_pd_, bias_pd_;
81
82     inline memory_format_t src_format()
83     {
84         using namespace memory_format;
85         return utils::pick(this->desc()->src_desc.ndims - 3, ncw, nchw, ncdhw);
86     }
87     inline memory_format_t wei_format()
88     {
89         using namespace memory_format;
90         return this->with_groups()
91             ? utils::pick(this->desc()->src_desc.ndims - 3, goiw, goihw, goidhw)
92             : utils::pick(this->desc()->src_desc.ndims - 3, oiw, oihw, oidhw);
93     }
94
95     virtual status_t set_default_params() {
96         using namespace memory_format;
97         if (src_pd_.desc()->format == any)
98             CHECK(src_pd_.set_format(src_format()));
99         if (dst_pd_.desc()->format == any)
100             CHECK(dst_pd_.set_format(src_pd_.desc()->format));
101         if (weights_pd_.desc()->format == any)
102             CHECK(weights_pd_.set_format(wei_format()));
103         if (bias_pd_.desc()->format == any)
104             CHECK(bias_pd_.set_format(x));
105         if (this->desc()->alg_kind == alg_kind::convolution_auto)
106             CHECK(this->set_alg_kind(alg_kind::convolution_direct));
107         return status::success;
108     }
109 };
110
111 struct cpu_convolution_bwd_data_pd_t: public convolution_bwd_data_pd_t {
112     using cpu_memory_pd_t = cpu_memory_t::pd_t;
113
114     cpu_convolution_bwd_data_pd_t(engine_t *engine,
115             const convolution_desc_t *adesc,
116             const primitive_attr_t *attr,
117             const convolution_fwd_pd_t *hint_fwd_pd)
118         : convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd)
119         , diff_src_pd_(this->engine_, &this->desc_.diff_src_desc)
120         , diff_dst_pd_(this->engine_, &this->desc_.diff_dst_desc)
121         , weights_pd_(this->engine_, &this->desc_.weights_desc)
122         , bias_pd_(this->engine_, &this->desc_.bias_desc) {}
123     virtual ~cpu_convolution_bwd_data_pd_t() {}
124
125     virtual const cpu_memory_pd_t *diff_src_pd(int index = 0) const override
126     { return index == 0 ? &diff_src_pd_ : nullptr; }
127     virtual const cpu_memory_pd_t *diff_dst_pd(int index = 0) const override
128     { return index == 0 ? &diff_dst_pd_ : nullptr; }
129     virtual const cpu_memory_pd_t *weights_pd(int index = 0) const override {
130         if (index == 0) return &weights_pd_;
131         if (index == 1 && this->with_bias()) return &bias_pd_;
132         return nullptr;
133     }
134
135 protected:
136     cpu_memory_pd_t diff_src_pd_, diff_dst_pd_;
137     cpu_memory_pd_t weights_pd_, bias_pd_;
138
139     inline memory_format_t src_format()
140     {
141         using namespace memory_format;
142         return utils::pick(this->desc_.diff_src_desc.ndims - 3, ncw, nchw, ncdhw);
143     }
144     inline memory_format_t wei_format()
145     {
146         using namespace memory_format;
147         return this->with_groups()
148             ? utils::pick(this->desc_.diff_src_desc.ndims - 3, goiw, goihw, goidhw)
149             : utils::pick(this->desc_.diff_src_desc.ndims - 3, oiw, oihw, oidhw);
150     }
151
152     virtual status_t set_default_params() {
153         using namespace memory_format;
154         if (diff_src_pd_.desc()->format == any)
155             CHECK(diff_src_pd_.set_format(src_format()));
156         if (diff_dst_pd_.desc()->format == any)
157             CHECK(diff_dst_pd_.set_format(diff_src_pd_.desc()->format));
158         if (weights_pd_.desc()->format == any)
159            CHECK(weights_pd_.set_format(wei_format()));
160         if (bias_pd_.desc()->format == any)
161             CHECK(bias_pd_.set_format(x));
162         if (this->desc()->alg_kind == alg_kind::convolution_auto)
163             CHECK(this->set_alg_kind(alg_kind::convolution_direct));
164         return status::success;
165     }
166 };
167
168 struct cpu_convolution_bwd_weights_pd_t: public convolution_bwd_weights_pd_t {
169     using cpu_memory_pd_t = cpu_memory_t::pd_t;
170
171     cpu_convolution_bwd_weights_pd_t(engine_t *engine,
172             const convolution_desc_t *adesc,
173             const primitive_attr_t *attr,
174             const convolution_fwd_pd_t *hint_fwd_pd)
175         : convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd)
176         , src_pd_(this->engine_, &this->desc_.src_desc)
177         , diff_dst_pd_(this->engine_, &this->desc_.diff_dst_desc)
178         , diff_weights_pd_(this->engine_, &this->desc_.diff_weights_desc)
179         , diff_bias_pd_(this->engine_, &this->desc_.diff_bias_desc) {}
180     virtual ~cpu_convolution_bwd_weights_pd_t() {}
181
182     virtual const cpu_memory_pd_t *src_pd(int index = 0) const override
183     { return index == 0 ? &src_pd_ : nullptr; }
184     virtual const cpu_memory_pd_t *diff_dst_pd(int index = 0) const override
185     { return index == 0 ? &diff_dst_pd_ : nullptr; }
186     virtual const cpu_memory_pd_t *diff_weights_pd(int index = 0) const
187         override {
188             if (index == 0) return &diff_weights_pd_;
189             if (index == 1 && this->with_bias()) return &diff_bias_pd_;
190             return  nullptr;
191         }
192
193     bool wants_padded_bias() const {
194         if (!this->with_bias()) return false;
195         memory_desc_wrapper diff_dst_d(&diff_dst_pd_);
196         if (!diff_dst_d.is_blocking_desc()) return false;
197         return OC() != diff_dst_d.blocking_desc().padding_dims[1];
198     }
199
200 protected:
201     cpu_memory_pd_t src_pd_;
202     cpu_memory_pd_t diff_dst_pd_;
203     cpu_memory_pd_t diff_weights_pd_, diff_bias_pd_;
204
205     inline memory_format_t src_format()
206     {
207         using namespace memory_format;
208         return utils::pick(this->desc_.src_desc.ndims - 3, ncw, nchw, ncdhw);
209     }
210     inline memory_format_t wei_format()
211     {
212         using namespace memory_format;
213         return this->with_groups()
214             ? utils::pick(this->desc_.src_desc.ndims - 3, goiw, goihw, goidhw)
215             : utils::pick(this->desc_.src_desc.ndims - 3, oiw, oihw, oidhw);
216     }
217
218     virtual status_t set_default_params() {
219         using namespace memory_format;
220         if (src_pd_.desc()->format == any)
221             CHECK(src_pd_.set_format(src_format()));
222         if (diff_dst_pd_.desc()->format == any)
223             CHECK(diff_dst_pd_.set_format(src_format()));
224         if (diff_weights_pd_.desc()->format == any)
225             CHECK(diff_weights_pd_.set_format(wei_format()));
226         if (diff_bias_pd_.desc()->format == any)
227             CHECK(diff_bias_pd_.set_format(x));
228         if (this->desc()->alg_kind == alg_kind::convolution_auto)
229             CHECK(this->set_alg_kind(alg_kind::convolution_direct));
230         return status::success;
231     }
232 };
233
234 }
235 }
236 }
237
238 #endif
239
240 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s