updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / ref_binary_convolution.cpp
1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <common/utils.hpp>
18 #include <common/primitive_attr.hpp>
19 #include "c_types_map.hpp"
20 #include "type_helpers.hpp"
21 #include "mkldnn_thread.hpp"
22 #include "mkldnn_traits.hpp"
23 #include "math_utils.hpp"
24
25 #include "ref_binary_convolution.hpp"
26
27 namespace mkldnn {
28 namespace impl {
29 namespace cpu {
30
31 using math::saturate;
32
33 void _ref_binary_convolution_fwd_t::execute_forward() const {
34     auto src = reinterpret_cast<const uint8_t*>(this->input_memory(0));
35     auto weights = reinterpret_cast<const uint8_t*>(this->input_memory(1));
36
37     const memory_desc_wrapper src_d(pd()->src_pd());
38     const memory_desc_wrapper dst_d(pd()->dst_pd());
39     const memory_desc_wrapper weights_d(pd()->weights_pd(0));
40
41     const bool with_groups = pd()->with_groups();
42
43     const int G = pd()->G();
44     const int MB = pd()->MB();
45     const int OD = pd()->OD();
46     const int OH = pd()->OH();
47     const int OW = pd()->OW();
48     const int ID = pd()->ID();
49     const int IH = pd()->IH();
50     const int IW = pd()->IW();
51
52     const int OC = pd()->OC() / G;
53     const int IC = pd()->IC() / G;
54     const int KD = pd()->KD();
55     const int KH = pd()->KH();
56     const int KW = pd()->KW();
57
58     const int KSD = pd()->KSD();
59     const int KSH = pd()->KSH();
60     const int KSW = pd()->KSW();
61
62     const int KDD = pd()->KDD();
63     const int KDH = pd()->KDH();
64     const int KDW = pd()->KDW();
65
66     const int padFront = pd()->padFront();
67     const int padT = pd()->padT();
68     const int padL = pd()->padL();
69
70     const float pad_value = pd()->pad_value();
71
72     const int ndims = pd()->cdesc()->src_desc.ndims;
73
74     const int nbits = 8;
75
76     const auto &p = pd()->attr()->post_ops_;
77     bool with_sum = p.find(primitive_kind::sum) != -1;
78     bool with_binarization = p.find(primitive_kind::binarization) != -1;
79
80     auto extract_bit = [](uint8_t val, uint8_t bit) -> uint8_t {
81         return (uint8_t)((val >> bit) & 0x0001);
82     };
83
84     auto ker = [=](int32_t &d, int g, int mb, int oc, int od, int oh, int ow) {
85         for (int ic = 0; ic < IC; ++ic)
86         for (int kd = 0; kd < KD; ++kd)
87         for (int kh = 0; kh < KH; ++kh)
88         for (int kw = 0; kw < KW; ++kw) {
89             const int id = od * KSD - padFront + kd * (1 + KDD);
90             const int ih = oh * KSH - padT + kh * (1 + KDH);
91             const int iw = ow * KSW - padL + kw * (1 + KDW);
92
93             size_t iidx = 0;
94             size_t widx = 0;
95             if (ndims == 5) {
96                 iidx = src_d.off(mb, g * IC + ic, id, ih, iw);
97                 widx = with_groups ? weights_d.off(g, oc, ic, kd, kh, kw)
98                                    : weights_d.off(oc, ic, kd, kh, kw);
99             } else if (ndims == 4) {
100                 iidx = src_d.off(mb, g * IC + ic, ih, iw);
101                 widx = with_groups ? weights_d.off(g, oc, ic, kh, kw)
102                                    : weights_d.off(oc, ic, kh, kw);
103             } else if (ndims == 3) {
104                 iidx = src_d.off(mb, g * IC + ic, iw);
105                 widx = with_groups ? weights_d.off(g, oc, ic, kw)
106                                    : weights_d.off(oc, ic, kw);
107             } else {
108                 assert(false);
109             }
110
111
112             uint8_t s;
113             if (id < 0 || id >= ID || ih < 0 || ih >= IH || iw < 0 || iw >= IW) {
114                 if (pad_value == 0)
115                     continue;
116                 else {
117                     s = pad_value == 1.0f ? (uint8_t)1 : (uint8_t)0;
118                 }
119             }  else {
120                 s = extract_bit(src[iidx/nbits], (uint8_t)(iidx % nbits));
121             }
122
123             uint8_t w = extract_bit(weights[widx/nbits], (uint8_t)(widx % nbits));
124
125             d += (int32_t)(s ^ w);
126        }
127     };
128
129     if (with_binarization) {
130         auto dst = reinterpret_cast<uint8_t*>(this->memory());
131
132         int binarization_idx = p.find(primitive_kind::binarization);
133         const float* binarization_weights = p.entry_[binarization_idx].binarization.weights_data;
134         const uint32_t* binarization_output_mask = (uint32_t*)p.entry_[binarization_idx].binarization.output_mask_data;
135
136         parallel_nd(G, MB, utils::div_up(OC, nbits), OD, OH, OW,
137             [&](int g, int mb, int ocb, int od, int oh, int ow) {
138
139             uint8_t bin_val = 0x00;
140             for (int oc = ocb * nbits, shift = 0; oc < std::min(OC, (ocb + 1) * nbits); oc++, shift++) {
141                 int32_t a = 0;
142                 ker(a, g, mb, oc, od, oh, ow);
143
144                 float base_value;
145                 if (pad_value == 0.0f) {
146                     const int i_left_overflow = nstl::max(0, (padL - ow * KSW));
147                     const int i_right_overflow = nstl::max(IW, (ow * KSW + (KW - 1) * (KDW + 1) - padL + 1)) - IW;
148                     const int kw_padding =
149                             KW - utils::div_up(i_left_overflow, (KDW + 1)) - utils::div_up(i_right_overflow, (KDW + 1));
150
151                     const int i_top_overflow = nstl::max(0, (padT - oh * KSH));
152                     const int i_bottom_overflow = nstl::max(IH, (oh * KSH + (KH - 1) * (KDH + 1) - padT + 1)) - IH;
153                     const int kh_padding =
154                             KH - utils::div_up(i_top_overflow, (KDH + 1)) - utils::div_up(i_bottom_overflow, (KDH + 1));
155
156                     const int i_front_overflow = nstl::max(0, (padFront - od * KSD));
157                     const int i_back_overflow = nstl::max(ID, (od * KSD + (KD - 1) * (KDD + 1) - padFront + 1)) - ID;
158                     const int kd_padding =
159                             KD - utils::div_up(i_front_overflow, (KDD + 1)) - utils::div_up(i_back_overflow, (KDD + 1));
160
161                     base_value = IC * kd_padding * kh_padding * kw_padding;
162                 } else {
163                     base_value = IC * KD * KH * KW;
164                 }
165
166                 float a_fp = base_value - (float)(2 * a);
167
168                 if (with_sum) {
169                     if (ndims == 5)
170                         a_fp += dst[dst_d.off(mb, g * OC + oc, od, oh, ow)];
171                     else if (ndims == 4)
172                         a_fp += dst[dst_d.off(mb, g * OC + oc, oh, ow)];
173                     else if (ndims == 3)
174                         a_fp += dst[dst_d.off(mb, g * OC + oc, ow)];
175                     else
176                         assert(false);
177                 }
178
179                 int eltwise_inj_idx = 0;
180                 int depthwise_inj_idx = 0;
181                 for (int i = 0; i < p.len_; i++) {
182                     auto &post_op = p.entry_[i];
183                     if (post_op.is_eltwise()) {
184                         a_fp = eltwise_injectors[eltwise_inj_idx]->compute_scalar(a_fp);
185                         eltwise_inj_idx++;
186                     } else if (post_op.is_depthwise()) {
187                         auto depthwise_weights = post_op.depthwise.weights_data;
188                         auto depthwise_bias = post_op.depthwise.biases_data;
189
190                         a_fp = depthwise_injectors[depthwise_inj_idx]->compute_scalar(a_fp,
191                                                                                       depthwise_weights + g * OC + oc,
192                                                                                       depthwise_bias + g * OC + oc);
193                         depthwise_inj_idx++;
194                     }
195                 }
196
197                 float thr = binarization_weights[g * OC + oc];
198                 uint32_t out_mask = binarization_output_mask[g * OC + oc];
199                 uint32_t res = (a_fp > thr) ? 0xffffffff : 0x00000000;
200
201                 auto bit = uint8_t((res == out_mask) ? 0x01 : 0x00);
202                 bin_val |= (bit << shift);
203             }
204
205             if (ndims == 5)
206                 dst[dst_d.off(mb, g*OC + ocb*nbits, od, oh, ow) / nbits] = bin_val;
207             else if (ndims == 4)
208                 dst[dst_d.off(mb, g*OC + ocb*nbits, oh, ow) / nbits] = bin_val;
209             else if (ndims == 3)
210                 dst[dst_d.off(mb, g*OC + ocb*nbits, ow) / nbits] = bin_val;
211             else
212                 assert(false);
213         });
214     } else {
215         auto dst = reinterpret_cast<float*>(this->memory());
216
217         parallel_nd(G, MB, OC, OD, OH, OW,
218             [&](int g, int mb, int oc, int od, int oh, int ow) {
219             int32_t a = 0;
220             ker(a, g, mb, oc, od, oh, ow);
221
222             float base_value;
223             if (pad_value == 0.0f) {
224                 const int i_left_overflow = nstl::max(0, (padL - ow * KSW));
225                 const int i_right_overflow = nstl::max(IW, (ow * KSW + (KW - 1) * (KDW + 1) - padL + 1)) - IW;
226                 const int kw_padding =
227                         KW - utils::div_up(i_left_overflow, (KDW + 1)) - utils::div_up(i_right_overflow, (KDW + 1));
228
229                 const int i_top_overflow = nstl::max(0, (padT - oh * KSH));
230                 const int i_bottom_overflow = nstl::max(IH, (oh * KSH + (KH - 1) * (KDH + 1) - padT + 1)) - IH;
231                 const int kh_padding =
232                         KH - utils::div_up(i_top_overflow, (KDH + 1)) - utils::div_up(i_bottom_overflow, (KDH + 1));
233
234                 const int i_front_overflow = nstl::max(0, (padFront - od * KSD));
235                 const int i_back_overflow = nstl::max(ID, (od * KSD + (KD - 1) * (KDD + 1) - padFront + 1)) - ID;
236                 const int kd_padding =
237                         KD - utils::div_up(i_front_overflow, (KDD + 1)) - utils::div_up(i_back_overflow, (KDD + 1));
238
239                 base_value = IC * kd_padding * kh_padding * kw_padding;
240             } else {
241                 base_value = IC * KD * KH * KW;
242             }
243
244             float a_fp = base_value - (float)(2 * a);
245
246             if (with_sum) {
247                 if (ndims == 5)
248                     a_fp += dst[dst_d.off(mb, g*OC + oc, od, oh, ow)];
249                 else if (ndims == 4)
250                     a_fp += dst[dst_d.off(mb, g*OC + oc, oh, ow)];
251                 else if (ndims == 3)
252                     a_fp += dst[dst_d.off(mb, g*OC + oc, ow)];
253                 else
254                     assert(false);
255             }
256
257             int eltwise_inj_idx = 0;
258             int depthwise_inj_idx = 0;
259             for (int i = 0; i < p.len_; i++) {
260                 auto& post_op = p.entry_[i];
261                 if (post_op.is_eltwise()) {
262                     a_fp = eltwise_injectors[eltwise_inj_idx]->compute_scalar(a_fp);
263                     eltwise_inj_idx++;
264                 } else if (post_op.is_depthwise()) {
265                     auto depthwise_weights = post_op.depthwise.weights_data;
266                     auto depthwise_bias = post_op.depthwise.biases_data;
267
268                     a_fp = depthwise_injectors[depthwise_inj_idx]->compute_scalar(a_fp, depthwise_weights + g * OC + oc,
269                                                                                         depthwise_bias + g * OC + oc);
270                     depthwise_inj_idx++;
271                 }
272             }
273
274             if (ndims == 5)
275                 dst[dst_d.off(mb, g*OC + oc, od, oh, ow)] = a_fp;
276             else if (ndims == 4)
277                 dst[dst_d.off(mb, g*OC + oc, oh, ow)] = a_fp;
278             else if (ndims == 3)
279                 dst[dst_d.off(mb, g*OC + oc, ow)] = a_fp;
280             else
281                 assert(false);
282         });
283     }
284 }
285
286 }
287 }
288 }