Publishing 2019 R1.1 content and Myriad plugin sources (#162)
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / ref_binary_convolution.cpp
index 2c9cbde..3ccdf95 100644 (file)
@@ -131,6 +131,7 @@ void _ref_binary_convolution_fwd_t::execute_forward() const {
 
         int binarization_idx = p.find(primitive_kind::binarization);
         const float* binarization_weights = p.entry_[binarization_idx].binarization.weights_data;
+        const uint32_t* binarization_output_mask = (uint32_t*)p.entry_[binarization_idx].binarization.output_mask_data;
 
         parallel_nd(G, MB, utils::div_up(OC, nbits), OD, OH, OW,
             [&](int g, int mb, int ocb, int od, int oh, int ow) {
@@ -194,7 +195,10 @@ void _ref_binary_convolution_fwd_t::execute_forward() const {
                 }
 
                 float thr = binarization_weights[g * OC + oc];
-                auto bit = uint8_t((a_fp > thr) ? 0x01 : 0x00);
+                uint32_t out_mask = binarization_output_mask[g * OC + oc];
+                uint32_t res = (a_fp > thr) ? 0xffffffff : 0x00000000;
+
+                auto bit = uint8_t((res == out_mask) ? 0x01 : 0x00);
                 bin_val |= (bit << shift);
             }