1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include <common/utils.hpp>
18 #include <common/primitive_attr.hpp>
19 #include "c_types_map.hpp"
20 #include "type_helpers.hpp"
21 #include "mkldnn_thread.hpp"
22 #include "mkldnn_traits.hpp"
23 #include "math_utils.hpp"
25 #include "ref_binary_convolution.hpp"
33 void _ref_binary_convolution_fwd_t::execute_forward() const {
34 auto src = reinterpret_cast<const uint8_t*>(this->input_memory(0));
35 auto weights = reinterpret_cast<const uint8_t*>(this->input_memory(1));
37 const memory_desc_wrapper src_d(pd()->src_pd());
38 const memory_desc_wrapper dst_d(pd()->dst_pd());
39 const memory_desc_wrapper weights_d(pd()->weights_pd(0));
41 const bool with_groups = pd()->with_groups();
43 const int G = pd()->G();
44 const int MB = pd()->MB();
45 const int OD = pd()->OD();
46 const int OH = pd()->OH();
47 const int OW = pd()->OW();
48 const int ID = pd()->ID();
49 const int IH = pd()->IH();
50 const int IW = pd()->IW();
52 const int OC = pd()->OC() / G;
53 const int IC = pd()->IC() / G;
54 const int KD = pd()->KD();
55 const int KH = pd()->KH();
56 const int KW = pd()->KW();
58 const int KSD = pd()->KSD();
59 const int KSH = pd()->KSH();
60 const int KSW = pd()->KSW();
62 const int KDD = pd()->KDD();
63 const int KDH = pd()->KDH();
64 const int KDW = pd()->KDW();
66 const int padFront = pd()->padFront();
67 const int padT = pd()->padT();
68 const int padL = pd()->padL();
70 const float pad_value = pd()->pad_value();
72 const int ndims = pd()->cdesc()->src_desc.ndims;
76 const auto &p = pd()->attr()->post_ops_;
77 bool with_sum = p.find(primitive_kind::sum) != -1;
78 bool with_binarization = p.find(primitive_kind::binarization) != -1;
80 auto extract_bit = [](uint8_t val, uint8_t bit) -> uint8_t {
81 return (uint8_t)((val >> bit) & 0x0001);
84 auto ker = [=](int32_t &d, int g, int mb, int oc, int od, int oh, int ow) {
85 for (int ic = 0; ic < IC; ++ic)
86 for (int kd = 0; kd < KD; ++kd)
87 for (int kh = 0; kh < KH; ++kh)
88 for (int kw = 0; kw < KW; ++kw) {
89 const int id = od * KSD - padFront + kd * (1 + KDD);
90 const int ih = oh * KSH - padT + kh * (1 + KDH);
91 const int iw = ow * KSW - padL + kw * (1 + KDW);
96 iidx = src_d.off(mb, g * IC + ic, id, ih, iw);
97 widx = with_groups ? weights_d.off(g, oc, ic, kd, kh, kw)
98 : weights_d.off(oc, ic, kd, kh, kw);
99 } else if (ndims == 4) {
100 iidx = src_d.off(mb, g * IC + ic, ih, iw);
101 widx = with_groups ? weights_d.off(g, oc, ic, kh, kw)
102 : weights_d.off(oc, ic, kh, kw);
103 } else if (ndims == 3) {
104 iidx = src_d.off(mb, g * IC + ic, iw);
105 widx = with_groups ? weights_d.off(g, oc, ic, kw)
106 : weights_d.off(oc, ic, kw);
113 if (id < 0 || id >= ID || ih < 0 || ih >= IH || iw < 0 || iw >= IW) {
117 s = pad_value == 1.0f ? (uint8_t)1 : (uint8_t)0;
120 s = extract_bit(src[iidx/nbits], (uint8_t)(iidx % nbits));
123 uint8_t w = extract_bit(weights[widx/nbits], (uint8_t)(widx % nbits));
125 d += (int32_t)(s ^ w);
129 if (with_binarization) {
130 auto dst = reinterpret_cast<uint8_t*>(this->memory());
132 int binarization_idx = p.find(primitive_kind::binarization);
133 const float* binarization_weights = p.entry_[binarization_idx].binarization.weights_data;
134 const uint32_t* binarization_output_mask = (uint32_t*)p.entry_[binarization_idx].binarization.output_mask_data;
136 parallel_nd(G, MB, utils::div_up(OC, nbits), OD, OH, OW,
137 [&](int g, int mb, int ocb, int od, int oh, int ow) {
139 uint8_t bin_val = 0x00;
140 for (int oc = ocb * nbits, shift = 0; oc < std::min(OC, (ocb + 1) * nbits); oc++, shift++) {
142 ker(a, g, mb, oc, od, oh, ow);
145 if (pad_value == 0.0f) {
146 const int i_left_overflow = nstl::max(0, (padL - ow * KSW));
147 const int i_right_overflow = nstl::max(IW, (ow * KSW + (KW - 1) * (KDW + 1) - padL + 1)) - IW;
148 const int kw_padding =
149 KW - utils::div_up(i_left_overflow, (KDW + 1)) - utils::div_up(i_right_overflow, (KDW + 1));
151 const int i_top_overflow = nstl::max(0, (padT - oh * KSH));
152 const int i_bottom_overflow = nstl::max(IH, (oh * KSH + (KH - 1) * (KDH + 1) - padT + 1)) - IH;
153 const int kh_padding =
154 KH - utils::div_up(i_top_overflow, (KDH + 1)) - utils::div_up(i_bottom_overflow, (KDH + 1));
156 const int i_front_overflow = nstl::max(0, (padFront - od * KSD));
157 const int i_back_overflow = nstl::max(ID, (od * KSD + (KD - 1) * (KDD + 1) - padFront + 1)) - ID;
158 const int kd_padding =
159 KD - utils::div_up(i_front_overflow, (KDD + 1)) - utils::div_up(i_back_overflow, (KDD + 1));
161 base_value = IC * kd_padding * kh_padding * kw_padding;
163 base_value = IC * KD * KH * KW;
166 float a_fp = base_value - (float)(2 * a);
170 a_fp += dst[dst_d.off(mb, g * OC + oc, od, oh, ow)];
172 a_fp += dst[dst_d.off(mb, g * OC + oc, oh, ow)];
174 a_fp += dst[dst_d.off(mb, g * OC + oc, ow)];
179 int eltwise_inj_idx = 0;
180 int depthwise_inj_idx = 0;
181 for (int i = 0; i < p.len_; i++) {
182 auto &post_op = p.entry_[i];
183 if (post_op.is_eltwise()) {
184 a_fp = eltwise_injectors[eltwise_inj_idx]->compute_scalar(a_fp);
186 } else if (post_op.is_depthwise()) {
187 auto depthwise_weights = post_op.depthwise.weights_data;
188 auto depthwise_bias = post_op.depthwise.biases_data;
190 a_fp = depthwise_injectors[depthwise_inj_idx]->compute_scalar(a_fp,
191 depthwise_weights + g * OC + oc,
192 depthwise_bias + g * OC + oc);
197 float thr = binarization_weights[g * OC + oc];
198 uint32_t out_mask = binarization_output_mask[g * OC + oc];
199 uint32_t res = (a_fp > thr) ? 0xffffffff : 0x00000000;
201 auto bit = uint8_t((res == out_mask) ? 0x01 : 0x00);
202 bin_val |= (bit << shift);
206 dst[dst_d.off(mb, g*OC + ocb*nbits, od, oh, ow) / nbits] = bin_val;
208 dst[dst_d.off(mb, g*OC + ocb*nbits, oh, ow) / nbits] = bin_val;
210 dst[dst_d.off(mb, g*OC + ocb*nbits, ow) / nbits] = bin_val;
215 auto dst = reinterpret_cast<float*>(this->memory());
217 parallel_nd(G, MB, OC, OD, OH, OW,
218 [&](int g, int mb, int oc, int od, int oh, int ow) {
220 ker(a, g, mb, oc, od, oh, ow);
223 if (pad_value == 0.0f) {
224 const int i_left_overflow = nstl::max(0, (padL - ow * KSW));
225 const int i_right_overflow = nstl::max(IW, (ow * KSW + (KW - 1) * (KDW + 1) - padL + 1)) - IW;
226 const int kw_padding =
227 KW - utils::div_up(i_left_overflow, (KDW + 1)) - utils::div_up(i_right_overflow, (KDW + 1));
229 const int i_top_overflow = nstl::max(0, (padT - oh * KSH));
230 const int i_bottom_overflow = nstl::max(IH, (oh * KSH + (KH - 1) * (KDH + 1) - padT + 1)) - IH;
231 const int kh_padding =
232 KH - utils::div_up(i_top_overflow, (KDH + 1)) - utils::div_up(i_bottom_overflow, (KDH + 1));
234 const int i_front_overflow = nstl::max(0, (padFront - od * KSD));
235 const int i_back_overflow = nstl::max(ID, (od * KSD + (KD - 1) * (KDD + 1) - padFront + 1)) - ID;
236 const int kd_padding =
237 KD - utils::div_up(i_front_overflow, (KDD + 1)) - utils::div_up(i_back_overflow, (KDD + 1));
239 base_value = IC * kd_padding * kh_padding * kw_padding;
241 base_value = IC * KD * KH * KW;
244 float a_fp = base_value - (float)(2 * a);
248 a_fp += dst[dst_d.off(mb, g*OC + oc, od, oh, ow)];
250 a_fp += dst[dst_d.off(mb, g*OC + oc, oh, ow)];
252 a_fp += dst[dst_d.off(mb, g*OC + oc, ow)];
257 int eltwise_inj_idx = 0;
258 int depthwise_inj_idx = 0;
259 for (int i = 0; i < p.len_; i++) {
260 auto& post_op = p.entry_[i];
261 if (post_op.is_eltwise()) {
262 a_fp = eltwise_injectors[eltwise_inj_idx]->compute_scalar(a_fp);
264 } else if (post_op.is_depthwise()) {
265 auto depthwise_weights = post_op.depthwise.weights_data;
266 auto depthwise_bias = post_op.depthwise.biases_data;
268 a_fp = depthwise_injectors[depthwise_inj_idx]->compute_scalar(a_fp, depthwise_weights + g * OC + oc,
269 depthwise_bias + g * OC + oc);
275 dst[dst_d.off(mb, g*OC + oc, od, oh, ow)] = a_fp;
277 dst[dst_d.off(mb, g*OC + oc, oh, ow)] = a_fp;
279 dst[dst_d.off(mb, g*OC + oc, ow)] = a_fp;