1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
19 #include <common/utils.hpp>
21 #include "c_types_map.hpp"
22 #include "type_helpers.hpp"
23 #include "mkldnn_thread.hpp"
25 #include "ref_binarization.hpp"
31 using namespace alg_kind;
33 template <impl::data_type_t src_type>
34 void ref_binarization_fwd_t<src_type>::execute_forward() const {
35 auto src = reinterpret_cast<const src_data_t*>(this->input_memory(0));
36 auto weights = reinterpret_cast<const src_data_t*>(this->input_memory(1));
37 auto dst = reinterpret_cast<uint8_t*>(this->memory());
39 const memory_desc_wrapper src_d(pd()->src_pd());
40 const memory_desc_wrapper dst_d(pd()->dst_pd());
41 const memory_desc_wrapper weights_d(pd()->weights_pd(0));
45 const int MB = pd()->MB();
46 const int C = pd()->C();
47 const int CB = utils::div_up(C, nbits);
48 const int D = pd()->D();
49 const int H = pd()->H();
50 const int W = pd()->W();
52 parallel_nd(MB, CB, D, H, W,
53 [&](int n, int cb, int d, int h, int w) {
55 uint8_t bin_val = 0x00;
56 for (int c = cb * nbits, shift = 0; c < std::min(C, (cb + 1) * nbits); c++, shift++) {
57 size_t src_off = src_d.ndims() == 4
58 ? src_d.off(n, c, h, w)
60 ? src_d.off(n, c, d, h, w)
63 size_t wei_off = weights_d.off(c);
65 float val = src[src_off];
66 float thr = weights[wei_off];
68 auto bit = uint8_t((val > thr) ? 0x01 : 0x00);
69 bin_val |= (bit << shift);
72 size_t dst_off = dst_d.ndims() == 4
73 ? dst_d.off(n, cb*nbits, h, w)
75 ? dst_d.off(n, cb, d, h, w)
78 dst[dst_off / nbits] = bin_val;
82 template struct ref_binarization_fwd_t<data_type::f32>;