Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / gtests / test_depthwise.cpp
index 1325398..932ec73 100644 (file)
@@ -95,7 +95,8 @@ private:
     std::shared_ptr<memory> bias;
     std::shared_ptr<memory> dst;
     std::shared_ptr<memory> workspace;
-    std::shared_ptr<memory::desc> data_desc;
+    std::shared_ptr<memory::desc> src_desc;
+    std::shared_ptr<memory::desc> dst_desc;
     std::shared_ptr<memory::desc> weights_desc;
     std::shared_ptr<memory::desc> bias_desc;
     std::shared_ptr<depthwise_forward::primitive_desc> depthwise_prim_desc;
@@ -126,9 +127,10 @@ protected:
 
         memory::dims dims = p.data_format == mkldnn_nc ? memory::dims({p.dims[0], p.dims[1]}) : p.dims;
 
-        data_desc.reset(new memory::desc(dims, data_type, p.data_format));
-        src.reset(new memory({*data_desc, *eng}));
-        dst.reset(new memory({*data_desc, *eng}));
+        src_desc.reset(new memory::desc(dims, data_type, p.data_format));
+        dst_desc.reset(new memory::desc(dims, data_type, p.data_format));
+        src.reset(new memory({*src_desc, *eng}));
+        dst.reset(new memory({*dst_desc, *eng}));
         fill_data<data_t>(data_size, (data_t *)src->get_data_handle(),
                           data_t(0), data_t(1));
 
@@ -146,8 +148,8 @@ protected:
 
         std::vector<primitive> pipeline;
         auto depthwise_desc = with_bias
-                              ? depthwise_forward::desc(prop_kind::forward_training, p.alg_kind, *data_desc, *data_desc, *weights_desc, *bias_desc)
-                              : depthwise_forward::desc(prop_kind::forward_training, p.alg_kind, *data_desc, *data_desc, *weights_desc);
+                              ? depthwise_forward::desc(prop_kind::forward_training, p.alg_kind, *src_desc, *dst_desc, *weights_desc, *bias_desc)
+                              : depthwise_forward::desc(prop_kind::forward_training, p.alg_kind, *src_desc, *dst_desc, *weights_desc);
         depthwise_prim_desc.reset(new depthwise_forward::primitive_desc(depthwise_desc, *eng));
 
         auto depthwise = with_bias
@@ -158,7 +160,7 @@ protected:
         auto s = stream(stream::kind::lazy);
         s.submit(pipeline).wait();
 
-        check_depthwise_fwd(p, *data_desc, *src, *weights, *bias, with_bias, *dst);
+        check_depthwise_fwd(p, *src_desc, *src, *weights, *bias, with_bias, *dst);
     }
 };