updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / ref_deconvolution.cpp
1 /*******************************************************************************
2 * Copyright 2018-2019 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "c_types_map.hpp"
18 #include "type_helpers.hpp"
19 #include "mkldnn_thread.hpp"
20 #include "mkldnn_traits.hpp"
21 #include "math_utils.hpp"
22
23 #include "ref_deconvolution.hpp"
24
25 #include "bfloat16_utils.hpp"
26
27 namespace mkldnn {
28 namespace impl {
29 namespace cpu {
30
31 void ref_deconvolution_fwd_t::compute_fwd_bias() const {
32     auto bias = reinterpret_cast<const f32_data_t *>(this->input_memory(2));
33     auto dst = reinterpret_cast<f32_data_t *>(this->memory());
34     const memory_desc_wrapper dst_d(pd()->dst_pd());
35
36     const int G = pd()->G();
37     const int MB = pd()->MB();
38     const int OH = pd()->OH();
39     const int OW = pd()->OW();
40     const int OD = pd()->OD();
41     const int OC = pd()->OC() / G;
42     const int ndims = pd()->desc()->src_desc.ndims;
43
44     parallel_nd(MB, G, OC, OD, OH, OW,
45         [&](int mb, int g, int oc, int od, int oh, int ow) {
46             auto b = bias[g * OC + oc];
47             switch (ndims) {
48             case 5: dst[dst_d.off(mb, g * OC + oc, od, oh, ow)] += b; break;
49             case 4: dst[dst_d.off(mb, g * OC + oc, oh, ow)] += b; break;
50             case 3: dst[dst_d.off(mb, g * OC + oc, ow)] += b; break;
51             default: assert(!"invalid dimension size");
52             }
53     });
54 }
55
56 void ref_deconvolution_fwd_t::compute_fwd_bias_ncdhw() const {
57     auto bias = reinterpret_cast<const f32_data_t *>(this->input_memory(2));
58     auto dst = reinterpret_cast<f32_data_t *>(this->memory());
59
60     const memory_desc_wrapper dst_d(pd()->dst_pd());
61
62     const int MB = pd()->MB();
63     const int OC = pd()->OC();
64     const int SP = pd()->OW()*pd()->OH()*pd()->OD();
65
66     parallel_nd(MB, OC, [&](int mb, int oc) {
67         PRAGMA_OMP_SIMD()
68         for (int sp = 0; sp < SP; ++sp) {
69             auto offset = (size_t)(mb * OC + oc) * SP + sp;
70             dst[offset] += bias[oc];
71         }
72     });
73 }
74
75 template <int blksize>
76 void ref_deconvolution_fwd_t::compute_fwd_bias_nCdhwXc() const {
77     auto bias = reinterpret_cast<const f32_data_t *>(this->input_memory(2));
78     auto dst = reinterpret_cast<f32_data_t *>(this->memory());
79
80     const memory_desc_wrapper dst_d(pd()->dst_pd());
81
82     const int MB = pd()->MB();
83     const int OC = pd()->OC();
84     const int SP = pd()->OW() * pd()->OH() * pd()->OD();
85
86     const ptrdiff_t stride_mb = dst_d.blocking_desc().strides[0][0];
87
88     parallel_nd(MB, utils::div_up(OC, blksize), SP,
89         [&](int mb, int oc_blk, int sp) {
90         int oc = oc_blk * blksize;
91         auto offset = mb * stride_mb + oc * SP + sp * blksize;
92         const int blk = nstl::min(blksize, OC - oc);
93
94         PRAGMA_OMP_SIMD()
95         for (int i = 0; i < blk; ++i)
96             dst[offset + i] += bias[oc + i];
97     });
98 }
99
100 template <int blksize>
101 void ref_deconvolution_fwd_t::compute_fwd_bias_nCdhwXc_bf16() const {
102     auto bias = reinterpret_cast<const f32_data_t *>(this->input_memory(2));
103     auto dst = reinterpret_cast<bf16_data_t *>(this->memory());
104
105     const memory_desc_wrapper dst_d(pd()->dst_pd());
106
107     const int MB = pd()->MB();
108     const int OC = pd()->OC();
109     const int SP = pd()->OW() * pd()->OH() * pd()->OD();
110
111     const ptrdiff_t stride_mb = dst_d.blocking_desc().strides[0][0];
112     parallel_nd(MB, utils::div_up(OC, blksize), SP,
113         [&](int mb, int oc_blk, int sp) {
114         int oc = oc_blk * blksize;
115         auto offset = mb * stride_mb + oc * SP + sp * blksize;
116         const int blk = nstl::min(blksize, OC - oc);
117
118         f32_data_t dst_f32[blksize] = {0.0f};
119         bf16_cvt_utils::cvt_bfloat16_to_float(dst_f32, &dst[offset], blk);
120
121         PRAGMA_OMP_SIMD()
122         for (int i = 0; i < blk; ++i) {
123             dst_f32[i] += bias[oc + i];
124         }
125
126         bf16_cvt_utils::cvt_float_to_bfloat16(&dst[offset], dst_f32, blk);
127     });
128 }
129
130 void ref_deconvolution_bwd_weights_t::compute_bwd_bias() const {
131     auto diff_dst = reinterpret_cast<const f32_data_t *>(this->input_memory(1));
132     auto diff_bias = reinterpret_cast<f32_data_t *>(this->memory(1));
133     const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
134
135     const int G = pd()->G();
136     const int MB = pd()->MB();
137     const int OH = pd()->OH();
138     const int OW = pd()->OW();
139     const int OC = pd()->OC() / G;
140     const int OD = pd()->OD();
141     const int ndims = pd()->desc()->src_desc.ndims;
142
143     parallel_nd(G, OC, [&](int g, int oc) {
144         f32_data_t db = 0;
145         for (int mb = 0; mb < MB; ++mb) {
146             for (int od = 0; od < OD; ++od) {
147                 for (int oh = 0; oh < OH; ++oh) {
148                     for (int ow = 0; ow < OW; ++ow) {
149                         switch (ndims) {
150                         case 5:
151                             db += diff_dst[diff_dst_d.off(
152                                     mb, g * OC + oc, od, oh, ow)];
153                             break;
154                         case 4:
155                             db += diff_dst[diff_dst_d.off(
156                                     mb, g * OC + oc, oh, ow)];
157                             break;
158                         case 3:
159                             db += diff_dst[diff_dst_d.off(mb, g * OC + oc, ow)];
160                             break;
161                         default: assert(!"invalid dimension size");
162                         }
163                     }
164                 }
165             }
166         }
167         diff_bias[g * OC + oc] = db;
168     });
169 }
170
171 void ref_deconvolution_bwd_weights_t::compute_bwd_bias_ncdhw() const {
172     auto diff_dst = reinterpret_cast<const f32_data_t *>(this->input_memory(1));
173     auto diff_bias = reinterpret_cast<f32_data_t *>(this->memory(1));
174
175     const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
176
177     const int OC = pd()->OC();
178     const int MB = pd()->MB();
179     const int SP = pd()->OH()*pd()->OW()*pd()->OD();
180
181     parallel_nd(OC, [&](int oc) {
182         f32_data_t db = 0;
183         for (int mb = 0; mb < MB; ++mb) {
184             PRAGMA_OMP_SIMD()
185             for (int sp = 0; sp < SP; ++sp) {
186                 auto offset = (size_t)(mb * OC + oc) * SP + sp;
187                 db += diff_dst[offset];
188             }
189         }
190         diff_bias[oc] = db;
191     });
192 }
193
194 template <int blksize>
195 void ref_deconvolution_bwd_weights_t::compute_bwd_bias_nCdhwXc() const {
196     auto diff_dst = reinterpret_cast<const f32_data_t *>(this->input_memory(1));
197     auto diff_bias = reinterpret_cast<f32_data_t *>(this->memory(1));
198
199     const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
200
201     const int OC = pd()->OC();
202     const int MB = pd()->MB();
203     const int SP = pd()->OH() * pd()->OW() * pd()->OD();
204
205     const ptrdiff_t stride_mb = diff_dst_d.blocking_desc().strides[0][0];
206
207     parallel_nd(utils::div_up(OC, blksize), [&](int ocb) {
208         f32_data_t db[blksize] = {0};
209
210         for (int mb = 0; mb < MB; ++mb) {
211             for (int sp = 0; sp < SP; ++sp) {
212                 auto offset = mb * stride_mb + (ocb * SP + sp) * blksize;
213
214                 PRAGMA_OMP_SIMD()
215                 for (int i = 0; i < blksize; ++i)
216                     db[i] += diff_dst[offset+i];
217             }
218         }
219
220         const int blk = nstl::min(blksize, OC - ocb * blksize);
221
222         PRAGMA_OMP_SIMD()
223         for (int i = 0; i < blk; ++i)
224             diff_bias[ocb * blksize + i] = db[i];
225     });
226 }
227
228 template <int blksize>
229 void ref_deconvolution_bwd_weights_t::compute_bwd_bias_nCdhwXc_bf16() const {
230     auto diff_dst = reinterpret_cast<const bf16_data_t *>(this->input_memory(1));
231     auto diff_bias = reinterpret_cast<f32_data_t *>(this->memory(1));
232
233     const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
234
235     const int OC = pd()->OC();
236     const int MB = pd()->MB();
237     const int SP = pd()->OH() * pd()->OW() * pd()->OD();
238
239     const ptrdiff_t stride_mb = diff_dst_d.blocking_desc().strides[0][0];
240
241     parallel_nd(utils::div_up(OC, blksize), [&](int ocb) {
242         f32_data_t db[blksize] = {0};
243         f32_data_t ddst_f32[blksize] = {0};
244
245         for (int mb = 0; mb < MB; ++mb) {
246             for (int sp = 0; sp < SP; ++sp) {
247                 auto offset = mb * stride_mb + (ocb * SP + sp) * blksize;
248
249                 bf16_cvt_utils::cvt_bfloat16_to_float(ddst_f32, &diff_dst[offset], blksize);
250                 PRAGMA_OMP_SIMD()
251                 for (int i = 0; i < blksize; ++i)
252                     db[i] += ddst_f32[i];
253             }
254         }
255
256         const int blk = nstl::min(blksize, OC - ocb * blksize);
257
258         PRAGMA_OMP_SIMD()
259         for (int i = 0; i < blk; ++i)
260             diff_bias[ocb * blksize + i] = db[i];
261     });
262 }
263
264 template void ref_deconvolution_fwd_t::compute_fwd_bias_nCdhwXc<8>() const;
265 template void ref_deconvolution_fwd_t::compute_fwd_bias_nCdhwXc<16>() const;
266 template void ref_deconvolution_fwd_t::compute_fwd_bias_nCdhwXc_bf16<16>() const;
267 template void ref_deconvolution_bwd_weights_t::compute_bwd_bias_nCdhwXc<8>() const;
268 template void ref_deconvolution_bwd_weights_t::compute_bwd_bias_nCdhwXc<16>() const;
269 template void ref_deconvolution_bwd_weights_t::compute_bwd_bias_nCdhwXc_bf16<16>() const;
270
271 }
272 }
273 }
274
275 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s