X-Git-Url: http://review.tizen.org/git/?a=blobdiff_plain;f=inference-engine%2Fthirdparty%2Fmkl-dnn%2Fsrc%2Fcpu%2Fref_binarization.cpp;h=dc0b8ddccc908750c66a020908ff31889612f2fb;hb=0ef92871b6dd9a9ceed16d184c4595d2618d526f;hp=4fa937208884f2654090199de3e0bb4f0338ca54;hpb=e206d06f18f9d3c3db29f22c3d496cb3627a16c7;p=platform%2Fupstream%2Fdldt.git diff --git a/inference-engine/thirdparty/mkl-dnn/src/cpu/ref_binarization.cpp b/inference-engine/thirdparty/mkl-dnn/src/cpu/ref_binarization.cpp index 4fa9372..dc0b8dd 100644 --- a/inference-engine/thirdparty/mkl-dnn/src/cpu/ref_binarization.cpp +++ b/inference-engine/thirdparty/mkl-dnn/src/cpu/ref_binarization.cpp @@ -34,11 +34,13 @@ template void ref_binarization_fwd_t::execute_forward() const { auto src = reinterpret_cast(this->input_memory(0)); auto weights = reinterpret_cast(this->input_memory(1)); + auto output_mask = reinterpret_cast(this->input_memory(2)); auto dst = reinterpret_cast(this->memory()); const memory_desc_wrapper src_d(pd()->src_pd()); const memory_desc_wrapper dst_d(pd()->dst_pd()); const memory_desc_wrapper weights_d(pd()->weights_pd(0)); + const memory_desc_wrapper output_mask_d(pd()->weights_pd(1)); int nbits = 8; @@ -61,11 +63,15 @@ void ref_binarization_fwd_t::execute_forward() const { : src_d.off(n, c); size_t wei_off = weights_d.off(c); + size_t out_mask_off = output_mask_d.off(c); float val = src[src_off]; float thr = weights[wei_off]; + uint32_t out_mask = output_mask[out_mask_off]; - auto bit = uint8_t((val > thr) ? 0x01 : 0x00); + uint32_t res = (val > thr) ? 0xffffffff : 0x00000000; + + auto bit = uint8_t(res == out_mask); bin_val |= (bit << shift); }