int binarization_idx = p.find(primitive_kind::binarization);
const float* binarization_weights = p.entry_[binarization_idx].binarization.weights_data;
+ const uint32_t* binarization_output_mask = (uint32_t*)p.entry_[binarization_idx].binarization.output_mask_data;
parallel_nd(G, MB, utils::div_up(OC, nbits), OD, OH, OW,
[&](int g, int mb, int ocb, int od, int oh, int ow) {
}
float thr = binarization_weights[g * OC + oc];
- auto bit = uint8_t((a_fp > thr) ? 0x01 : 0x00);
+ uint32_t out_mask = binarization_output_mask[g * OC + oc];
+ uint32_t res = (a_fp > thr) ? 0xffffffff : 0x00000000;
+
+ auto bit = uint8_t((res == out_mask) ? 0x01 : 0x00);
bin_val |= (bit << shift);
}