Publishing 2019 R1.1 content and Myriad plugin sources (#162)
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / ref_binarization.cpp
index 4fa9372..dc0b8dd 100644 (file)
@@ -34,11 +34,13 @@ template <impl::data_type_t src_type>
 void ref_binarization_fwd_t<src_type>::execute_forward() const {
     auto src = reinterpret_cast<const src_data_t*>(this->input_memory(0));
     auto weights = reinterpret_cast<const src_data_t*>(this->input_memory(1));
+    auto output_mask = reinterpret_cast<const uint32_t*>(this->input_memory(2));
     auto dst = reinterpret_cast<uint8_t*>(this->memory());
 
     const memory_desc_wrapper src_d(pd()->src_pd());
     const memory_desc_wrapper dst_d(pd()->dst_pd());
     const memory_desc_wrapper weights_d(pd()->weights_pd(0));
+    const memory_desc_wrapper output_mask_d(pd()->weights_pd(1));
 
     int nbits = 8;
 
@@ -61,11 +63,15 @@ void ref_binarization_fwd_t<src_type>::execute_forward() const {
                                 : src_d.off(n, c);
 
             size_t wei_off = weights_d.off(c);
+            size_t out_mask_off = output_mask_d.off(c);
 
             float val = src[src_off];
             float thr = weights[wei_off];
+            uint32_t out_mask = output_mask[out_mask_off];
 
-            auto bit = uint8_t((val > thr) ? 0x01 : 0x00);
+            uint32_t res = (val > thr) ? 0xffffffff : 0x00000000;
+
+            auto bit = uint8_t(res == out_mask);
             bin_val |= (bit << shift);
         }