Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_planar_convolution.cpp
1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <cstring>
18 #include "mkldnn_types.h"
19
20 #include "c_types_map.hpp"
21 #include "jit_uni_planar_convolution.hpp"
22 #include "utils.hpp"
23 #include "mkldnn_thread.hpp"
24 #include "type_helpers.hpp"
25
26 namespace mkldnn {
27 namespace impl {
28 namespace cpu {
29
30 using namespace mkldnn::impl::status;
31 using namespace mkldnn::impl::memory_format;
32 using namespace mkldnn::impl::utils;
33
34 #define src_blk_off(f, n, c, d, h, w) \
35     pd()->ndims() == 5 \
36         ? (f).blk_off(n, c, d, h, w) \
37         : (f).blk_off(n, c, h, w)
38
39 #define wht_blk_off(f, g, oc, ic, kd, kh, kw) \
40     pd()->ndims() == 5 \
41         ? pd()->with_groups() \
42             ? (f).blk_off(g, oc, ic, kd, kh, kw) \
43             : (f).blk_off(oc, ic, kd, kh, kw) \
44         : pd()->with_groups() \
45             ? (f).blk_off(g, oc, ic, kh, kw) \
46             : (f).blk_off(oc, ic, kh, kw)
47
48 template <cpu_isa_t isa>
49 void _jit_uni_planar_convolution_fwd_t<isa>::execute_forward() const {
50     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
51     auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
52     auto bias = reinterpret_cast<const data_t *>(this->input_memory(2));
53     auto dst = reinterpret_cast<data_t *>(this->memory());
54
55     const memory_desc_wrapper src_d(pd()->src_pd());
56     const memory_desc_wrapper dst_d(pd()->dst_pd());
57     const memory_desc_wrapper weights_d(pd()->weights_pd(0));
58     const memory_desc_wrapper bias_d(pd()->weights_pd(1));
59
60     const auto &jcp = kernel_->jcp;
61     const int MB = pd()->MB();
62
63     int od_indexes[jcp.od];
64
65     int idx = 0;
66     for (int i = 0; i < (jcp.dilate_d + 1); i++) {
67         for (int ib = 0; ib < jcp.od; ib += (jcp.dilate_d + 1)) {
68             if (ib + i >= jcp.od)
69                 continue;
70
71             od_indexes[idx++] = ib + i;
72             if (idx >= jcp.od)
73                 break;
74         }
75         if (idx >= jcp.od)
76             break;
77     }
78
79     int threads_count = mkldnn_get_max_threads();
80     int odb_size = div_up(jcp.od, threads_count);
81
82     auto kernel_params = [&](int n, int g, int icb, int oc, int od, int oh, int oh_blocks, int id, int wd, int kd_padding) {
83         auto par_conv = jit_conv_call_s();
84
85         const int hj = oh * jcp.stride_h;
86         const int i_t_overflow = nstl::max(0, jcp.t_pad - hj);
87         const int i_b_overflow = nstl::max(jcp.ih, hj + (jcp.kh - 1) * (jcp.dilate_h + 1) - jcp.t_pad + 1) - jcp.ih;
88         const int ih = nstl::max(hj - jcp.t_pad + div_up(i_t_overflow, (jcp.dilate_h + 1)) * (jcp.dilate_h + 1), 0);
89         const int wh = div_up(i_t_overflow, (jcp.dilate_h + 1));
90         const int kh_padding = jcp.kh - div_up(i_t_overflow, (jcp.dilate_h + 1)) - div_up(i_b_overflow, (jcp.dilate_h + 1));
91
92         const size_t _oc = oc;
93         const size_t _ic = g * jcp.nb_ic + icb;
94
95         par_conv.src = &src[src_blk_off(src_d, n, _ic, id, ih, 0)];
96         par_conv.dst = &dst[src_blk_off(dst_d, n, _oc, od, oh, 0)];
97         par_conv.filt = &weights[wht_blk_off(weights_d, g, _oc, _ic, wd, wh, 0)];
98
99         if (icb == 0) {
100             if (bias)
101                 par_conv.bias = &bias[bias_d.blk_off(_oc)];
102             par_conv.flags |= FLAG_IC_FIRST;
103         }
104
105         if (icb + 1 == jcp.nb_ic) {
106             par_conv.flags |= FLAG_IC_LAST;
107         }
108
109         par_conv.oc_off = _oc * sizeof(float);
110         par_conv.oh_blocks = (size_t)oh_blocks;
111
112         par_conv.kh_padding = (size_t)nstl::max(0, kh_padding);
113         par_conv.kd_padding = (size_t)nstl::max(0, kd_padding);
114
115         return par_conv;
116     };
117
118     auto ker = [&](const int ithr, const int nthr) {
119         int g = 0;
120         int oc = 0;
121
122         for (int n = 0; n < MB; n++) {
123             int icbb = 0;
124             while (icbb < jcp.nb_ic) {
125                 int icb_step = jcp.nb_ic_blocking;
126                 int icb_step_rem = jcp.nb_ic - icbb;
127                 if (icb_step_rem < jcp.nb_ic_blocking_max)
128                     icb_step = icb_step_rem;
129
130                 for (int icb = icbb; icb < icbb + icb_step; ++icb) {
131                     for (int ohb = 0; ohb < (jcp.dilate_h + 1); ohb++) {
132                         for (int oh = ohb; oh < jcp.oh; oh += (jcp.dilate_h + 1)) {
133                             int od_idx_off = ithr * odb_size;
134                             for (int od_idx = 0; od_idx < odb_size; od_idx++) {
135                                 if ((od_idx_off + od_idx) >= jcp.od || od_indexes[od_idx_off + od_idx] >= jcp.od)
136                                     continue;
137                                 int od = od_indexes[od_idx_off + od_idx];
138
139                                 const int dj = od * jcp.stride_d;
140                                 const int d_t_overflow = nstl::max(0, jcp.f_pad - dj);
141                                 const int d_b_overflow =
142                                         nstl::max(jcp.id, dj + (jcp.kd - 1) * (jcp.dilate_d + 1) - jcp.f_pad + 1) -
143                                         jcp.id;
144                                 const int id = nstl::max(dj - jcp.f_pad +
145                                                          div_up(d_t_overflow, (jcp.dilate_d + 1)) * (jcp.dilate_d + 1),
146                                                          0);
147                                 const int wd = div_up(d_t_overflow, (jcp.dilate_d + 1));
148                                 const int kd_padding = jcp.kd - div_up(d_t_overflow, (jcp.dilate_d + 1)) -
149                                                        div_up(d_b_overflow, (jcp.dilate_d + 1));
150
151                                 jit_conv_call_s par_conv = kernel_params(n, g, icb, oc, od, oh, 1, id, wd, kd_padding);
152
153                                 kernel_->jit_ker(&par_conv);
154                             }
155                         }
156                     }
157                 }
158                 icbb += icb_step;
159             }
160         }
161     };
162
163     parallel(0, ker);
164 }
165
166
167 template struct _jit_uni_planar_convolution_fwd_t<avx512_common>;
168 template struct _jit_uni_planar_convolution_fwd_t<avx2>;
169
170 }
171 }
172 }