1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
19 #include <common/utils.hpp>
21 #include "c_types_map.hpp"
22 #include "type_helpers.hpp"
23 #include "mkldnn_thread.hpp"
25 #include "ref_binarization.hpp"
31 using namespace alg_kind;
33 template <impl::data_type_t src_type>
34 void ref_binarization_fwd_t<src_type>::execute_forward() const {
35 auto src = reinterpret_cast<const src_data_t*>(this->input_memory(0));
36 auto weights = reinterpret_cast<const src_data_t*>(this->input_memory(1));
37 auto output_mask = reinterpret_cast<const uint32_t*>(this->input_memory(2));
38 auto dst = reinterpret_cast<uint8_t*>(this->memory());
40 const memory_desc_wrapper src_d(pd()->src_pd());
41 const memory_desc_wrapper dst_d(pd()->dst_pd());
42 const memory_desc_wrapper weights_d(pd()->weights_pd(0));
43 const memory_desc_wrapper output_mask_d(pd()->weights_pd(1));
47 const int MB = pd()->MB();
48 const int C = pd()->C();
49 const int CB = utils::div_up(C, nbits);
50 const int D = pd()->D();
51 const int H = pd()->H();
52 const int W = pd()->W();
54 parallel_nd(MB, CB, D, H, W,
55 [&](int n, int cb, int d, int h, int w) {
57 uint8_t bin_val = 0x00;
58 for (int c = cb * nbits, shift = 0; c < std::min(C, (cb + 1) * nbits); c++, shift++) {
59 size_t src_off = src_d.ndims() == 4
60 ? src_d.off(n, c, h, w)
62 ? src_d.off(n, c, d, h, w)
65 size_t wei_off = weights_d.off(c);
66 size_t out_mask_off = output_mask_d.off(c);
68 float val = src[src_off];
69 float thr = weights[wei_off];
70 uint32_t out_mask = output_mask[out_mask_off];
72 uint32_t res = (val > thr) ? 0xffffffff : 0x00000000;
74 auto bit = uint8_t(res == out_mask);
75 bin_val |= (bit << shift);
78 size_t dst_off = dst_d.ndims() == 4
79 ? dst_d.off(n, cb*nbits, h, w)
81 ? dst_d.off(n, cb, d, h, w)
84 dst[dst_off / nbits] = bin_val;
88 template struct ref_binarization_fwd_t<data_type::f32>;