updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_dw_convolution.hpp
1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef CPU_JIT_UNI_DW_CONVOLUTION_HPP
18 #define CPU_JIT_UNI_DW_CONVOLUTION_HPP
19
20 #include "c_types_map.hpp"
21 #include "memory_tracking.hpp"
22
23 #include "cpu_barrier.hpp"
24 #include "cpu_convolution_pd.hpp"
25 #include "cpu_reducer.hpp"
26
27 #include "jit_uni_dw_conv_kernel_utils.hpp"
28
29 namespace mkldnn {
30 namespace impl {
31 namespace cpu {
32
33 template <cpu_isa_t isa, data_type_t src_type, data_type_t dst_type = src_type>
34 struct _jit_uni_dw_convolution_fwd_t : public cpu_primitive_t {
35     struct pd_t : public cpu_convolution_fwd_pd_t {
36         pd_t(engine_t *engine, const convolution_desc_t *adesc,
37                 const primitive_attr_t *attr,
38                 const typename pd_t::base_class *hint_fwd_pd)
39             : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
40             , jcp_() {}
41
42         DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_dw:", isa, ""),
43                 _jit_uni_dw_convolution_fwd_t<isa, src_type, dst_type>);
44
45         virtual status_t init() override {
46             using namespace prop_kind;
47             assert(this->engine()->kind() == engine_kind::cpu);
48             bool ok = true && this->set_default_params() == status::success
49                     && utils::one_of(this->desc()->prop_kind, forward_training,
50                                forward_inference)
51                     && utils::one_of(this->desc()->alg_kind,
52                                alg_kind::convolution_auto,
53                                alg_kind::convolution_direct)
54                     && !this->has_zero_dim_memory()
55                     && utils::everyone_is(src_type,
56                                this->desc()->src_desc.data_type,
57                                this->desc()->weights_desc.data_type)
58                     && this->desc()->dst_desc.data_type == dst_type
59                     && IMPLICATION(this->with_bias(), data_type::f32
60                                        == this->desc()->bias_desc.data_type);
61
62             if (!ok)
63                 return status::unimplemented;
64
65             status_t status
66                     = jit_uni_dw_conv_fwd_kernel<isa, src_type>::init_conf(jcp_,
67                             *this->desc(), this->src_pd_.desc(),
68                             *this->weights_pd_.desc(), *this->dst_pd_.desc(),
69                             *this->attr());
70             if (status != status::success)
71                 return status;
72
73             auto scratchpad = scratchpad_registry().registrar();
74             jit_uni_dw_conv_fwd_kernel<isa, src_type>::init_scratchpad(
75                     scratchpad, jcp_);
76
77             return status::success;
78         }
79
80         jit_conv_conf_t jcp_;
81
82     protected:
83         virtual status_t set_default_params() override {
84             using namespace memory_format;
85             auto desired_act_fmt
86                     = utils::one_of(isa, avx512_common, avx512_core) ? nChw16c
87                                                                      : nChw8c;
88             auto desired_wei_fmt
89                     = utils::one_of(isa, avx512_common, avx512_core) ? Goihw16g
90                                                                      : Goihw8g;
91
92             if (this->src_pd_.desc()->format == any)
93                 CHECK(this->src_pd_.set_format(desired_act_fmt));
94             if (this->dst_pd_.desc()->format == any)
95                 CHECK(this->dst_pd_.set_format(desired_act_fmt));
96             if (this->weights_pd_.desc()->format == any)
97                 CHECK(this->weights_pd_.set_format(desired_wei_fmt));
98             if (this->bias_pd_.desc()->format == any)
99                 CHECK(this->bias_pd_.set_format(x));
100             if (this->desc()->alg_kind == alg_kind::convolution_auto)
101                 CHECK(this->set_alg_kind(alg_kind::convolution_direct));
102             return status::success;
103         }
104     };
105
106     _jit_uni_dw_convolution_fwd_t(const pd_t *apd, const input_vector &inputs,
107             const output_vector &outputs)
108         : cpu_primitive_t(apd, inputs, outputs), kernel_(nullptr) {
109         kernel_ = new jit_uni_dw_conv_fwd_kernel<isa, src_type>(pd()->jcp_, *pd()->attr());
110     }
111
112     ~_jit_uni_dw_convolution_fwd_t() { delete kernel_; }
113
114     typedef typename prec_traits<data_type::f32>::type f32_data_t;
115     typedef typename prec_traits<src_type>::type data_t;
116     typedef typename prec_traits<dst_type>::type dst_data_t;
117
118     virtual void execute(event_t *e) const {
119         execute_forward();
120         e->set_state(event_t::ready);
121     }
122
123 private:
124     void execute_forward() const;
125     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
126
127     jit_uni_dw_conv_fwd_kernel<isa, src_type> *kernel_;
128 };
129
130 using jit_avx512_common_dw_convolution_fwd_t
131         = _jit_uni_dw_convolution_fwd_t<avx512_common, data_type::f32>;
132 using jit_avx2_dw_convolution_fwd_t
133         = _jit_uni_dw_convolution_fwd_t<avx2, data_type::f32>;
134 using jit_sse42_dw_convolution_fwd_t
135         = _jit_uni_dw_convolution_fwd_t<sse42, data_type::f32>;
136
137 template <cpu_isa_t isa, data_type_t diff_dst_type,
138         data_type_t diff_src_type = diff_dst_type>
139 struct _jit_uni_dw_convolution_bwd_data_t : public cpu_primitive_t {
140     struct pd_t : public cpu_convolution_bwd_data_pd_t {
141         pd_t(engine_t *engine, const convolution_desc_t *adesc,
142                 const primitive_attr_t *attr,
143                 const convolution_fwd_pd_t *hint_fwd_pd)
144             : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd)
145             , jcp_() {}
146
147         DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_dw:", isa, ""),
148                 _jit_uni_dw_convolution_bwd_data_t<isa, diff_dst_type,
149                                     diff_src_type>);
150
151         virtual status_t init() override {
152             using namespace prop_kind;
153
154             assert(this->engine()->kind() == engine_kind::cpu);
155             bool ok = true && this->set_default_params() == status::success
156                     && utils::one_of(
157                                this->desc()->prop_kind, backward, backward_data)
158                     && utils::one_of(this->desc()->alg_kind,
159                                alg_kind::convolution_auto,
160                                alg_kind::convolution_direct)
161                     && !this->has_zero_dim_memory()
162                     && utils::everyone_is(diff_dst_type,
163                                this->desc()->weights_desc.data_type,
164                                this->desc()->diff_dst_desc.data_type)
165                     && diff_src_type == this->desc()->diff_src_desc.data_type;
166
167             if (!ok)
168                 return status::unimplemented;
169
170             status_t status = jit_uni_dw_conv_bwd_data_kernel<isa,
171                     diff_dst_type>::init_conf(jcp_, *this->desc(),
172                     *this->diff_src_pd_.desc(), *this->weights_pd_.desc(),
173                     *this->diff_dst_pd_.desc());
174             if (status != status::success)
175                 return status;
176
177             auto scratchpad = scratchpad_registry().registrar();
178             jit_uni_dw_conv_bwd_data_kernel<isa,
179                     diff_dst_type>::init_scratchpad(scratchpad, jcp_);
180
181             return status::success;
182         }
183
184         jit_conv_conf_t jcp_;
185
186     protected:
187         virtual status_t set_default_params() override {
188             using namespace memory_format;
189             auto desired_act_fmt
190                     = utils::one_of(isa, avx512_common, avx512_core) ? nChw16c
191                                                                      : nChw8c;
192             auto desired_wei_fmt
193                     = utils::one_of(isa, avx512_common, avx512_core) ? Goihw16g
194                                                                      : Goihw8g;
195
196             if (this->diff_src_pd_.desc()->format == any)
197                 CHECK(this->diff_src_pd_.set_format(desired_act_fmt));
198             if (this->diff_dst_pd_.desc()->format == any)
199                 CHECK(this->diff_dst_pd_.set_format(desired_act_fmt));
200             if (this->weights_pd_.desc()->format == any)
201                 CHECK(this->weights_pd_.set_format(desired_wei_fmt));
202             if (this->desc()->alg_kind == alg_kind::convolution_auto)
203                 CHECK(this->set_alg_kind(alg_kind::convolution_direct));
204
205             return status::success;
206         }
207     };
208
209     _jit_uni_dw_convolution_bwd_data_t(const pd_t *apd,
210             const input_vector &inputs, const output_vector &outputs)
211         : cpu_primitive_t(apd, inputs, outputs) {
212         kernel_ = new jit_uni_dw_conv_bwd_data_kernel<isa, diff_dst_type>(
213                 pd()->jcp_);
214     }
215     ~_jit_uni_dw_convolution_bwd_data_t() { delete kernel_; };
216
217     typedef typename prec_traits<diff_src_type>::type diff_src_data_t;
218     typedef typename prec_traits<diff_dst_type>::type diff_dst_data_t;
219     typedef typename prec_traits<diff_dst_type>::type diff_wei_data_t;
220
221     virtual void execute(event_t *e) const {
222         switch (pd()->desc()->prop_kind) {
223         case prop_kind::backward_data: execute_backward_data(); break;
224         default: assert(!"invalid prop_kind");
225         }
226         e->set_state(event_t::ready);
227     }
228
229 private:
230     void execute_backward_data() const;
231     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
232
233     jit_uni_dw_conv_bwd_data_kernel<isa, diff_dst_type> *kernel_;
234 };
235
236 using jit_avx512_common_dw_convolution_bwd_data_t
237         = _jit_uni_dw_convolution_bwd_data_t<avx512_common, data_type::f32>;
238 using jit_avx2_dw_convolution_bwd_data_t
239         = _jit_uni_dw_convolution_bwd_data_t<avx2, data_type::f32>;
240 using jit_sse42_dw_convolution_bwd_data_t
241         = _jit_uni_dw_convolution_bwd_data_t<sse42, data_type::f32>;
242
243 template <cpu_isa_t isa, data_type_t src_type,
244         data_type_t diff_weights_type = src_type>
245 struct _jit_uni_dw_convolution_bwd_weights_t : public cpu_primitive_t {
246     struct pd_t : public cpu_convolution_bwd_weights_pd_t {
247         pd_t(engine_t *engine, const convolution_desc_t *adesc,
248                 const primitive_attr_t *attr,
249                 const convolution_fwd_pd_t *hint_fwd_pd)
250             : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd)
251             , jcp_() {}
252
253         DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_dw:", isa, ""),
254                 _jit_uni_dw_convolution_bwd_weights_t<isa, src_type,
255                                     diff_weights_type>);
256
257         virtual status_t init() override {
258             using namespace prop_kind;
259
260             assert(this->engine()->kind() == engine_kind::cpu);
261             bool ok = true && this->set_default_params() == status::success
262                     && this->desc()->prop_kind == prop_kind::backward_weights
263                     && utils::one_of(this->desc()->alg_kind,
264                                alg_kind::convolution_auto,
265                                alg_kind::convolution_direct)
266                     && utils::everyone_is(src_type,
267                                this->desc()->src_desc.data_type,
268                                this->desc()->diff_dst_desc.data_type)
269                     && this->desc()->diff_weights_desc.data_type
270                             == diff_weights_type;
271
272             if (!ok)
273                 return status::unimplemented;
274
275             const int max_threads
276                     = mkldnn_in_parallel() ? 1 : mkldnn_get_max_threads();
277
278             status_t status = jit_uni_dw_conv_bwd_weights_kernel<isa,
279                     src_type>::init_conf(jcp_, *this->desc(),
280                     *this->src_pd_.desc(), *this->diff_weights_pd_.desc(),
281                     *this->diff_dst_pd_.desc(), max_threads);
282             if (status != status::success)
283                 return status;
284
285             auto scratchpad = scratchpad_registry().registrar();
286             jit_uni_dw_conv_bwd_weights_kernel<isa, src_type>::init_scratchpad(
287                     scratchpad, jcp_);
288
289             return status::success;
290         }
291
292         jit_conv_conf_t jcp_;
293
294     protected:
295         virtual status_t set_default_params() override {
296             using namespace memory_format;
297             auto desired_act_fmt
298                     = utils::one_of(isa, avx512_common, avx512_core) ? nChw16c
299                                                                      : nChw8c;
300             auto desired_wei_fmt
301                     = utils::one_of(isa, avx512_common, avx512_core) ? Goihw16g
302                                                                      : Goihw8g;
303
304             if (this->src_pd_.desc()->format == any)
305                 CHECK(this->src_pd_.set_format(desired_act_fmt));
306             if (this->diff_dst_pd_.desc()->format == any)
307                 CHECK(this->diff_dst_pd_.set_format(desired_act_fmt));
308             if (this->diff_weights_pd_.desc()->format == any)
309                 CHECK(this->diff_weights_pd_.set_format(desired_wei_fmt));
310             if (this->diff_bias_pd_.desc()->format == any)
311                 CHECK(this->diff_bias_pd_.set_format(x));
312             if (this->desc()->alg_kind == alg_kind::convolution_auto)
313                 CHECK(this->set_alg_kind(alg_kind::convolution_direct));
314
315             return status::success;
316         }
317     };
318
319     _jit_uni_dw_convolution_bwd_weights_t(const pd_t *apd,
320             const input_vector &inputs, const output_vector &outputs)
321         : cpu_primitive_t(apd, inputs, outputs)
322         , acc_ker_(nullptr)
323         , kernel_(nullptr) {
324         kernel_ = new jit_uni_dw_conv_bwd_weights_kernel<isa, src_type>(
325                 pd()->jcp_);
326
327         if (pd()->jcp_.nthr_mb > 1 && isa != sse42)
328             acc_ker_ = new cpu_accumulator_1d_t<data_type::f32>();
329     }
330
331     ~_jit_uni_dw_convolution_bwd_weights_t() {
332         delete acc_ker_;
333         delete kernel_;
334     };
335
336     typedef typename prec_traits<data_type::f32>::type f32_data_t;
337     typedef typename prec_traits<src_type>::type src_data_t;
338     typedef typename prec_traits<src_type>::type diff_dst_data_t;
339     typedef typename prec_traits<diff_weights_type>::type diff_weights_data_t;
340
341     virtual void execute(event_t *e) const {
342         execute_backward_weights();
343         execute_reduction();
344         e->set_state(event_t::ready);
345     }
346
347 private:
348     void execute_backward_weights() const;
349     void execute_reduction() const;
350     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
351
352     cpu_accumulator_1d_t<data_type::f32> *acc_ker_;
353     jit_uni_dw_conv_bwd_weights_kernel<isa, src_type> *kernel_;
354 };
355
356 using jit_avx512_common_dw_convolution_bwd_weights_t
357         = _jit_uni_dw_convolution_bwd_weights_t<avx512_common, data_type::f32>;
358 using jit_avx2_dw_convolution_bwd_weights_t
359         = _jit_uni_dw_convolution_bwd_weights_t<avx2, data_type::f32>;
360 using jit_sse42_dw_convolution_bwd_weights_t
361         = _jit_uni_dw_convolution_bwd_weights_t<sse42, data_type::f32>;
362
363 }
364 }
365 }
366
367 #endif