void ref_binarization_fwd_t<src_type>::execute_forward() const {
auto src = reinterpret_cast<const src_data_t*>(this->input_memory(0));
auto weights = reinterpret_cast<const src_data_t*>(this->input_memory(1));
+ auto output_mask = reinterpret_cast<const uint32_t*>(this->input_memory(2));
auto dst = reinterpret_cast<uint8_t*>(this->memory());
const memory_desc_wrapper src_d(pd()->src_pd());
const memory_desc_wrapper dst_d(pd()->dst_pd());
const memory_desc_wrapper weights_d(pd()->weights_pd(0));
+ const memory_desc_wrapper output_mask_d(pd()->weights_pd(1));
int nbits = 8;
: src_d.off(n, c);
size_t wei_off = weights_d.off(c);
+ size_t out_mask_off = output_mask_d.off(c);
float val = src[src_off];
float thr = weights[wei_off];
+ uint32_t out_mask = output_mask[out_mask_off];
- auto bit = uint8_t((val > thr) ? 0x01 : 0x00);
+ uint32_t res = (val > thr) ? 0xffffffff : 0x00000000;
+
+ auto bit = uint8_t(res == out_mask);
bin_val |= (bit << shift);
}