1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef CPU_JIT_AVX512_COMMON_1x1_CONVOLUTION_HPP
18 #define CPU_JIT_AVX512_COMMON_1x1_CONVOLUTION_HPP
20 #include "c_types_map.hpp"
21 #include "memory_tracking.hpp"
22 #include "mkldnn_thread.hpp"
25 #include "cpu_convolution_pd.hpp"
26 #include "cpu_engine.hpp"
27 #include "cpu_reducer.hpp"
29 #include "jit_avx512_common_1x1_conv_kernel.hpp"
30 #include "jit_uni_1x1_conv_utils.hpp"
31 #include "jit_transpose_src_utils.hpp"
37 template <impl::data_type_t src_type,
38 impl::data_type_t wei_type = src_type,
39 impl::data_type_t dst_type = src_type>
40 struct jit_avx512_common_1x1_convolution_fwd_t : public cpu_primitive_t {
41 // TODO: (Roma) Code duplication duplication! Remove with templates
43 struct pd_t: public cpu_convolution_fwd_pd_t {
44 pd_t(engine_t *engine, const convolution_desc_t *adesc,
45 const primitive_attr_t *attr,
46 const typename pd_t::base_class *hint_fwd_pd)
47 : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
51 JIT_IMPL_NAME_HELPER("jit_1x1:", avx512_common, ""),
52 jit_avx512_common_1x1_convolution_fwd_t);
54 virtual status_t init() override {
55 using namespace prop_kind;
56 using namespace utils;
57 assert(this->engine()->kind() == engine_kind::cpu);
59 && this->set_default_params() == status::success
60 && utils::one_of(this->desc()->prop_kind, forward_training,
62 && utils::one_of(this->desc()->alg_kind,
63 alg_kind::convolution_auto,
64 alg_kind::convolution_direct)
65 && !this->has_zero_dim_memory()
66 && this->desc()->src_desc.data_type == src_type
67 && this->desc()->weights_desc.data_type == wei_type
68 && this->desc()->dst_desc.data_type == dst_type
69 && IMPLICATION(this->with_bias(),
70 dst_type == this->desc()->bias_desc.data_type);
71 if (!ok) return status::unimplemented;
73 const convolution_desc_t *conv_d = this->desc();
74 const memory_desc_t *src_d = this->src_pd_.desc();
75 rtus_prepare(this, conv_d, src_d, this->dst_pd_.desc());
77 status_t status = jit_avx512_common_1x1_conv_kernel::init_conf(
78 jcp_, *conv_d, *src_d, *this->weights_pd_.desc(),
79 *this->dst_pd_.desc(), *this->attr(),
80 mkldnn_get_max_threads(), rtus_.reduce_src_);
81 if (status != status::success) return status;
83 auto scratchpad = scratchpad_registry().registrar();
84 jit_avx512_common_1x1_conv_kernel::init_scratchpad(scratchpad,
87 rtus_prepare_space_info(this, scratchpad);
89 return status::success;
92 jit_1x1_conv_conf_t jcp_;
93 reduce_to_unit_stride_t rtus_;
96 virtual status_t set_default_params() override {
97 using namespace memory_format;
98 if (this->src_pd_.desc()->format == any)
99 CHECK(this->src_pd_.set_format(pick(this->ndims() - 3,
101 if (this->dst_pd_.desc()->format == any)
102 CHECK(this->dst_pd_.set_format(pick(this->ndims() - 3,
104 if (this->weights_pd_.desc()->format == any) {
105 if (dst_type == data_type::f32 && src_type == data_type::f32
106 && wei_type == data_type::f32)
107 CHECK(this->weights_pd_.set_format(this->with_groups()
108 ? pick(this->ndims() - 3, gOIw16i16o, gOIhw16i16o)
109 : pick(this->ndims() - 3, OIw16i16o, OIhw16i16o)));
110 else if (dst_type == data_type::s32
111 && src_type == data_type::s16
112 && wei_type == data_type::s16)
113 CHECK(this->weights_pd_.set_format(this->with_groups()
114 ? pick(this->ndims() - 3, gOIw8i16o2i, gOIhw8i16o2i)
115 : pick(this->ndims() - 3, OIw8i16o2i, OIhw8i16o2i)));
117 if (this->bias_pd_.desc()->format == any)
118 CHECK(this->bias_pd_.set_format(x));
119 if (this->desc()->alg_kind == alg_kind::convolution_auto)
120 CHECK(this->set_alg_kind(alg_kind::convolution_direct));
121 return status::success;
125 template <cpu_isa_t isa, typename conv_t>
126 friend void init_rtus_driver(conv_t *self);
128 jit_avx512_common_1x1_convolution_fwd_t(const pd_t *apd,
129 const input_vector &inputs, const output_vector &outputs)
130 : cpu_primitive_t(apd, inputs, outputs)
131 , kernel_(nullptr), rtus_driver_(nullptr)
134 new jit_avx512_common_1x1_conv_kernel(pd()->jcp_, *pd()->attr());
135 init_rtus_driver<avx512_common>(this);
138 ~jit_avx512_common_1x1_convolution_fwd_t() {
143 typedef typename prec_traits<src_type>::type src_data_t;
144 typedef typename prec_traits<wei_type>::type wei_data_t;
145 typedef typename prec_traits<dst_type>::type dst_data_t;
147 virtual void execute(event_t *e) const {
149 e->set_state(event_t::ready);
153 void execute_forward() const;
154 void execute_forward_thr(const int ithr, const int nthr,
155 const src_data_t *src, const wei_data_t *weights,
156 const dst_data_t *bias, dst_data_t *dst,
157 const memory_tracking::grantor_t &scratchpad) const;
158 const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
160 jit_avx512_common_1x1_conv_kernel *kernel_;
161 rtus_driver_t<avx512_common> *rtus_driver_;
164 using jit_avx512_common_1x1_convolution_fwd_f32_t
165 = jit_avx512_common_1x1_convolution_fwd_t<data_type::f32>;
166 using jit_avx512_common_1x1_convolution_fwd_s16s16s32_t
167 = jit_avx512_common_1x1_convolution_fwd_t<data_type::s16,
168 data_type::s16, data_type::s32>;
170 template <impl::data_type_t diff_dst_type,
171 impl::data_type_t wei_type = diff_dst_type,
172 impl::data_type_t diff_src_type = diff_dst_type>
173 struct jit_avx512_common_1x1_convolution_bwd_data_t : public cpu_primitive_t {
174 struct pd_t : public cpu_convolution_bwd_data_pd_t {
175 pd_t(engine_t *engine,
176 const convolution_desc_t *adesc,
177 const primitive_attr_t *attr,
178 const convolution_fwd_pd_t *hint_fwd_pd)
179 : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd)
183 JIT_IMPL_NAME_HELPER("jit_1x1:", avx512_common, ""),
184 jit_avx512_common_1x1_convolution_bwd_data_t);
186 virtual status_t init() override {
187 using namespace prop_kind;
188 assert(this->engine()->kind() == engine_kind::cpu);
190 && this->set_default_params() == status::success
191 && this->desc()->prop_kind == backward_data
192 && utils::one_of(this->desc()->alg_kind, alg_kind::convolution_auto,
193 alg_kind::convolution_direct)
194 && !this->has_zero_dim_memory()
195 && this->desc()->diff_dst_desc.data_type == diff_dst_type
196 && this->desc()->weights_desc.data_type == wei_type
197 && this->desc()->diff_src_desc.data_type == diff_src_type;
198 if (!ok) return status::unimplemented;
200 const convolution_desc_t *conv_d = this->desc();
201 const memory_desc_t *diff_src_d = this->diff_src_pd_.desc();
202 rtus_prepare(this, conv_d, diff_src_d, this->diff_dst_pd_.desc());
204 status_t status = jit_avx512_common_1x1_conv_kernel::init_conf(
205 jcp_, *conv_d, *diff_src_d, *this->weights_pd_.desc(),
206 *this->diff_dst_pd_.desc(), *this->attr(),
207 mkldnn_get_max_threads(), rtus_.reduce_src_);
208 if (status != status::success) return status;
210 auto scratchpad = scratchpad_registry().registrar();
211 jit_avx512_common_1x1_conv_kernel::init_scratchpad(scratchpad,
214 rtus_prepare_space_info(this, scratchpad);
216 return status::success;
219 // TODO (Roma): structs conf header cleanup
220 jit_1x1_conv_conf_t jcp_;
221 reduce_to_unit_stride_t rtus_;
224 virtual status_t set_default_params() override {
225 using namespace memory_format;
227 if (this->diff_src_pd_.desc()->format == any)
228 CHECK(this->diff_src_pd_.set_format(pick(this->ndims() - 3,
230 if (this->diff_dst_pd_.desc()->format == any)
231 CHECK(this->diff_dst_pd_.set_format(pick(this->ndims() - 3,
233 if (this->weights_pd_.desc()->format == any) {
234 if (diff_dst_type == data_type::f32
235 && diff_src_type == data_type::f32
236 && wei_type == data_type::f32) {
237 CHECK(this->weights_pd_.set_format(this->with_groups()
238 ? pick(this->ndims() - 3, gIOw16o16i, gIOhw16o16i)
239 : pick(this->ndims() - 3, IOw16o16i, IOhw16o16i)));
241 else if (diff_dst_type == data_type::s16
242 && diff_src_type == data_type::s32
243 && wei_type == data_type::s16)
244 CHECK(this->weights_pd_.set_format(this->with_groups()
245 ? pick(this->ndims() - 3, gOIw8o16i2o, gOIhw8o16i2o)
246 : pick(this->ndims() - 3, OIw8o16i2o, OIhw8o16i2o)));
248 if (this->desc()->alg_kind == alg_kind::convolution_auto)
249 CHECK(this->set_alg_kind(alg_kind::convolution_direct));
251 return status::success;
255 template <cpu_isa_t isa, typename conv_t>
256 friend void init_rtus_driver(conv_t *self);
258 jit_avx512_common_1x1_convolution_bwd_data_t(const pd_t *apd,
259 const input_vector &inputs, const output_vector &outputs)
260 : cpu_primitive_t(apd, inputs, outputs)
261 , kernel_(nullptr), rtus_driver_(nullptr)
263 kernel_ = new jit_avx512_common_1x1_conv_kernel(pd()->jcp_,
265 init_rtus_driver<avx512_common>(this);
268 ~jit_avx512_common_1x1_convolution_bwd_data_t() {
273 typedef typename prec_traits<diff_dst_type>::type diff_dst_data_t;
274 typedef typename prec_traits<wei_type>::type wei_data_t;
275 typedef typename prec_traits<diff_src_type>::type diff_src_data_t;
277 virtual void execute(event_t *e) const {
278 switch (pd()->desc()->prop_kind) {
279 case prop_kind::backward_data:
280 execute_backward_data();
283 assert(!"invalid prop_kind");
285 e->set_state(event_t::ready);
289 void execute_backward_data() const;
290 const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
292 jit_avx512_common_1x1_conv_kernel *kernel_;
293 rtus_driver_t<avx512_common> *rtus_driver_;
296 using jit_avx512_common_1x1_convolution_bwd_data_f32_t
297 = jit_avx512_common_1x1_convolution_bwd_data_t<data_type::f32>;
298 using jit_avx512_common_1x1_convolution_bwd_data_s16s16s32_t
299 = jit_avx512_common_1x1_convolution_bwd_data_t<data_type::s16,
300 data_type::s16, data_type::s32>;
302 struct jit_avx512_common_1x1_convolution_bwd_weights_t : public cpu_primitive_t
304 struct pd_t : public cpu_convolution_bwd_weights_pd_t {
305 pd_t(engine_t *engine,
306 const convolution_desc_t *adesc,
307 const primitive_attr_t *attr,
308 const convolution_fwd_pd_t *hint_fwd_pd)
309 : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd)
313 JIT_IMPL_NAME_HELPER("jit_1x1:", avx512_common, ""),
314 jit_avx512_common_1x1_convolution_bwd_weights_t);
316 virtual status_t init() override {
317 using namespace prop_kind;
318 assert(this->engine()->kind() == engine_kind::cpu);
320 && this->set_default_params() == status::success
321 && this->desc()->prop_kind == backward_weights
322 && utils::one_of(this->desc()->alg_kind,
323 alg_kind::convolution_auto,
324 alg_kind::convolution_direct)
325 && !this->has_zero_dim_memory()
326 && utils::everyone_is(data_type::f32,
327 this->desc()->src_desc.data_type,
328 this->desc()->diff_weights_desc.data_type,
329 this->desc()->diff_dst_desc.data_type)
330 && IMPLICATION(this->with_bias(),
331 data_type::f32 == desc()->diff_bias_desc.data_type);
332 if (!ok) return status::unimplemented;
334 const convolution_desc_t *conv_d = this->desc();
335 const memory_desc_t *src_d = this->src_pd_.desc();
336 rtus_prepare(this, conv_d, src_d, this->diff_dst_pd_.desc());
338 status_t status = jit_avx512_common_1x1_conv_kernel::init_conf(
339 jcp_, *conv_d, *src_d, *this->diff_weights_pd_.desc(),
340 *this->diff_dst_pd_.desc(), *this->attr(),
341 mkldnn_get_max_threads(), rtus_.reduce_src_);
342 if (status != status::success) return status;
346 auto scratchpad = scratchpad_registry().registrar();
347 jit_avx512_common_1x1_conv_kernel::init_scratchpad(scratchpad,
350 auto reducer_bia_scratchpad = memory_tracking::registrar_t(
351 scratchpad, memory_tracking::names::prefix_reducer_bia);
352 reducer_bia_conf_.init_scratchpad(reducer_bia_scratchpad);
354 rtus_prepare_space_info(this, scratchpad);
356 return status::success;
359 // TODO (Roma): structs conf header cleanup
360 jit_1x1_conv_conf_t jcp_;
361 cpu_reducer_t<data_type::f32>::conf_t reducer_bia_conf_;
362 reduce_to_unit_stride_t rtus_;
365 virtual status_t set_default_params() override {
366 using namespace memory_format;
368 if (this->src_pd_.desc()->format == any)
369 CHECK(this->src_pd_.set_format(pick(this->ndims() - 3,
371 if (this->diff_dst_pd_.desc()->format == any)
372 CHECK(this->diff_dst_pd_.set_format(pick(this->ndims() - 3,
374 if (this->diff_weights_pd_.desc()->format == any)
375 CHECK(this->diff_weights_pd_.set_format(this->with_groups()
376 ? pick(this->ndims() - 3, gOIw16i16o, gOIhw16i16o)
377 : pick(this->ndims() - 3, OIw16i16o, OIhw16i16o)));
378 if (this->diff_bias_pd_.desc()->format == any)
379 CHECK(this->diff_bias_pd_.set_format(x));
380 if (this->desc()->alg_kind == alg_kind::convolution_auto)
381 CHECK(this->set_alg_kind(alg_kind::convolution_direct));
382 return status::success;
386 void init_balancers() {
387 const size_t max_buffer_size = jcp_.nthr * 3 * 5 * 5 * 16 * 16;
389 reducer_bia_conf_.init(reduce_balancer_t(jcp_.nthr,
390 jcp_.oc_block, jcp_.ngroups * jcp_.nb_load,
391 jcp_.mb, max_buffer_size));
396 template <cpu_isa_t isa, typename conv_t>
397 friend void init_rtus_driver(conv_t *self);
399 jit_avx512_common_1x1_convolution_bwd_weights_t(const pd_t *apd,
400 const input_vector &inputs, const output_vector &outputs);
402 ~jit_avx512_common_1x1_convolution_bwd_weights_t() {
405 delete reducer_bias_;
407 delete trans_kernel_;
410 typedef typename prec_traits<data_type::f32>::type data_t;
412 virtual void execute(event_t *e) const {
413 switch (pd()->desc()->prop_kind) {
414 case prop_kind::backward_weights:
415 execute_backward_weights();
418 assert(!"invalid prop_kind");
420 e->set_state(event_t::ready);
424 void execute_backward_weights() const;
425 const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
427 jit_avx512_common_1x1_conv_kernel *kernel_;
428 cpu_accumulator_1d_t<data_type::f32> *acc_ker_;
429 cpu_reducer_t<data_type::f32> *reducer_bias_;
430 jit_transpose4x16_src *trans_kernel_;
431 rtus_driver_t<avx512_common> *rtus_driver_;