Publishing 2019 R1.1 content and Myriad plugin sources (#162)
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / gtests / test_binarization.cpp
index e720faf..8b748b6 100644 (file)
@@ -31,13 +31,18 @@ struct binarization_test_params {
 
 template <typename src_data_t>
 void check_binarization_fwd(const binarization_test_params<src_data_t> &p,
-        const memory::desc &src_md, const memory &src, const memory &weights, const memory &dst) {
+        const memory::desc &src_md, const memory &src, const memory &weights,
+        const memory &output_low, const memory &output_high, const memory &dst) {
     auto src_data = (src_data_t*)src.get_data_handle();
     auto weights_data = (src_data_t*)weights.get_data_handle();
+    auto output_low_data = (float*)output_low.get_data_handle();
+    auto output_high_data = (float*)output_high.get_data_handle();
     auto dst_data = (uint8_t*)dst.get_data_handle();
 
     const memory::desc src_d = src.get_primitive_desc().desc();
     const memory::desc weights_d = weights.get_primitive_desc().desc();
+    const memory::desc output_low_d = output_low.get_primitive_desc().desc();
+    const memory::desc output_high_d = output_high.get_primitive_desc().desc();
     const memory::desc dst_d = dst.get_primitive_desc().desc();
 
     int N = src_md.data.ndims > 0 ? src_md.data.dims[0] : 1;
@@ -63,8 +68,10 @@ void check_binarization_fwd(const binarization_test_params<src_data_t> &p,
 
                         src_data_t s_val = src_data[map_index(src_d, src_idx)];
                         src_data_t w_val = weights_data[map_index(weights_d, wei_idx)];
+                        src_data_t out_low = output_low_data[map_index(output_low_d, wei_idx)];
+                        src_data_t out_high = output_high_data[map_index(output_high_d, wei_idx)];
 
-                        auto bit = uint8_t((s_val > w_val) ? 0x01 : 0x00);
+                        auto bit = uint8_t((s_val > w_val) ? out_high : out_low);
                         bin_val |= (bit << shift);
                     }
 
@@ -95,28 +102,45 @@ protected:
 
         auto src_desc = create_md(src_dims, src_data_type, p.data_format);
         auto weights_desc = create_md(wei_dims, src_data_type, memory::format::x);
+        auto output_low_desc = create_md(wei_dims, src_data_type, memory::format::x);
+        auto output_high_desc = create_md(wei_dims, src_data_type, memory::format::x);
+        auto output_mask_desc = create_md(wei_dims, src_data_type, memory::format::x);
         auto dst_desc = create_md(dst_dims, memory::data_type::bin, p.data_format);
 
         auto src = test_memory(src_desc, eng);
         auto weights = test_memory(weights_desc, eng);
+        auto output_low = test_memory(output_low_desc, eng);
+        auto output_high = test_memory(output_high_desc, eng);
+        auto output_mask = test_memory(output_mask_desc, eng);
         auto dst = test_memory(dst_desc, eng);
 
         fill_data<src_data_t>(src.get_size() / sizeof(src_data_t), (src_data_t *)src.get().get_data_handle(),
                               src_data_t(0), src_data_t(1));
         fill_data<src_data_t>(weights.get_size() / sizeof(src_data_t), (src_data_t *)weights.get().get_data_handle(),
                               src_data_t(0), src_data_t(1));
+        fill_data<src_data_t>(output_low.get_size() / sizeof(src_data_t), (src_data_t *)output_low.get().get_data_handle(),
+                              src_data_t(0), src_data_t(1));
         fill_data<uint8_t>(dst.get_size() / sizeof(uint8_t), (uint8_t*)dst.get().get_data_handle());
 
+        src_data_t* p_output_low = (src_data_t *)output_low.get().get_data_handle();
+        src_data_t* p_output_high = (src_data_t *)output_high.get().get_data_handle();
+        uint32_t* p_output_mask = (uint32_t *)output_mask.get().get_data_handle();
+        for (int i = 0; i < src_dims[1]; i++) {
+            p_output_low[i] = p_output_low[i] >= 0 ? 1 : 0;
+            p_output_high[i] = p_output_low[i] == 1 ? 0 : 1;
+            p_output_mask[i] = p_output_high[i] == 1 ? 0xffffffff : 0x00000000;
+        }
+
         std::vector<primitive> pipeline;
-        auto binarization_desc = binarization_forward::desc(prop_kind::forward_training, p.alg_kind, src_desc, weights_desc, dst_desc);
+        auto binarization_desc = binarization_forward::desc(prop_kind::forward_training, p.alg_kind, src_desc, weights_desc, output_high_desc, dst_desc);
         auto binarization_prim_desc = binarization_forward::primitive_desc(binarization_desc, eng);
-        auto binarization = binarization_forward(binarization_prim_desc, src.get(), weights.get(), dst.get());
+        auto binarization = binarization_forward(binarization_prim_desc, src.get(), weights.get(), output_mask.get(), dst.get());
 
         pipeline.push_back(binarization);
         auto s = stream(stream::kind::lazy);
         s.submit(pipeline).wait();
 
-        check_binarization_fwd(p, src_desc, src.get(), weights.get(), dst.get());
+        check_binarization_fwd(p, src_desc, src.get(), weights.get(), output_low.get(), output_high.get(), dst.get());
     }
 };