Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx2_convolution.hpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef CPU_JIT_AVX2_CONVOLUTION_HPP
18 #define CPU_JIT_AVX2_CONVOLUTION_HPP
19
20 #include "c_types_map.hpp"
21 #include "memory_tracking.hpp"
22 #include "mkldnn_thread.hpp"
23 #include "utils.hpp"
24
25 #include "cpu_convolution_pd.hpp"
26 #include "cpu_reducer.hpp"
27
28 #include "jit_avx2_conv_kernel_f32.hpp"
29 #include "jit_uni_depthwise.hpp"
30
31 namespace mkldnn {
32 namespace impl {
33 namespace cpu {
34
35 struct jit_avx2_convolution_fwd_t: public cpu_primitive_t {
36     struct pd_t: public cpu_convolution_fwd_pd_t {
37         pd_t(engine_t *engine,
38                 const convolution_desc_t *adesc,
39                 const primitive_attr_t *attr,
40                 const typename pd_t::base_class *hint_fwd_pd)
41             : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
42             , jcp_(), jcp_dw_() {}
43
44         DECLARE_COMMON_PD_T(
45                 JIT_IMPL_NAME_HELPER("jit:", avx2, ""),
46                 jit_avx2_convolution_fwd_t);
47
48         virtual status_t init() override {
49             using namespace prop_kind;
50             assert(this->engine()->kind() == engine_kind::cpu);
51             bool ok = true
52                 && this->set_default_params() == status::success
53                 && utils::one_of(this->desc()->prop_kind, forward_training,
54                         forward_inference)
55                 && utils::one_of(this->desc()->alg_kind,
56                         alg_kind::convolution_auto,
57                         alg_kind::convolution_direct)
58                 && !this->has_zero_dim_memory()
59                 && utils::everyone_is(data_type::f32,
60                         this->desc()->src_desc.data_type,
61                         this->desc()->weights_desc.data_type,
62                         this->desc()->dst_desc.data_type)
63                 && IMPLICATION(this->with_bias(),
64                         data_type::f32 == this->desc()->bias_desc.data_type);
65             if (!ok) return status::unimplemented;
66
67
68
69             status_t sts = jit_avx2_conv_fwd_kernel_f32::init_conf(jcp_,
70                     *this->desc(), *this->src_pd_.desc(),
71                     *this->weights_pd_.desc(), *this->dst_pd_.desc(),
72                     *this->attr());
73             if (sts != status::success) return sts;
74
75             if (jcp_.with_dw_conv) {
76                 status_t sts_dw = jit_uni_dw_conv_row_f32<avx2>::init_conf(jcp_, jcp_dw_, *this->attr());
77                 if (sts_dw != status::success) return sts_dw;
78             }
79
80             auto scratchpad = scratchpad_registry().registrar();
81             jit_avx2_conv_fwd_kernel_f32::init_scratchpad(scratchpad, jcp_, jcp_dw_);
82
83             return status::success;
84         }
85
86         jit_conv_conf_t jcp_;
87         jit_conv_conf_t jcp_dw_;
88
89     protected:
90         virtual status_t set_default_params() override {
91             using namespace memory_format;
92
93             const int simd_w = 8;
94             const bool flat = this->IC() < simd_w;
95             if (this->src_pd_.desc()->format == any)
96                 CHECK(this->src_pd_.set_format(flat
97                     ? utils::pick(this->ndims() - 3, ncw, nchw, ncdhw)
98                     : utils::pick(this->ndims() - 3, nCw8c, nChw8c, nCdhw8c)));
99             if (this->dst_pd_.desc()->format == any)
100                 CHECK(this->dst_pd_.set_format(
101                     utils::pick(this->ndims() - 3, nCw8c, nChw8c, nCdhw8c)));
102             if (this->weights_pd_.desc()->format == any)
103                 CHECK(this->weights_pd_.set_format(this->with_groups()
104                     ? utils::pick(2 * this->ndims() - 6 + flat, gOIw8i8o,
105                         gOwi8o, gOIhw8i8o, gOhwi8o, gOIdhw8i8o, gOdhwi8o)
106                     : utils::pick(2 * this->ndims() - 6 + flat, OIw8i8o, Owi8o,
107                         OIhw8i8o, Ohwi8o, OIdhw8i8o, Odhwi8o)));
108
109             if (this->bias_pd_.desc()->format == any)
110                 CHECK(this->bias_pd_.set_format(x));
111             if (this->desc()->alg_kind == alg_kind::convolution_auto)
112                 CHECK(this->set_alg_kind(alg_kind::convolution_direct));
113             return status::success;
114         }
115     };
116
117     jit_avx2_convolution_fwd_t(const pd_t *apd, const input_vector &inputs,
118             const output_vector &outputs)
119         : cpu_primitive_t(apd, inputs, outputs)
120     {
121         kernel_ = new jit_avx2_conv_fwd_kernel_f32(pd()->jcp_, pd()->jcp_dw_, *pd()->attr());
122
123         if (pd()->jcp_.with_dw_conv) {
124             kernel_dw_ = new jit_uni_dw_conv_row_f32<avx2>(pd()->jcp_dw_, *pd()->attr(), pd()->jcp_dw_.ch_block);
125         }
126     }
127
128     ~jit_avx2_convolution_fwd_t() {
129         delete kernel_;
130
131         if (pd()->jcp_.with_dw_conv) {
132             delete kernel_dw_;
133         }
134     };
135
136     typedef typename prec_traits<data_type::f32>::type data_t;
137
138     virtual void execute(event_t *e) const {
139         if (pd()->jcp_.with_dw_conv)
140             execute_forward_with_dw_conv();
141         else
142             execute_forward();
143
144         e->set_state(event_t::ready);
145     }
146
147 private:
148     void execute_forward() const;
149     void execute_forward_with_dw_conv() const;
150     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
151
152     jit_avx2_conv_fwd_kernel_f32 *kernel_;
153     jit_uni_dw_conv_row_f32<avx2> *kernel_dw_;
154 };
155
156 struct jit_avx2_convolution_bwd_data_t: public cpu_primitive_t {
157     struct pd_t: public cpu_convolution_bwd_data_pd_t {
158         pd_t(engine_t *engine,
159                 const convolution_desc_t *adesc,
160                 const primitive_attr_t *attr,
161                 const convolution_fwd_pd_t *hint_fwd_pd)
162             : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd)
163             , jcp_()
164         {}
165
166         DECLARE_COMMON_PD_T(
167                 JIT_IMPL_NAME_HELPER("jit:", avx2, ""),
168                 jit_avx2_convolution_bwd_data_t);
169
170         virtual status_t init() override {
171             using namespace prop_kind;
172             assert(this->engine()->kind() == engine_kind::cpu);
173             bool ok = true
174                 && this->set_default_params() == status::success
175                 && utils::one_of(this->desc()->prop_kind, backward_data)
176                 && utils::one_of(this->desc()->alg_kind, alg_kind::convolution_auto,
177                            alg_kind::convolution_direct)
178                 && !this->has_zero_dim_memory()
179                 && utils::everyone_is(data_type::f32,
180                         this->desc()->diff_src_desc.data_type,
181                         this->desc()->weights_desc.data_type,
182                         this->desc()->diff_dst_desc.data_type);
183             if (!ok) return status::unimplemented;
184
185             status_t status = jit_avx2_conv_bwd_data_kernel_f32::init_conf(
186                     jcp_, *this->desc(), *this->diff_src_pd_.desc(),
187                     *this->weights_pd_.desc(), *this->diff_dst_pd_.desc());
188             if (status != status::success) return status;
189
190             auto scratchpad = scratchpad_registry().registrar();
191             jit_avx2_conv_bwd_data_kernel_f32::init_scratchpad(scratchpad,
192                     jcp_);
193
194             return status::success;
195         }
196
197         jit_conv_conf_t jcp_;
198
199     protected:
200         virtual status_t set_default_params() override {
201             using namespace memory_format;
202
203             if (this->diff_src_pd_.desc()->format == any)
204                 CHECK(this->diff_src_pd_.set_format(
205                     utils::pick(this->ndims() - 3, nCw8c, nChw8c, nCdhw8c)));
206             if (this->diff_dst_pd_.desc()->format == any)
207                 CHECK(this->diff_dst_pd_.set_format(
208                     utils::pick(this->ndims() - 3, nCw8c, nChw8c, nCdhw8c)));
209             if (this->weights_pd_.desc()->format == any)
210                 CHECK(this->weights_pd_.set_format(this->with_groups()
211                     ? utils::pick(this->ndims() - 3, gOIw8o8i, gOIhw8o8i,
212                         gOIdhw8o8i)
213                     : utils::pick(this->ndims() - 3, OIw8o8i, OIhw8o8i,
214                         OIdhw8o8i)));
215             if (this->desc()->alg_kind == alg_kind::convolution_auto)
216                 CHECK(this->set_alg_kind(alg_kind::convolution_direct));
217             return status::success;
218         }
219     };
220
221     jit_avx2_convolution_bwd_data_t(const pd_t *apd, const input_vector &inputs,
222             const output_vector &outputs)
223         : cpu_primitive_t(apd, inputs, outputs)
224     { kernel_ = new jit_avx2_conv_bwd_data_kernel_f32(pd()->jcp_); }
225     ~jit_avx2_convolution_bwd_data_t() { delete kernel_; }
226
227     typedef typename prec_traits<data_type::f32>::type data_t;
228
229     virtual void execute(event_t *e) const {
230         switch (pd()->desc()->prop_kind) {
231         case prop_kind::backward_data:
232             execute_backward_data();
233             break;
234         default:
235             assert(!"invalid prop_kind");
236         }
237         e->set_state(event_t::ready);
238     }
239
240 private:
241     void execute_backward_data() const;
242     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
243
244     jit_avx2_conv_bwd_data_kernel_f32 *kernel_;
245 };
246
247 struct jit_avx2_convolution_bwd_weights_t: public cpu_primitive_t {
248     struct pd_t: public  cpu_convolution_bwd_weights_pd_t {
249         pd_t(engine_t *engine, const convolution_desc_t *adesc,
250                 const primitive_attr_t *attr,
251                 const convolution_fwd_pd_t *hint_fwd_pd)
252             : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd)
253             , jcp_() {}
254
255         DECLARE_COMMON_PD_T(
256                 JIT_IMPL_NAME_HELPER("jit:", avx2, ""),
257                 jit_avx2_convolution_bwd_weights_t);
258
259         virtual status_t init() override {
260             assert(this->engine()->kind() == engine_kind::cpu);
261             bool ok = true
262                 && this->set_default_params() == status::success
263                 && this->desc()->prop_kind == prop_kind::backward_weights
264                 && utils::one_of(this->desc()->alg_kind, alg_kind::convolution_auto,
265                            alg_kind::convolution_direct)
266                 && !this->has_zero_dim_memory()
267                 && utils::everyone_is(data_type::f32,
268                         this->desc()->src_desc.data_type,
269                         this->desc()->diff_dst_desc.data_type,
270                         this->desc()->diff_weights_desc.data_type);
271             if (!ok) return status::unimplemented;
272
273             status_t status = jit_avx2_conv_bwd_weights_kernel_f32::init_conf(
274                     jcp_, *this->desc(), *this->src_pd_.desc(),
275                     *this->diff_weights_pd_.desc(),
276                     *this->diff_dst_pd_.desc());
277             if (status != status::success) return status;
278
279             init_balancers();
280
281             auto scratchpad = scratchpad_registry().registrar();
282             jit_avx2_conv_bwd_weights_kernel_f32::init_scratchpad(scratchpad,
283                     jcp_);
284
285             auto reducer_bia_scratchpad = memory_tracking::registrar_t(
286                     scratchpad, memory_tracking::names::prefix_reducer_bia);
287             reducer_bia_conf_.init_scratchpad(reducer_bia_scratchpad);
288
289             auto reducer_wei_scratchpad = memory_tracking::registrar_t(
290                     scratchpad, memory_tracking::names::prefix_reducer_wei);
291             reducer_wei_conf_.init_scratchpad(reducer_wei_scratchpad);
292
293             return status::success;
294         }
295
296         jit_conv_conf_t jcp_;
297         cpu_reducer_t<data_type::f32>::conf_t reducer_bia_conf_;
298         cpu_reducer_t<data_type::f32>::conf_t reducer_wei_conf_;
299
300     protected:
301         virtual status_t set_default_params() override {
302             using namespace memory_format;
303             const bool flat = this->IC() == 3;
304
305             if (this->src_pd_.desc()->format == any)
306                 CHECK(this->src_pd_.set_format(flat
307                     ? utils::pick(this->ndims() - 3, ncw, nchw, ncdhw)
308                     : utils::pick(this->ndims() - 3, nCw8c, nChw8c, nCdhw8c)));
309             if (this->diff_dst_pd_.desc()->format == any)
310                 CHECK(this->diff_dst_pd_.set_format(
311                     utils::pick(this->ndims() - 3, nCw8c, nChw8c, nCdhw8c)));
312             if (this->diff_weights_pd_.desc()->format == any)
313                 CHECK(this->diff_weights_pd_.set_format(this->with_groups()
314                     ? utils::pick(2 * this->ndims() - 6 + flat, gOIw8i8o,
315                         gOwi8o, gOIhw8i8o, gOhwi8o, gOIdhw8i8o, gOdhwi8o)
316                     : utils::pick(2 * this->ndims() - 6 + flat, OIw8i8o, Owi8o,
317                         OIhw8i8o, Ohwi8o, OIdhw8i8o, Odhwi8o)));
318             if (this->diff_bias_pd_.desc()->format == any)
319                 CHECK(this->diff_bias_pd_.set_format(x));
320             if (this->desc()->alg_kind == alg_kind::convolution_auto)
321                 CHECK(this->set_alg_kind(alg_kind::convolution_direct));
322             return status::success;
323         }
324
325     private:
326         void init_balancers() {
327             const int max_threads = mkldnn_get_max_threads();
328             const size_t max_buffer_size = 1<<21; /* just a heuristic */
329
330             if(with_bias()) {
331                 reducer_bia_conf_.init(reduce_balancer_t(max_threads,
332                             jcp_.oc_block, jcp_.ngroups * jcp_.nb_oc, jcp_.mb,
333                             max_buffer_size));
334             }
335
336             reducer_wei_conf_.init(reduce_balancer_t(max_threads,
337                         jcp_.kd * jcp_.kh * jcp_.kw
338                         * jcp_.ic_block * jcp_.oc_block,
339                         jcp_.ngroups * jcp_.nb_ic * jcp_.nb_oc,
340                         jcp_.mb * jcp_.od, max_buffer_size));
341         }
342     };
343
344     jit_avx2_convolution_bwd_weights_t(const pd_t *apd,
345             const input_vector &inputs, const output_vector &outputs)
346         : cpu_primitive_t(apd, inputs, outputs)
347         , kernel_(nullptr), reducer_weights_(nullptr), reducer_bias_(nullptr)
348     {
349         kernel_ = new jit_avx2_conv_bwd_weights_kernel_f32(pd()->jcp_);
350         reducer_bias_ =
351             new cpu_reducer_t<data_type::f32>(pd()->reducer_bia_conf_);
352         reducer_weights_ =
353             new cpu_reducer_t<data_type::f32>(pd()->reducer_wei_conf_);
354     }
355
356     ~jit_avx2_convolution_bwd_weights_t() {
357         delete kernel_;
358         delete reducer_weights_;
359         delete reducer_bias_;
360     }
361
362     typedef typename prec_traits<data_type::f32>::type data_t;
363
364     virtual void execute(event_t *e) const {
365         execute_backward_weights();
366         e->set_state(event_t::ready);
367     }
368
369 private:
370     void execute_backward_weights() const;
371     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
372
373     jit_avx2_conv_bwd_weights_kernel_f32 *kernel_;
374     cpu_reducer_t<data_type::f32> *reducer_weights_, *reducer_bias_;
375 };
376
377 }
378 }
379 }
380
381 #endif
382
383 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s