[moco-tf] Revise FixShape test with Conv2D (#7116)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Wed, 4 Sep 2019 01:51:07 +0000 (10:51 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Wed, 4 Sep 2019 01:51:07 +0000 (10:51 +0900)
This will revise FixShapeTransform test with Conv2D to use moco::tf::TFConv2D from loco::Conv2D IR

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
compiler/moco-tf/src/Transforms/FixShapeTransform.test.cpp

index 8b77a3b..1cb0a4d 100644 (file)
@@ -125,33 +125,34 @@ void conv2d_test(const std::array<uint32_t, 4> ifm_shape, const std::array<uint3
   moco::tf::FixShapeTransform fstransform;
   loco::Graph graph;
 
-  auto conv2d_node = graph.nodes()->create<loco::Conv2D>();
+  auto conv2d_node = graph.nodes()->create<moco::tf::TFConv2D>();
+  conv2d_node->data_layout("NHWC");
+  conv2d_node->strides({1, stride_h_w[0], stride_h_w[1], 1});
+  conv2d_node->padding(padding);
 
-  auto stride = conv2d_node->stride();
-  stride->vertical(stride_h_w[0]);
-  stride->horizontal(stride_h_w[1]);
-
-  auto ifm_node = graph.nodes()->create<loco::ConstGen>();
+  auto ifm_node = graph.nodes()->create<moco::tf::TFConst>();
   {
     auto shapedata = stdex::make_unique<moco::tf::ShapeInferenceData>();
-    loco::FeatureShape cshape;
-    cshape.count() = ifm_shape[0];
-    cshape.height() = ifm_shape[1];
-    cshape.width() = ifm_shape[2];
-    cshape.depth() = ifm_shape[3];
-    shapedata->feature_shape(cshape);
+    loco::TensorShape tshape;
+    tshape.rank(4);
+    tshape.dim(0).set(ifm_shape[0]);
+    tshape.dim(1).set(ifm_shape[1]);
+    tshape.dim(2).set(ifm_shape[2]);
+    tshape.dim(3).set(ifm_shape[3]);
+    shapedata->tensor_shape(tshape);
     ifm_node->annot(std::move(shapedata));
   }
 
   auto ker_node = graph.nodes()->create<loco::ConstGen>();
   {
     auto shapedata = stdex::make_unique<moco::tf::ShapeInferenceData>();
-    loco::FilterShape cshape;
-    cshape.height() = ker_shape[0];
-    cshape.width() = ker_shape[1];
-    cshape.depth() = ker_shape[2];
-    cshape.count() = ker_shape[3];
-    shapedata->filter_shape(cshape);
+    loco::TensorShape tshape;
+    tshape.rank(4);
+    tshape.dim(0).set(ker_shape[0]);
+    tshape.dim(1).set(ker_shape[1]);
+    tshape.dim(2).set(ker_shape[2]);
+    tshape.dim(3).set(ker_shape[3]);
+    shapedata->tensor_shape(tshape);
     ker_node->annot(std::move(shapedata));
   }
 
@@ -160,19 +161,17 @@ void conv2d_test(const std::array<uint32_t, 4> ifm_shape, const std::array<uint3
 
   setup_output_node(&graph, conv2d_node);
 
-  auto padding_data = stdex::make_unique<moco::tf::PaddingData>(padding);
-  conv2d_node->annot(std::move(padding_data));
-
   moco::tf::FixShapeTransform transform;
   transform.run(&graph);
 
   auto shapedata = conv2d_node->annot<moco::tf::ShapeInferenceData>();
   ASSERT_NE(shapedata, nullptr);
-  auto fshape = shapedata->feature_shape();
-  ASSERT_EQ(fshape.count(), expected_shape[0]);
-  ASSERT_EQ(fshape.height(), expected_shape[1]);
-  ASSERT_EQ(fshape.width(), expected_shape[2]);
-  ASSERT_EQ(fshape.depth(), expected_shape[3]);
+  auto tshape = shapedata->tensor_shape();
+  ASSERT_EQ(tshape.rank(), 4);
+  ASSERT_EQ(tshape.dim(0).value(), expected_shape[0]);
+  ASSERT_EQ(tshape.dim(1).value(), expected_shape[1]);
+  ASSERT_EQ(tshape.dim(2).value(), expected_shape[2]);
+  ASSERT_EQ(tshape.dim(3).value(), expected_shape[3]);
 }
 
 } // namespace