1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef CPU_JIT_AVX2_CONVOLUTION_HPP
18 #define CPU_JIT_AVX2_CONVOLUTION_HPP
20 #include "c_types_map.hpp"
21 #include "memory_tracking.hpp"
22 #include "mkldnn_thread.hpp"
25 #include "cpu_convolution_pd.hpp"
26 #include "cpu_reducer.hpp"
28 #include "jit_avx2_conv_kernel_f32.hpp"
29 #include "jit_uni_depthwise.hpp"
35 struct jit_avx2_convolution_fwd_t: public cpu_primitive_t {
36 struct pd_t: public cpu_convolution_fwd_pd_t {
37 pd_t(engine_t *engine,
38 const convolution_desc_t *adesc,
39 const primitive_attr_t *attr,
40 const typename pd_t::base_class *hint_fwd_pd)
41 : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
42 , jcp_(), jcp_dw_() {}
45 JIT_IMPL_NAME_HELPER("jit:", avx2, ""),
46 jit_avx2_convolution_fwd_t);
48 virtual status_t init() override {
49 using namespace prop_kind;
50 assert(this->engine()->kind() == engine_kind::cpu);
52 && this->set_default_params() == status::success
53 && utils::one_of(this->desc()->prop_kind, forward_training,
55 && utils::one_of(this->desc()->alg_kind,
56 alg_kind::convolution_auto,
57 alg_kind::convolution_direct)
58 && !this->has_zero_dim_memory()
59 && utils::everyone_is(data_type::f32,
60 this->desc()->src_desc.data_type,
61 this->desc()->weights_desc.data_type,
62 this->desc()->dst_desc.data_type)
63 && IMPLICATION(this->with_bias(),
64 data_type::f32 == this->desc()->bias_desc.data_type);
65 if (!ok) return status::unimplemented;
69 status_t sts = jit_avx2_conv_fwd_kernel_f32::init_conf(jcp_,
70 *this->desc(), *this->src_pd_.desc(),
71 *this->weights_pd_.desc(), *this->dst_pd_.desc(),
73 if (sts != status::success) return sts;
75 if (jcp_.with_dw_conv) {
76 status_t sts_dw = jit_uni_dw_conv_row_f32<avx2>::init_conf(jcp_, jcp_dw_, *this->attr());
77 if (sts_dw != status::success) return sts_dw;
80 auto scratchpad = scratchpad_registry().registrar();
81 jit_avx2_conv_fwd_kernel_f32::init_scratchpad(scratchpad, jcp_, jcp_dw_);
83 return status::success;
87 jit_conv_conf_t jcp_dw_;
90 virtual status_t set_default_params() override {
91 using namespace memory_format;
94 const bool flat = this->IC() < simd_w;
95 if (this->src_pd_.desc()->format == any)
96 CHECK(this->src_pd_.set_format(flat
97 ? utils::pick(this->ndims() - 3, ncw, nchw, ncdhw)
98 : utils::pick(this->ndims() - 3, nCw8c, nChw8c, nCdhw8c)));
99 if (this->dst_pd_.desc()->format == any)
100 CHECK(this->dst_pd_.set_format(
101 utils::pick(this->ndims() - 3, nCw8c, nChw8c, nCdhw8c)));
102 if (this->weights_pd_.desc()->format == any)
103 CHECK(this->weights_pd_.set_format(this->with_groups()
104 ? utils::pick(2 * this->ndims() - 6 + flat, gOIw8i8o,
105 gOwi8o, gOIhw8i8o, gOhwi8o, gOIdhw8i8o, gOdhwi8o)
106 : utils::pick(2 * this->ndims() - 6 + flat, OIw8i8o, Owi8o,
107 OIhw8i8o, Ohwi8o, OIdhw8i8o, Odhwi8o)));
109 if (this->bias_pd_.desc()->format == any)
110 CHECK(this->bias_pd_.set_format(x));
111 if (this->desc()->alg_kind == alg_kind::convolution_auto)
112 CHECK(this->set_alg_kind(alg_kind::convolution_direct));
113 return status::success;
117 jit_avx2_convolution_fwd_t(const pd_t *apd, const input_vector &inputs,
118 const output_vector &outputs)
119 : cpu_primitive_t(apd, inputs, outputs)
121 kernel_ = new jit_avx2_conv_fwd_kernel_f32(pd()->jcp_, pd()->jcp_dw_, *pd()->attr());
123 if (pd()->jcp_.with_dw_conv) {
124 kernel_dw_ = new jit_uni_dw_conv_row_f32<avx2>(pd()->jcp_dw_, *pd()->attr(), pd()->jcp_dw_.ch_block);
128 ~jit_avx2_convolution_fwd_t() {
131 if (pd()->jcp_.with_dw_conv) {
136 typedef typename prec_traits<data_type::f32>::type data_t;
138 virtual void execute(event_t *e) const {
139 if (pd()->jcp_.with_dw_conv)
140 execute_forward_with_dw_conv();
144 e->set_state(event_t::ready);
148 void execute_forward() const;
149 void execute_forward_with_dw_conv() const;
150 const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
152 jit_avx2_conv_fwd_kernel_f32 *kernel_;
153 jit_uni_dw_conv_row_f32<avx2> *kernel_dw_;
156 struct jit_avx2_convolution_bwd_data_t: public cpu_primitive_t {
157 struct pd_t: public cpu_convolution_bwd_data_pd_t {
158 pd_t(engine_t *engine,
159 const convolution_desc_t *adesc,
160 const primitive_attr_t *attr,
161 const convolution_fwd_pd_t *hint_fwd_pd)
162 : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd)
167 JIT_IMPL_NAME_HELPER("jit:", avx2, ""),
168 jit_avx2_convolution_bwd_data_t);
170 virtual status_t init() override {
171 using namespace prop_kind;
172 assert(this->engine()->kind() == engine_kind::cpu);
174 && this->set_default_params() == status::success
175 && utils::one_of(this->desc()->prop_kind, backward_data)
176 && utils::one_of(this->desc()->alg_kind, alg_kind::convolution_auto,
177 alg_kind::convolution_direct)
178 && !this->has_zero_dim_memory()
179 && utils::everyone_is(data_type::f32,
180 this->desc()->diff_src_desc.data_type,
181 this->desc()->weights_desc.data_type,
182 this->desc()->diff_dst_desc.data_type);
183 if (!ok) return status::unimplemented;
185 status_t status = jit_avx2_conv_bwd_data_kernel_f32::init_conf(
186 jcp_, *this->desc(), *this->diff_src_pd_.desc(),
187 *this->weights_pd_.desc(), *this->diff_dst_pd_.desc());
188 if (status != status::success) return status;
190 auto scratchpad = scratchpad_registry().registrar();
191 jit_avx2_conv_bwd_data_kernel_f32::init_scratchpad(scratchpad,
194 return status::success;
197 jit_conv_conf_t jcp_;
200 virtual status_t set_default_params() override {
201 using namespace memory_format;
203 if (this->diff_src_pd_.desc()->format == any)
204 CHECK(this->diff_src_pd_.set_format(
205 utils::pick(this->ndims() - 3, nCw8c, nChw8c, nCdhw8c)));
206 if (this->diff_dst_pd_.desc()->format == any)
207 CHECK(this->diff_dst_pd_.set_format(
208 utils::pick(this->ndims() - 3, nCw8c, nChw8c, nCdhw8c)));
209 if (this->weights_pd_.desc()->format == any)
210 CHECK(this->weights_pd_.set_format(this->with_groups()
211 ? utils::pick(this->ndims() - 3, gOIw8o8i, gOIhw8o8i,
213 : utils::pick(this->ndims() - 3, OIw8o8i, OIhw8o8i,
215 if (this->desc()->alg_kind == alg_kind::convolution_auto)
216 CHECK(this->set_alg_kind(alg_kind::convolution_direct));
217 return status::success;
221 jit_avx2_convolution_bwd_data_t(const pd_t *apd, const input_vector &inputs,
222 const output_vector &outputs)
223 : cpu_primitive_t(apd, inputs, outputs)
224 { kernel_ = new jit_avx2_conv_bwd_data_kernel_f32(pd()->jcp_); }
225 ~jit_avx2_convolution_bwd_data_t() { delete kernel_; }
227 typedef typename prec_traits<data_type::f32>::type data_t;
229 virtual void execute(event_t *e) const {
230 switch (pd()->desc()->prop_kind) {
231 case prop_kind::backward_data:
232 execute_backward_data();
235 assert(!"invalid prop_kind");
237 e->set_state(event_t::ready);
241 void execute_backward_data() const;
242 const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
244 jit_avx2_conv_bwd_data_kernel_f32 *kernel_;
247 struct jit_avx2_convolution_bwd_weights_t: public cpu_primitive_t {
248 struct pd_t: public cpu_convolution_bwd_weights_pd_t {
249 pd_t(engine_t *engine, const convolution_desc_t *adesc,
250 const primitive_attr_t *attr,
251 const convolution_fwd_pd_t *hint_fwd_pd)
252 : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd)
256 JIT_IMPL_NAME_HELPER("jit:", avx2, ""),
257 jit_avx2_convolution_bwd_weights_t);
259 virtual status_t init() override {
260 assert(this->engine()->kind() == engine_kind::cpu);
262 && this->set_default_params() == status::success
263 && this->desc()->prop_kind == prop_kind::backward_weights
264 && utils::one_of(this->desc()->alg_kind, alg_kind::convolution_auto,
265 alg_kind::convolution_direct)
266 && !this->has_zero_dim_memory()
267 && utils::everyone_is(data_type::f32,
268 this->desc()->src_desc.data_type,
269 this->desc()->diff_dst_desc.data_type,
270 this->desc()->diff_weights_desc.data_type);
271 if (!ok) return status::unimplemented;
273 status_t status = jit_avx2_conv_bwd_weights_kernel_f32::init_conf(
274 jcp_, *this->desc(), *this->src_pd_.desc(),
275 *this->diff_weights_pd_.desc(),
276 *this->diff_dst_pd_.desc());
277 if (status != status::success) return status;
281 auto scratchpad = scratchpad_registry().registrar();
282 jit_avx2_conv_bwd_weights_kernel_f32::init_scratchpad(scratchpad,
285 auto reducer_bia_scratchpad = memory_tracking::registrar_t(
286 scratchpad, memory_tracking::names::prefix_reducer_bia);
287 reducer_bia_conf_.init_scratchpad(reducer_bia_scratchpad);
289 auto reducer_wei_scratchpad = memory_tracking::registrar_t(
290 scratchpad, memory_tracking::names::prefix_reducer_wei);
291 reducer_wei_conf_.init_scratchpad(reducer_wei_scratchpad);
293 return status::success;
296 jit_conv_conf_t jcp_;
297 cpu_reducer_t<data_type::f32>::conf_t reducer_bia_conf_;
298 cpu_reducer_t<data_type::f32>::conf_t reducer_wei_conf_;
301 virtual status_t set_default_params() override {
302 using namespace memory_format;
303 const bool flat = this->IC() == 3;
305 if (this->src_pd_.desc()->format == any)
306 CHECK(this->src_pd_.set_format(flat
307 ? utils::pick(this->ndims() - 3, ncw, nchw, ncdhw)
308 : utils::pick(this->ndims() - 3, nCw8c, nChw8c, nCdhw8c)));
309 if (this->diff_dst_pd_.desc()->format == any)
310 CHECK(this->diff_dst_pd_.set_format(
311 utils::pick(this->ndims() - 3, nCw8c, nChw8c, nCdhw8c)));
312 if (this->diff_weights_pd_.desc()->format == any)
313 CHECK(this->diff_weights_pd_.set_format(this->with_groups()
314 ? utils::pick(2 * this->ndims() - 6 + flat, gOIw8i8o,
315 gOwi8o, gOIhw8i8o, gOhwi8o, gOIdhw8i8o, gOdhwi8o)
316 : utils::pick(2 * this->ndims() - 6 + flat, OIw8i8o, Owi8o,
317 OIhw8i8o, Ohwi8o, OIdhw8i8o, Odhwi8o)));
318 if (this->diff_bias_pd_.desc()->format == any)
319 CHECK(this->diff_bias_pd_.set_format(x));
320 if (this->desc()->alg_kind == alg_kind::convolution_auto)
321 CHECK(this->set_alg_kind(alg_kind::convolution_direct));
322 return status::success;
326 void init_balancers() {
327 const int max_threads = mkldnn_get_max_threads();
328 const size_t max_buffer_size = 1<<21; /* just a heuristic */
331 reducer_bia_conf_.init(reduce_balancer_t(max_threads,
332 jcp_.oc_block, jcp_.ngroups * jcp_.nb_oc, jcp_.mb,
336 reducer_wei_conf_.init(reduce_balancer_t(max_threads,
337 jcp_.kd * jcp_.kh * jcp_.kw
338 * jcp_.ic_block * jcp_.oc_block,
339 jcp_.ngroups * jcp_.nb_ic * jcp_.nb_oc,
340 jcp_.mb * jcp_.od, max_buffer_size));
344 jit_avx2_convolution_bwd_weights_t(const pd_t *apd,
345 const input_vector &inputs, const output_vector &outputs)
346 : cpu_primitive_t(apd, inputs, outputs)
347 , kernel_(nullptr), reducer_weights_(nullptr), reducer_bias_(nullptr)
349 kernel_ = new jit_avx2_conv_bwd_weights_kernel_f32(pd()->jcp_);
351 new cpu_reducer_t<data_type::f32>(pd()->reducer_bia_conf_);
353 new cpu_reducer_t<data_type::f32>(pd()->reducer_wei_conf_);
356 ~jit_avx2_convolution_bwd_weights_t() {
358 delete reducer_weights_;
359 delete reducer_bias_;
362 typedef typename prec_traits<data_type::f32>::type data_t;
364 virtual void execute(event_t *e) const {
365 execute_backward_weights();
366 e->set_state(event_t::ready);
370 void execute_backward_weights() const;
371 const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
373 jit_avx2_conv_bwd_weights_kernel_f32 *kernel_;
374 cpu_reducer_t<data_type::f32> *reducer_weights_, *reducer_bias_;
383 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s