template <typename src_data_t>
void check_binarization_fwd(const binarization_test_params<src_data_t> &p,
- const memory::desc &src_md, const memory &src, const memory &weights, const memory &dst) {
+ const memory::desc &src_md, const memory &src, const memory &weights,
+ const memory &output_low, const memory &output_high, const memory &dst) {
auto src_data = (src_data_t*)src.get_data_handle();
auto weights_data = (src_data_t*)weights.get_data_handle();
+ auto output_low_data = (float*)output_low.get_data_handle();
+ auto output_high_data = (float*)output_high.get_data_handle();
auto dst_data = (uint8_t*)dst.get_data_handle();
const memory::desc src_d = src.get_primitive_desc().desc();
const memory::desc weights_d = weights.get_primitive_desc().desc();
+ const memory::desc output_low_d = output_low.get_primitive_desc().desc();
+ const memory::desc output_high_d = output_high.get_primitive_desc().desc();
const memory::desc dst_d = dst.get_primitive_desc().desc();
int N = src_md.data.ndims > 0 ? src_md.data.dims[0] : 1;
src_data_t s_val = src_data[map_index(src_d, src_idx)];
src_data_t w_val = weights_data[map_index(weights_d, wei_idx)];
+ src_data_t out_low = output_low_data[map_index(output_low_d, wei_idx)];
+ src_data_t out_high = output_high_data[map_index(output_high_d, wei_idx)];
- auto bit = uint8_t((s_val > w_val) ? 0x01 : 0x00);
+ auto bit = uint8_t((s_val > w_val) ? out_high : out_low);
bin_val |= (bit << shift);
}
auto src_desc = create_md(src_dims, src_data_type, p.data_format);
auto weights_desc = create_md(wei_dims, src_data_type, memory::format::x);
+ auto output_low_desc = create_md(wei_dims, src_data_type, memory::format::x);
+ auto output_high_desc = create_md(wei_dims, src_data_type, memory::format::x);
+ auto output_mask_desc = create_md(wei_dims, src_data_type, memory::format::x);
auto dst_desc = create_md(dst_dims, memory::data_type::bin, p.data_format);
auto src = test_memory(src_desc, eng);
auto weights = test_memory(weights_desc, eng);
+ auto output_low = test_memory(output_low_desc, eng);
+ auto output_high = test_memory(output_high_desc, eng);
+ auto output_mask = test_memory(output_mask_desc, eng);
auto dst = test_memory(dst_desc, eng);
fill_data<src_data_t>(src.get_size() / sizeof(src_data_t), (src_data_t *)src.get().get_data_handle(),
src_data_t(0), src_data_t(1));
fill_data<src_data_t>(weights.get_size() / sizeof(src_data_t), (src_data_t *)weights.get().get_data_handle(),
src_data_t(0), src_data_t(1));
+ fill_data<src_data_t>(output_low.get_size() / sizeof(src_data_t), (src_data_t *)output_low.get().get_data_handle(),
+ src_data_t(0), src_data_t(1));
fill_data<uint8_t>(dst.get_size() / sizeof(uint8_t), (uint8_t*)dst.get().get_data_handle());
+ src_data_t* p_output_low = (src_data_t *)output_low.get().get_data_handle();
+ src_data_t* p_output_high = (src_data_t *)output_high.get().get_data_handle();
+ uint32_t* p_output_mask = (uint32_t *)output_mask.get().get_data_handle();
+ for (int i = 0; i < src_dims[1]; i++) {
+ p_output_low[i] = p_output_low[i] >= 0 ? 1 : 0;
+ p_output_high[i] = p_output_low[i] == 1 ? 0 : 1;
+ p_output_mask[i] = p_output_high[i] == 1 ? 0xffffffff : 0x00000000;
+ }
+
std::vector<primitive> pipeline;
- auto binarization_desc = binarization_forward::desc(prop_kind::forward_training, p.alg_kind, src_desc, weights_desc, dst_desc);
+ auto binarization_desc = binarization_forward::desc(prop_kind::forward_training, p.alg_kind, src_desc, weights_desc, output_high_desc, dst_desc);
auto binarization_prim_desc = binarization_forward::primitive_desc(binarization_desc, eng);
- auto binarization = binarization_forward(binarization_prim_desc, src.get(), weights.get(), dst.get());
+ auto binarization = binarization_forward(binarization_prim_desc, src.get(), weights.get(), output_mask.get(), dst.get());
pipeline.push_back(binarization);
auto s = stream(stream::kind::lazy);
s.submit(pipeline).wait();
- check_binarization_fwd(p, src_desc, src.get(), weights.get(), dst.get());
+ check_binarization_fwd(p, src_desc, src.get(), weights.get(), output_low.get(), output_high.get(), dst.get());
}
};