Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / ref_binary_convolution.cpp
1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <common/utils.hpp>
18 #include <common/primitive_attr.hpp>
19 #include "c_types_map.hpp"
20 #include "type_helpers.hpp"
21 #include "mkldnn_thread.hpp"
22 #include "mkldnn_traits.hpp"
23 #include "math_utils.hpp"
24
25 #include "ref_binary_convolution.hpp"
26
27 namespace mkldnn {
28 namespace impl {
29 namespace cpu {
30
31 using math::saturate;
32
33 void _ref_binary_convolution_fwd_t::execute_forward() const {
34     auto src = reinterpret_cast<const uint8_t*>(this->input_memory(0));
35     auto weights = reinterpret_cast<const uint8_t*>(this->input_memory(1));
36
37     const memory_desc_wrapper src_d(pd()->src_pd());
38     const memory_desc_wrapper dst_d(pd()->dst_pd());
39     const memory_desc_wrapper weights_d(pd()->weights_pd(0));
40
41     const bool with_groups = pd()->with_groups();
42
43     const int G = pd()->G();
44     const int MB = pd()->MB();
45     const int OD = pd()->OD();
46     const int OH = pd()->OH();
47     const int OW = pd()->OW();
48     const int ID = pd()->ID();
49     const int IH = pd()->IH();
50     const int IW = pd()->IW();
51
52     const int OC = pd()->OC() / G;
53     const int IC = pd()->IC() / G;
54     const int KD = pd()->KD();
55     const int KH = pd()->KH();
56     const int KW = pd()->KW();
57
58     const int KSD = pd()->KSD();
59     const int KSH = pd()->KSH();
60     const int KSW = pd()->KSW();
61
62     const int KDD = pd()->KDD();
63     const int KDH = pd()->KDH();
64     const int KDW = pd()->KDW();
65
66     const int padFront = pd()->padFront();
67     const int padT = pd()->padT();
68     const int padL = pd()->padL();
69
70     const float pad_value = pd()->pad_value();
71
72     const int ndims = pd()->cdesc()->src_desc.ndims;
73
74     const int nbits = 8;
75
76     const auto &p = pd()->attr()->post_ops_;
77     bool with_sum = p.find(primitive_kind::sum) != -1;
78     bool with_binarization = p.find(primitive_kind::binarization) != -1;
79
80     auto extract_bit = [](uint8_t val, uint8_t bit) -> uint8_t {
81         return (uint8_t)((val >> bit) & 0x0001);
82     };
83
84     auto ker = [=](int32_t &d, int g, int mb, int oc, int od, int oh, int ow) {
85         for (int ic = 0; ic < IC; ++ic)
86         for (int kd = 0; kd < KD; ++kd)
87         for (int kh = 0; kh < KH; ++kh)
88         for (int kw = 0; kw < KW; ++kw) {
89             const int id = od * KSD - padFront + kd * (1 + KDD);
90             const int ih = oh * KSH - padT + kh * (1 + KDH);
91             const int iw = ow * KSW - padL + kw * (1 + KDW);
92
93             size_t iidx = 0;
94             size_t widx = 0;
95             if (ndims == 5) {
96                 iidx = src_d.off(mb, g * IC + ic, id, ih, iw);
97                 widx = with_groups ? weights_d.off(g, oc, ic, kd, kh, kw)
98                                    : weights_d.off(oc, ic, kd, kh, kw);
99             } else if (ndims == 4) {
100                 iidx = src_d.off(mb, g * IC + ic, ih, iw);
101                 widx = with_groups ? weights_d.off(g, oc, ic, kh, kw)
102                                    : weights_d.off(oc, ic, kh, kw);
103             } else if (ndims == 3) {
104                 iidx = src_d.off(mb, g * IC + ic, iw);
105                 widx = with_groups ? weights_d.off(g, oc, ic, kw)
106                                    : weights_d.off(oc, ic, kw);
107             } else {
108                 assert(false);
109             }
110
111
112             uint8_t s;
113             if (id < 0 || id >= ID || ih < 0 || ih >= IH || iw < 0 || iw >= IW) {
114                 if (pad_value == 0)
115                     continue;
116                 else {
117                     s = pad_value == 1.0f ? (uint8_t)1 : (uint8_t)0;
118                 }
119             }  else {
120                 s = extract_bit(src[iidx/nbits], (uint8_t)(iidx % nbits));
121             }
122
123             uint8_t w = extract_bit(weights[widx/nbits], (uint8_t)(widx % nbits));
124
125             d += (int32_t)(s ^ w);
126        }
127     };
128
129     if (with_binarization) {
130         auto dst = reinterpret_cast<uint8_t*>(this->memory());
131
132         int binarization_idx = p.find(primitive_kind::binarization);
133         const float* binarization_weights = p.entry_[binarization_idx].binarization.weights_data;
134
135         parallel_nd(G, MB, utils::div_up(OC, nbits), OD, OH, OW,
136             [&](int g, int mb, int ocb, int od, int oh, int ow) {
137
138             uint8_t bin_val = 0x00;
139             for (int oc = ocb * nbits, shift = 0; oc < std::min(OC, (ocb + 1) * nbits); oc++, shift++) {
140                 int32_t a = 0;
141                 ker(a, g, mb, oc, od, oh, ow);
142
143                 float base_value;
144                 if (pad_value == 0.0f) {
145                     const int i_left_overflow = nstl::max(0, (padL - ow * KSW));
146                     const int i_right_overflow = nstl::max(IW, (ow * KSW + (KW - 1) * (KDW + 1) - padL + 1)) - IW;
147                     const int kw_padding =
148                             KW - utils::div_up(i_left_overflow, (KDW + 1)) - utils::div_up(i_right_overflow, (KDW + 1));
149
150                     const int i_top_overflow = nstl::max(0, (padT - oh * KSH));
151                     const int i_bottom_overflow = nstl::max(IH, (oh * KSH + (KH - 1) * (KDH + 1) - padT + 1)) - IH;
152                     const int kh_padding =
153                             KH - utils::div_up(i_top_overflow, (KDH + 1)) - utils::div_up(i_bottom_overflow, (KDH + 1));
154
155                     const int i_front_overflow = nstl::max(0, (padFront - od * KSD));
156                     const int i_back_overflow = nstl::max(ID, (od * KSD + (KD - 1) * (KDD + 1) - padFront + 1)) - ID;
157                     const int kd_padding =
158                             KD - utils::div_up(i_front_overflow, (KDD + 1)) - utils::div_up(i_back_overflow, (KDD + 1));
159
160                     base_value = IC * kd_padding * kh_padding * kw_padding;
161                 } else {
162                     base_value = IC * KD * KH * KW;
163                 }
164
165                 float a_fp = base_value - (float)(2 * a);
166
167                 if (with_sum) {
168                     if (ndims == 5)
169                         a_fp += dst[dst_d.off(mb, g * OC + oc, od, oh, ow)];
170                     else if (ndims == 4)
171                         a_fp += dst[dst_d.off(mb, g * OC + oc, oh, ow)];
172                     else if (ndims == 3)
173                         a_fp += dst[dst_d.off(mb, g * OC + oc, ow)];
174                     else
175                         assert(false);
176                 }
177
178                 int eltwise_inj_idx = 0;
179                 int depthwise_inj_idx = 0;
180                 for (int i = 0; i < p.len_; i++) {
181                     auto &post_op = p.entry_[i];
182                     if (post_op.is_eltwise()) {
183                         a_fp = eltwise_injectors[eltwise_inj_idx]->compute_scalar(a_fp);
184                         eltwise_inj_idx++;
185                     } else if (post_op.is_depthwise()) {
186                         auto depthwise_weights = post_op.depthwise.weights_data;
187                         auto depthwise_bias = post_op.depthwise.biases_data;
188
189                         a_fp = depthwise_injectors[depthwise_inj_idx]->compute_scalar(a_fp,
190                                                                                       depthwise_weights + g * OC + oc,
191                                                                                       depthwise_bias + g * OC + oc);
192                         depthwise_inj_idx++;
193                     }
194                 }
195
196                 float thr = binarization_weights[g * OC + oc];
197                 auto bit = uint8_t((a_fp > thr) ? 0x01 : 0x00);
198                 bin_val |= (bit << shift);
199             }
200
201             if (ndims == 5)
202                 dst[dst_d.off(mb, g*OC + ocb*nbits, od, oh, ow) / nbits] = bin_val;
203             else if (ndims == 4)
204                 dst[dst_d.off(mb, g*OC + ocb*nbits, oh, ow) / nbits] = bin_val;
205             else if (ndims == 3)
206                 dst[dst_d.off(mb, g*OC + ocb*nbits, ow) / nbits] = bin_val;
207             else
208                 assert(false);
209         });
210     } else {
211         auto dst = reinterpret_cast<float*>(this->memory());
212
213         parallel_nd(G, MB, OC, OD, OH, OW,
214             [&](int g, int mb, int oc, int od, int oh, int ow) {
215             int32_t a = 0;
216             ker(a, g, mb, oc, od, oh, ow);
217
218             float base_value;
219             if (pad_value == 0.0f) {
220                 const int i_left_overflow = nstl::max(0, (padL - ow * KSW));
221                 const int i_right_overflow = nstl::max(IW, (ow * KSW + (KW - 1) * (KDW + 1) - padL + 1)) - IW;
222                 const int kw_padding =
223                         KW - utils::div_up(i_left_overflow, (KDW + 1)) - utils::div_up(i_right_overflow, (KDW + 1));
224
225                 const int i_top_overflow = nstl::max(0, (padT - oh * KSH));
226                 const int i_bottom_overflow = nstl::max(IH, (oh * KSH + (KH - 1) * (KDH + 1) - padT + 1)) - IH;
227                 const int kh_padding =
228                         KH - utils::div_up(i_top_overflow, (KDH + 1)) - utils::div_up(i_bottom_overflow, (KDH + 1));
229
230                 const int i_front_overflow = nstl::max(0, (padFront - od * KSD));
231                 const int i_back_overflow = nstl::max(ID, (od * KSD + (KD - 1) * (KDD + 1) - padFront + 1)) - ID;
232                 const int kd_padding =
233                         KD - utils::div_up(i_front_overflow, (KDD + 1)) - utils::div_up(i_back_overflow, (KDD + 1));
234
235                 base_value = IC * kd_padding * kh_padding * kw_padding;
236             } else {
237                 base_value = IC * KD * KH * KW;
238             }
239
240             float a_fp = base_value - (float)(2 * a);
241
242             if (with_sum) {
243                 if (ndims == 5)
244                     a_fp += dst[dst_d.off(mb, g*OC + oc, od, oh, ow)];
245                 else if (ndims == 4)
246                     a_fp += dst[dst_d.off(mb, g*OC + oc, oh, ow)];
247                 else if (ndims == 3)
248                     a_fp += dst[dst_d.off(mb, g*OC + oc, ow)];
249                 else
250                     assert(false);
251             }
252
253             int eltwise_inj_idx = 0;
254             int depthwise_inj_idx = 0;
255             for (int i = 0; i < p.len_; i++) {
256                 auto& post_op = p.entry_[i];
257                 if (post_op.is_eltwise()) {
258                     a_fp = eltwise_injectors[eltwise_inj_idx]->compute_scalar(a_fp);
259                     eltwise_inj_idx++;
260                 } else if (post_op.is_depthwise()) {
261                     auto depthwise_weights = post_op.depthwise.weights_data;
262                     auto depthwise_bias = post_op.depthwise.biases_data;
263
264                     a_fp = depthwise_injectors[depthwise_inj_idx]->compute_scalar(a_fp, depthwise_weights + g * OC + oc,
265                                                                                         depthwise_bias + g * OC + oc);
266                     depthwise_inj_idx++;
267                 }
268             }
269
270             if (ndims == 5)
271                 dst[dst_d.off(mb, g*OC + oc, od, oh, ow)] = a_fp;
272             else if (ndims == 4)
273                 dst[dst_d.off(mb, g*OC + oc, oh, ow)] = a_fp;
274             else if (ndims == 3)
275                 dst[dst_d.off(mb, g*OC + oc, ow)] = a_fp;
276             else
277                 assert(false);
278         });
279     }
280 }
281
282 }
283 }
284 }