1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef CPU_JIT_AVX512_COMMON_1x1_CONVOLUTION_HPP
18 #define CPU_JIT_AVX512_COMMON_1x1_CONVOLUTION_HPP
20 #include "c_types_map.hpp"
21 #include "cpu_convolution_pd.hpp"
22 #include "cpu_engine.hpp"
23 #include "cpu_reducer.hpp"
24 #include "jit_avx512_common_1x1_conv_kernel.hpp"
25 #include "jit_uni_1x1_conv_utils.hpp"
26 #include "jit_transpose_src_utils.hpp"
27 #include "mkldnn_thread.hpp"
34 template <bool with_relu, impl::data_type_t src_type,
35 impl::data_type_t wei_type = src_type,
36 impl::data_type_t dst_type = src_type>
37 struct _jit_avx512_common_1x1_convolution_fwd_t : public cpu_primitive_t {
38 // TODO: (Roma) Code duplication duplication! Remove with templates
40 struct pd_t: public _cpu_convolution_fwd_pd_t<with_relu> {
41 pd_t(engine_t *engine,
42 const typename pd_t::base_desc_t *adesc,
43 const primitive_attr_t *attr,
44 const typename pd_t::base_class *hint_fwd_pd)
45 : _cpu_convolution_fwd_pd_t<with_relu>(engine, adesc, attr,
50 JIT_IMPL_NAME_HELPER("jit_1x1:", avx512_common, ""),
51 _jit_avx512_common_1x1_convolution_fwd_t);
53 virtual status_t init() override {
54 using namespace prop_kind;
55 using namespace utils;
56 assert(this->engine()->kind() == engine_kind::cpu);
58 && this->set_default_params() == status::success
59 && utils::one_of(this->cdesc_().prop_kind, forward_training,
61 && this->cdesc_().alg_kind == alg_kind::convolution_direct
62 && this->cdesc_().src_desc.data_type == src_type
63 && this->cdesc_().weights_desc.data_type == wei_type
64 && this->cdesc_().dst_desc.data_type == dst_type
65 && implication(this->with_bias(),
66 dst_type == this->cdesc_().bias_desc.data_type)
67 && implication(with_relu && dst_type == data_type::s32
68 && everyone_is(data_type::s16, src_type, wei_type),
69 this->negative_slope() == 0.);
70 if (!ok) return status::unimplemented;
72 const convolution_desc_t *conv_d = &this->cdesc_();
73 const memory_desc_t *src_d = this->src_pd_.desc();
74 rtus_prepare(this, conv_d, src_d, this->dst_pd_.desc());
75 return jit_avx512_common_1x1_conv_kernel::init_conf(jcp_,
76 *conv_d, *src_d, *this->weights_pd_.desc(),
77 *this->dst_pd_.desc(), *this->attr(),
78 with_relu, this->negative_slope(),
79 omp_get_max_threads(), rtus_.reduce_src_);
82 jit_1x1_conv_conf_t jcp_;
83 struct reduce_to_unit_stride_t {
84 convolution_desc_t conv_d_;
89 virtual status_t set_default_params() override {
90 using namespace memory_format;
91 if (this->src_pd_.desc()->format == any)
92 CHECK(this->src_pd_.set_format(nChw16c));
93 if (this->dst_pd_.desc()->format == any)
94 CHECK(this->dst_pd_.set_format(nChw16c));
95 if (this->weights_pd_.desc()->format == any) {
96 if (dst_type == data_type::f32 && src_type == data_type::f32
97 && wei_type == data_type::f32)
98 CHECK(this->weights_pd_.set_format(this->with_groups()
99 ? gOIhw16i16o : OIhw16i16o));
100 else if (dst_type == data_type::s32
101 && src_type == data_type::s16
102 && wei_type == data_type::s16)
103 CHECK(this->weights_pd_.set_format(this->with_groups()
104 ? gOIhw8i16o2i : OIhw8i16o2i));
106 if (this->bias_pd_.desc()->format == any)
107 CHECK(this->bias_pd_.set_format(x));
108 return status::success;
112 template <cpu_isa_t isa, typename conv_t>
113 friend void init_rtus_driver(conv_t *self);
114 _jit_avx512_common_1x1_convolution_fwd_t(const pd_t *pd,
115 const input_vector &inputs,
116 const output_vector &outputs)
117 : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd)
118 , kernel_(nullptr), rtus_driver_(nullptr), ws_per_thread_(0)
119 , scratch_(nullptr), padded_bias_(nullptr)
121 kernel_ = new jit_avx512_common_1x1_conv_kernel(conf_.jcp_,
124 init_rtus_driver<avx512_common>(this);
126 if (conf_.want_padded_bias()) {
127 const auto &j = conf_.jcp_;
128 assert(j.ngroups == 1);
129 padded_bias_ = (dst_data_t *)malloc(sizeof(dst_data_t) * j.oc, 64);
130 for (int oc = j.oc_without_padding; oc < j.oc; ++oc)
131 padded_bias_[oc] = 0;
135 ~_jit_avx512_common_1x1_convolution_fwd_t() {
142 typedef typename prec_traits<src_type>::type src_data_t;
143 typedef typename prec_traits<wei_type>::type wei_data_t;
144 typedef typename prec_traits<dst_type>::type dst_data_t;
146 virtual void execute(event_t *e) {
148 e->set_state(event_t::ready);
152 void execute_forward();
154 jit_avx512_common_1x1_conv_kernel *kernel_;
155 /* reduction to unit stride */
156 rtus_driver_t<avx512_common> *rtus_driver_;
157 size_t ws_per_thread_;
158 src_data_t *scratch_;
159 dst_data_t *padded_bias_;
162 using jit_avx512_common_1x1_convolution_fwd_f32_t
163 = _jit_avx512_common_1x1_convolution_fwd_t<false, data_type::f32>;
164 using jit_avx512_common_1x1_convolution_relu_f32_t
165 = _jit_avx512_common_1x1_convolution_fwd_t<true, data_type::f32>;
166 using jit_avx512_common_1x1_convolution_fwd_s16s16s32_t
167 = _jit_avx512_common_1x1_convolution_fwd_t<false, data_type::s16,
168 data_type::s16, data_type::s32>;
169 using jit_avx512_common_1x1_convolution_relu_s16s16s32_t
170 = _jit_avx512_common_1x1_convolution_fwd_t<true, data_type::s16,
171 data_type::s16, data_type::s32>;
173 template <impl::data_type_t diff_dst_type,
174 impl::data_type_t wei_type = diff_dst_type,
175 impl::data_type_t diff_src_type = diff_dst_type>
176 struct _jit_avx512_common_1x1_convolution_bwd_data_t : public cpu_primitive_t {
177 struct pd_t : public cpu_convolution_bwd_data_pd_t {
178 pd_t(engine_t *engine,
179 const convolution_desc_t *adesc,
180 const primitive_attr_t *attr,
181 const convolution_fwd_pd_t *hint_fwd_pd)
182 : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd)
186 JIT_IMPL_NAME_HELPER("jit_1x1:", avx512_common, ""),
187 _jit_avx512_common_1x1_convolution_bwd_data_t);
189 virtual status_t init() override {
190 using namespace prop_kind;
191 assert(this->engine()->kind() == engine_kind::cpu);
193 && this->set_default_params() == status::success
194 && this->desc()->prop_kind == backward_data
195 && this->desc()->alg_kind == alg_kind::convolution_direct
196 && this->desc()->diff_dst_desc.data_type == diff_dst_type
197 && this->desc()->weights_desc.data_type == wei_type
198 && this->desc()->diff_src_desc.data_type == diff_src_type;
199 if (!ok) return status::unimplemented;
201 const convolution_desc_t *conv_d = this->desc();
202 const memory_desc_t *diff_src_d = this->diff_src_pd_.desc();
203 rtus_prepare(this, conv_d, diff_src_d, this->diff_dst_pd_.desc());
204 return jit_avx512_common_1x1_conv_kernel::init_conf(jcp_,
205 *conv_d, *diff_src_d, *this->weights_pd_.desc(),
206 *this->diff_dst_pd_.desc(), *this->attr(),
207 omp_get_max_threads(), rtus_.reduce_src_);
210 // TODO (Roma): structs conf header cleanup
211 jit_1x1_conv_conf_t jcp_;
212 struct reduce_to_unit_stride_t {
213 convolution_desc_t conv_d_;
218 virtual status_t set_default_params() override {
219 using namespace memory_format;
221 if (this->diff_src_pd_.desc()->format == any)
222 CHECK(this->diff_src_pd_.set_format(nChw16c));
223 if (this->diff_dst_pd_.desc()->format == any)
224 CHECK(this->diff_dst_pd_.set_format(nChw16c));
225 if (this->weights_pd_.desc()->format == any) {
226 if (diff_dst_type == data_type::f32
227 && diff_src_type == data_type::f32
228 && wei_type == data_type::f32) {
229 CHECK(this->weights_pd_.set_format(this->with_groups()
230 ? gIOhw16o16i : IOhw16o16i));
232 else if (diff_dst_type == data_type::s16
233 && diff_src_type == data_type::s32
234 && wei_type == data_type::s16)
235 CHECK(this->weights_pd_.set_format(this->with_groups()
236 ? gOIhw8o16i2o : OIhw8o16i2o));
239 return status::success;
243 template <cpu_isa_t isa, typename conv_t>
244 friend void init_rtus_driver(conv_t *self);
245 _jit_avx512_common_1x1_convolution_bwd_data_t(const pd_t *pd,
246 const input_vector &inputs,
247 const output_vector &outputs)
248 : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd)
249 , kernel_(nullptr), rtus_driver_(nullptr), ws_per_thread_(0)
252 kernel_ = new jit_avx512_common_1x1_conv_kernel(conf_.jcp_,
254 init_rtus_driver<avx512_common>(this);
256 ~_jit_avx512_common_1x1_convolution_bwd_data_t()
263 typedef typename prec_traits<diff_dst_type>::type diff_dst_data_t;
264 typedef typename prec_traits<wei_type>::type wei_data_t;
265 typedef typename prec_traits<diff_src_type>::type diff_src_data_t;
267 virtual void execute(event_t *e) {
268 switch (conf_.desc()->prop_kind) {
269 case prop_kind::backward_data:
270 execute_backward_data();
273 assert(!"invalid prop_kind");
275 e->set_state(event_t::ready);
279 void execute_backward_data();
281 jit_avx512_common_1x1_conv_kernel *kernel_;
282 /* reduction to unit stride */
283 rtus_driver_t<avx512_common> *rtus_driver_;
284 size_t ws_per_thread_;
285 diff_src_data_t *scratch_;
288 using jit_avx512_common_1x1_convolution_bwd_data_f32_t
289 = _jit_avx512_common_1x1_convolution_bwd_data_t<data_type::f32>;
290 using jit_avx512_common_1x1_convolution_bwd_data_s16s16s32_t
291 = _jit_avx512_common_1x1_convolution_bwd_data_t<data_type::s16,
292 data_type::s16, data_type::s32>;
294 struct jit_avx512_common_1x1_convolution_bwd_weights_t : public cpu_primitive_t
296 struct pd_t : public cpu_convolution_bwd_weights_pd_t {
297 pd_t(engine_t *engine,
298 const convolution_desc_t *adesc,
299 const primitive_attr_t *attr,
300 const convolution_fwd_pd_t *hint_fwd_pd)
301 : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd)
305 JIT_IMPL_NAME_HELPER("jit_1x1:", avx512_common, ""),
306 jit_avx512_common_1x1_convolution_bwd_weights_t);
308 virtual status_t init() override {
309 using namespace prop_kind;
310 assert(this->engine()->kind() == engine_kind::cpu);
312 && this->set_default_params() == status::success
313 && this->desc()->prop_kind == backward_weights
314 && this->desc()->alg_kind == alg_kind::convolution_direct
315 && utils::everyone_is(data_type::f32,
316 this->desc()->src_desc.data_type,
317 this->desc()->diff_weights_desc.data_type,
318 this->desc()->diff_dst_desc.data_type)
319 && utils::implication(this->with_bias(),
320 data_type::f32 == desc()->diff_bias_desc.data_type);
321 if (!ok) return status::unimplemented;
323 const convolution_desc_t *conv_d = this->desc();
324 const memory_desc_t *src_d = this->src_pd_.desc();
325 rtus_prepare(this, conv_d, src_d, this->diff_dst_pd_.desc());
326 return jit_avx512_common_1x1_conv_kernel::init_conf(jcp_,
327 *conv_d, *src_d, *this->diff_weights_pd_.desc(),
328 *this->diff_dst_pd_.desc(), *this->attr(),
329 omp_get_max_threads(), rtus_.reduce_src_);
332 // TODO (Roma): structs conf header cleanup
333 jit_1x1_conv_conf_t jcp_;
335 struct reduce_to_unit_stride_t {
336 convolution_desc_t conv_d_;
341 virtual status_t set_default_params() override {
342 using namespace memory_format;
344 if (this->src_pd_.desc()->format == any)
345 CHECK(this->src_pd_.set_format(nChw16c));
346 if (this->diff_dst_pd_.desc()->format == any)
347 CHECK(this->diff_dst_pd_.set_format(nChw16c));
348 if (this->diff_weights_pd_.desc()->format == any)
349 CHECK(this->diff_weights_pd_.set_format(this->with_groups()
350 ? gOIhw16i16o : OIhw16i16o));
351 if (this->diff_bias_pd_.desc()->format == any)
352 CHECK(this->diff_bias_pd_.set_format(x));
353 return status::success;
357 template <cpu_isa_t isa, typename conv_t>
358 friend void init_rtus_driver(conv_t *self);
359 jit_avx512_common_1x1_convolution_bwd_weights_t(const pd_t *pd,
360 const input_vector &inputs,
361 const output_vector &outputs);
362 ~jit_avx512_common_1x1_convolution_bwd_weights_t() {
365 delete reducer_bias_;
367 delete trans_kernel_;
375 typedef typename prec_traits<data_type::f32>::type data_t;
377 virtual void execute(event_t *e) {
378 switch (conf_.desc()->prop_kind) {
379 case prop_kind::backward_weights:
380 execute_backward_weights();
383 assert(!"invalid prop_kind");
385 e->set_state(event_t::ready);
389 void execute_backward_weights();
392 jit_avx512_common_1x1_conv_kernel *kernel_;
393 cpu_accumulator_1d_t<data_type::f32> *acc_ker_;
394 cpu_reducer_t<data_type::f32> *reducer_bias_;
395 jit_transpose4x16_src *trans_kernel_;
397 /* reduction to unit stride */
398 rtus_driver_t<avx512_common> *rtus_driver_;
399 size_t ws_per_thread_;
401 data_t *padded_bias_;
403 simple_barrier::ctx_t *bctx_;
405 data_t *ws_reduction_;