1 /*******************************************************************************
2 * Copyright 2018-2019 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "c_types_map.hpp"
18 #include "type_helpers.hpp"
19 #include "mkldnn_thread.hpp"
20 #include "mkldnn_traits.hpp"
21 #include "math_utils.hpp"
23 #include "ref_deconvolution.hpp"
25 #include "bfloat16_utils.hpp"
31 void ref_deconvolution_fwd_t::compute_fwd_bias() const {
32 auto bias = reinterpret_cast<const f32_data_t *>(this->input_memory(2));
33 auto dst = reinterpret_cast<f32_data_t *>(this->memory());
34 const memory_desc_wrapper dst_d(pd()->dst_pd());
36 const int G = pd()->G();
37 const int MB = pd()->MB();
38 const int OH = pd()->OH();
39 const int OW = pd()->OW();
40 const int OD = pd()->OD();
41 const int OC = pd()->OC() / G;
42 const int ndims = pd()->desc()->src_desc.ndims;
44 parallel_nd(MB, G, OC, OD, OH, OW,
45 [&](int mb, int g, int oc, int od, int oh, int ow) {
46 auto b = bias[g * OC + oc];
48 case 5: dst[dst_d.off(mb, g * OC + oc, od, oh, ow)] += b; break;
49 case 4: dst[dst_d.off(mb, g * OC + oc, oh, ow)] += b; break;
50 case 3: dst[dst_d.off(mb, g * OC + oc, ow)] += b; break;
51 default: assert(!"invalid dimension size");
56 void ref_deconvolution_fwd_t::compute_fwd_bias_ncdhw() const {
57 auto bias = reinterpret_cast<const f32_data_t *>(this->input_memory(2));
58 auto dst = reinterpret_cast<f32_data_t *>(this->memory());
60 const memory_desc_wrapper dst_d(pd()->dst_pd());
62 const int MB = pd()->MB();
63 const int OC = pd()->OC();
64 const int SP = pd()->OW()*pd()->OH()*pd()->OD();
66 parallel_nd(MB, OC, [&](int mb, int oc) {
68 for (int sp = 0; sp < SP; ++sp) {
69 auto offset = (size_t)(mb * OC + oc) * SP + sp;
70 dst[offset] += bias[oc];
75 template <int blksize>
76 void ref_deconvolution_fwd_t::compute_fwd_bias_nCdhwXc() const {
77 auto bias = reinterpret_cast<const f32_data_t *>(this->input_memory(2));
78 auto dst = reinterpret_cast<f32_data_t *>(this->memory());
80 const memory_desc_wrapper dst_d(pd()->dst_pd());
82 const int MB = pd()->MB();
83 const int OC = pd()->OC();
84 const int SP = pd()->OW() * pd()->OH() * pd()->OD();
86 const ptrdiff_t stride_mb = dst_d.blocking_desc().strides[0][0];
88 parallel_nd(MB, utils::div_up(OC, blksize), SP,
89 [&](int mb, int oc_blk, int sp) {
90 int oc = oc_blk * blksize;
91 auto offset = mb * stride_mb + oc * SP + sp * blksize;
92 const int blk = nstl::min(blksize, OC - oc);
95 for (int i = 0; i < blk; ++i)
96 dst[offset + i] += bias[oc + i];
100 template <int blksize>
101 void ref_deconvolution_fwd_t::compute_fwd_bias_nCdhwXc_bf16() const {
102 auto bias = reinterpret_cast<const f32_data_t *>(this->input_memory(2));
103 auto dst = reinterpret_cast<bf16_data_t *>(this->memory());
105 const memory_desc_wrapper dst_d(pd()->dst_pd());
107 const int MB = pd()->MB();
108 const int OC = pd()->OC();
109 const int SP = pd()->OW() * pd()->OH() * pd()->OD();
111 const ptrdiff_t stride_mb = dst_d.blocking_desc().strides[0][0];
112 parallel_nd(MB, utils::div_up(OC, blksize), SP,
113 [&](int mb, int oc_blk, int sp) {
114 int oc = oc_blk * blksize;
115 auto offset = mb * stride_mb + oc * SP + sp * blksize;
116 const int blk = nstl::min(blksize, OC - oc);
118 f32_data_t dst_f32[blksize] = {0.0f};
119 bf16_cvt_utils::cvt_bfloat16_to_float(dst_f32, &dst[offset], blk);
122 for (int i = 0; i < blk; ++i) {
123 dst_f32[i] += bias[oc + i];
126 bf16_cvt_utils::cvt_float_to_bfloat16(&dst[offset], dst_f32, blk);
130 void ref_deconvolution_bwd_weights_t::compute_bwd_bias() const {
131 auto diff_dst = reinterpret_cast<const f32_data_t *>(this->input_memory(1));
132 auto diff_bias = reinterpret_cast<f32_data_t *>(this->memory(1));
133 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
135 const int G = pd()->G();
136 const int MB = pd()->MB();
137 const int OH = pd()->OH();
138 const int OW = pd()->OW();
139 const int OC = pd()->OC() / G;
140 const int OD = pd()->OD();
141 const int ndims = pd()->desc()->src_desc.ndims;
143 parallel_nd(G, OC, [&](int g, int oc) {
145 for (int mb = 0; mb < MB; ++mb) {
146 for (int od = 0; od < OD; ++od) {
147 for (int oh = 0; oh < OH; ++oh) {
148 for (int ow = 0; ow < OW; ++ow) {
151 db += diff_dst[diff_dst_d.off(
152 mb, g * OC + oc, od, oh, ow)];
155 db += diff_dst[diff_dst_d.off(
156 mb, g * OC + oc, oh, ow)];
159 db += diff_dst[diff_dst_d.off(mb, g * OC + oc, ow)];
161 default: assert(!"invalid dimension size");
167 diff_bias[g * OC + oc] = db;
171 void ref_deconvolution_bwd_weights_t::compute_bwd_bias_ncdhw() const {
172 auto diff_dst = reinterpret_cast<const f32_data_t *>(this->input_memory(1));
173 auto diff_bias = reinterpret_cast<f32_data_t *>(this->memory(1));
175 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
177 const int OC = pd()->OC();
178 const int MB = pd()->MB();
179 const int SP = pd()->OH()*pd()->OW()*pd()->OD();
181 parallel_nd(OC, [&](int oc) {
183 for (int mb = 0; mb < MB; ++mb) {
185 for (int sp = 0; sp < SP; ++sp) {
186 auto offset = (size_t)(mb * OC + oc) * SP + sp;
187 db += diff_dst[offset];
194 template <int blksize>
195 void ref_deconvolution_bwd_weights_t::compute_bwd_bias_nCdhwXc() const {
196 auto diff_dst = reinterpret_cast<const f32_data_t *>(this->input_memory(1));
197 auto diff_bias = reinterpret_cast<f32_data_t *>(this->memory(1));
199 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
201 const int OC = pd()->OC();
202 const int MB = pd()->MB();
203 const int SP = pd()->OH() * pd()->OW() * pd()->OD();
205 const ptrdiff_t stride_mb = diff_dst_d.blocking_desc().strides[0][0];
207 parallel_nd(utils::div_up(OC, blksize), [&](int ocb) {
208 f32_data_t db[blksize] = {0};
210 for (int mb = 0; mb < MB; ++mb) {
211 for (int sp = 0; sp < SP; ++sp) {
212 auto offset = mb * stride_mb + (ocb * SP + sp) * blksize;
215 for (int i = 0; i < blksize; ++i)
216 db[i] += diff_dst[offset+i];
220 const int blk = nstl::min(blksize, OC - ocb * blksize);
223 for (int i = 0; i < blk; ++i)
224 diff_bias[ocb * blksize + i] = db[i];
228 template <int blksize>
229 void ref_deconvolution_bwd_weights_t::compute_bwd_bias_nCdhwXc_bf16() const {
230 auto diff_dst = reinterpret_cast<const bf16_data_t *>(this->input_memory(1));
231 auto diff_bias = reinterpret_cast<f32_data_t *>(this->memory(1));
233 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
235 const int OC = pd()->OC();
236 const int MB = pd()->MB();
237 const int SP = pd()->OH() * pd()->OW() * pd()->OD();
239 const ptrdiff_t stride_mb = diff_dst_d.blocking_desc().strides[0][0];
241 parallel_nd(utils::div_up(OC, blksize), [&](int ocb) {
242 f32_data_t db[blksize] = {0};
243 f32_data_t ddst_f32[blksize] = {0};
245 for (int mb = 0; mb < MB; ++mb) {
246 for (int sp = 0; sp < SP; ++sp) {
247 auto offset = mb * stride_mb + (ocb * SP + sp) * blksize;
249 bf16_cvt_utils::cvt_bfloat16_to_float(ddst_f32, &diff_dst[offset], blksize);
251 for (int i = 0; i < blksize; ++i)
252 db[i] += ddst_f32[i];
256 const int blk = nstl::min(blksize, OC - ocb * blksize);
259 for (int i = 0; i < blk; ++i)
260 diff_bias[ocb * blksize + i] = db[i];
264 template void ref_deconvolution_fwd_t::compute_fwd_bias_nCdhwXc<8>() const;
265 template void ref_deconvolution_fwd_t::compute_fwd_bias_nCdhwXc<16>() const;
266 template void ref_deconvolution_fwd_t::compute_fwd_bias_nCdhwXc_bf16<16>() const;
267 template void ref_deconvolution_bwd_weights_t::compute_bwd_bias_nCdhwXc<8>() const;
268 template void ref_deconvolution_bwd_weights_t::compute_bwd_bias_nCdhwXc<16>() const;
269 template void ref_deconvolution_bwd_weights_t::compute_bwd_bias_nCdhwXc_bf16<16>() const;
275 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s