Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / ref_binarization.cpp
1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <assert.h>
18 #include <math.h>
19 #include <common/utils.hpp>
20
21 #include "c_types_map.hpp"
22 #include "type_helpers.hpp"
23 #include "mkldnn_thread.hpp"
24
25 #include "ref_binarization.hpp"
26
27 namespace mkldnn {
28 namespace impl {
29 namespace cpu {
30
31 using namespace alg_kind;
32
33 template <impl::data_type_t src_type>
34 void ref_binarization_fwd_t<src_type>::execute_forward() const {
35     auto src = reinterpret_cast<const src_data_t*>(this->input_memory(0));
36     auto weights = reinterpret_cast<const src_data_t*>(this->input_memory(1));
37     auto dst = reinterpret_cast<uint8_t*>(this->memory());
38
39     const memory_desc_wrapper src_d(pd()->src_pd());
40     const memory_desc_wrapper dst_d(pd()->dst_pd());
41     const memory_desc_wrapper weights_d(pd()->weights_pd(0));
42
43     int nbits = 8;
44
45     const int MB = pd()->MB();
46     const int C = pd()->C();
47     const int CB = utils::div_up(C, nbits);
48     const int D = pd()->D();
49     const int H = pd()->H();
50     const int W = pd()->W();
51
52     parallel_nd(MB, CB, D, H, W,
53         [&](int n, int cb, int d, int h, int w) {
54
55         uint8_t bin_val = 0x00;
56         for (int c = cb * nbits, shift = 0; c < std::min(C, (cb + 1) * nbits); c++, shift++) {
57             size_t src_off = src_d.ndims() == 4
58                               ? src_d.off(n, c, h, w)
59                               : src_d.ndims() == 5
60                                 ? src_d.off(n, c, d, h, w)
61                                 : src_d.off(n, c);
62
63             size_t wei_off = weights_d.off(c);
64
65             float val = src[src_off];
66             float thr = weights[wei_off];
67
68             auto bit = uint8_t((val > thr) ? 0x01 : 0x00);
69             bin_val |= (bit << shift);
70         }
71
72         size_t dst_off = dst_d.ndims() == 4
73                            ? dst_d.off(n, cb*nbits, h, w)
74                            : dst_d.ndims() == 5
75                              ? dst_d.off(n, cb, d, h, w)
76                              : dst_d.off(n, cb);
77
78         dst[dst_off / nbits] = bin_val;
79     });
80 }
81
82 template struct ref_binarization_fwd_t<data_type::f32>;
83
84 }
85 }
86 }