X-Git-Url: http://review.tizen.org/git/?a=blobdiff_plain;f=inference-engine%2Fthirdparty%2Fmkl-dnn%2Ftests%2Fgtests%2Ftest_binarization.cpp;h=8b748b6cc034e264527accec9da41d65650bca2f;hb=0ef92871b6dd9a9ceed16d184c4595d2618d526f;hp=e720faf53125c49902c3c077f41268fe45a58f07;hpb=72660e9a4d683dc6a0c50e9fad96e59b7edd1f71;p=platform%2Fupstream%2Fdldt.git diff --git a/inference-engine/thirdparty/mkl-dnn/tests/gtests/test_binarization.cpp b/inference-engine/thirdparty/mkl-dnn/tests/gtests/test_binarization.cpp index e720faf..8b748b6 100644 --- a/inference-engine/thirdparty/mkl-dnn/tests/gtests/test_binarization.cpp +++ b/inference-engine/thirdparty/mkl-dnn/tests/gtests/test_binarization.cpp @@ -31,13 +31,18 @@ struct binarization_test_params { template void check_binarization_fwd(const binarization_test_params &p, - const memory::desc &src_md, const memory &src, const memory &weights, const memory &dst) { + const memory::desc &src_md, const memory &src, const memory &weights, + const memory &output_low, const memory &output_high, const memory &dst) { auto src_data = (src_data_t*)src.get_data_handle(); auto weights_data = (src_data_t*)weights.get_data_handle(); + auto output_low_data = (float*)output_low.get_data_handle(); + auto output_high_data = (float*)output_high.get_data_handle(); auto dst_data = (uint8_t*)dst.get_data_handle(); const memory::desc src_d = src.get_primitive_desc().desc(); const memory::desc weights_d = weights.get_primitive_desc().desc(); + const memory::desc output_low_d = output_low.get_primitive_desc().desc(); + const memory::desc output_high_d = output_high.get_primitive_desc().desc(); const memory::desc dst_d = dst.get_primitive_desc().desc(); int N = src_md.data.ndims > 0 ? src_md.data.dims[0] : 1; @@ -63,8 +68,10 @@ void check_binarization_fwd(const binarization_test_params &p, src_data_t s_val = src_data[map_index(src_d, src_idx)]; src_data_t w_val = weights_data[map_index(weights_d, wei_idx)]; + src_data_t out_low = output_low_data[map_index(output_low_d, wei_idx)]; + src_data_t out_high = output_high_data[map_index(output_high_d, wei_idx)]; - auto bit = uint8_t((s_val > w_val) ? 0x01 : 0x00); + auto bit = uint8_t((s_val > w_val) ? out_high : out_low); bin_val |= (bit << shift); } @@ -95,28 +102,45 @@ protected: auto src_desc = create_md(src_dims, src_data_type, p.data_format); auto weights_desc = create_md(wei_dims, src_data_type, memory::format::x); + auto output_low_desc = create_md(wei_dims, src_data_type, memory::format::x); + auto output_high_desc = create_md(wei_dims, src_data_type, memory::format::x); + auto output_mask_desc = create_md(wei_dims, src_data_type, memory::format::x); auto dst_desc = create_md(dst_dims, memory::data_type::bin, p.data_format); auto src = test_memory(src_desc, eng); auto weights = test_memory(weights_desc, eng); + auto output_low = test_memory(output_low_desc, eng); + auto output_high = test_memory(output_high_desc, eng); + auto output_mask = test_memory(output_mask_desc, eng); auto dst = test_memory(dst_desc, eng); fill_data(src.get_size() / sizeof(src_data_t), (src_data_t *)src.get().get_data_handle(), src_data_t(0), src_data_t(1)); fill_data(weights.get_size() / sizeof(src_data_t), (src_data_t *)weights.get().get_data_handle(), src_data_t(0), src_data_t(1)); + fill_data(output_low.get_size() / sizeof(src_data_t), (src_data_t *)output_low.get().get_data_handle(), + src_data_t(0), src_data_t(1)); fill_data(dst.get_size() / sizeof(uint8_t), (uint8_t*)dst.get().get_data_handle()); + src_data_t* p_output_low = (src_data_t *)output_low.get().get_data_handle(); + src_data_t* p_output_high = (src_data_t *)output_high.get().get_data_handle(); + uint32_t* p_output_mask = (uint32_t *)output_mask.get().get_data_handle(); + for (int i = 0; i < src_dims[1]; i++) { + p_output_low[i] = p_output_low[i] >= 0 ? 1 : 0; + p_output_high[i] = p_output_low[i] == 1 ? 0 : 1; + p_output_mask[i] = p_output_high[i] == 1 ? 0xffffffff : 0x00000000; + } + std::vector pipeline; - auto binarization_desc = binarization_forward::desc(prop_kind::forward_training, p.alg_kind, src_desc, weights_desc, dst_desc); + auto binarization_desc = binarization_forward::desc(prop_kind::forward_training, p.alg_kind, src_desc, weights_desc, output_high_desc, dst_desc); auto binarization_prim_desc = binarization_forward::primitive_desc(binarization_desc, eng); - auto binarization = binarization_forward(binarization_prim_desc, src.get(), weights.get(), dst.get()); + auto binarization = binarization_forward(binarization_prim_desc, src.get(), weights.get(), output_mask.get(), dst.get()); pipeline.push_back(binarization); auto s = stream(stream::kind::lazy); s.submit(pipeline).wait(); - check_binarization_fwd(p, src_desc, src.get(), weights.get(), dst.get()); + check_binarization_fwd(p, src_desc, src.get(), weights.get(), output_low.get(), output_high.get(), dst.get()); } };