[moco-tf] Update FixShapeTransform test with AvgPool (#7041)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Mon, 2 Sep 2019 22:06:44 +0000 (07:06 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Mon, 2 Sep 2019 22:06:44 +0000 (07:06 +0900)
This will update FixShapeTransform test to test with TFAvgPool

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
compiler/moco-tf/src/Transforms/FixShapeTransform.test.cpp

index 4e51420..8b77a3b 100644 (file)
@@ -21,6 +21,8 @@
 #include "Annotations/PaddingData.h"
 #include "Annotations/ShapeInferenceData.h"
 
+#include "Dialect/TFNodes.h"
+
 #include <loco.h>
 
 #include <stdex/Memory.h>
@@ -40,83 +42,77 @@ TEST(FixShapeTransform, ctor)
 namespace
 {
 
-loco::AvgPool2D *avgpool2d_network_simple1331(loco::Graph *graph)
+moco::tf::TFAvgPool *avgpool_network_simple1331(loco::Graph *graph)
 {
-  auto avgpool2d_node = graph->nodes()->create<loco::AvgPool2D>();
-  avgpool2d_node->convention(loco::AvgPool2D::Convention::Valid);
-
-  auto window = avgpool2d_node->window();
-  window->vertical(3);
-  window->horizontal(3);
+  auto avgpool_node = graph->nodes()->create<moco::tf::TFAvgPool>();
 
-  auto stride = avgpool2d_node->stride();
-  stride->vertical(1);
-  stride->horizontal(1);
+  avgpool_node->data_layout("NHWC");
+  avgpool_node->ksize({1, 3, 3, 1});
+  avgpool_node->strides({1, 1, 1, 1});
 
-  // Dummy const node as ifm, just to fake FixShapeTransform for AvgPool2D.
+  // Dummy const node as ifm, just to fake FixShapeTransform for TFAvgPool.
   // FixShapeTransform only cares about ShapeInferenceData of ifm()
-  auto const_node = graph->nodes()->create<loco::ConstGen>();
+  auto const_node = graph->nodes()->create<moco::tf::TFConst>();
   {
     auto shapedata = stdex::make_unique<moco::tf::ShapeInferenceData>();
-    loco::FeatureShape cshape;
-    cshape.count() = 1;
-    cshape.height() = 3;
-    cshape.width() = 3;
-    cshape.depth() = 1;
-    shapedata->feature_shape(cshape);
+    loco::TensorShape tshape;
+    tshape.rank(4);
+    tshape.dim(0).set(1);
+    tshape.dim(1).set(3);
+    tshape.dim(2).set(3);
+    tshape.dim(3).set(1);
+    shapedata->tensor_shape(tshape);
     const_node->annot(std::move(shapedata));
   }
-  avgpool2d_node->ifm(const_node);
+  avgpool_node->value(const_node);
 
-  setup_output_node(graph, avgpool2d_node);
+  setup_output_node(graph, avgpool_node);
 
-  return avgpool2d_node;
+  return avgpool_node;
 }
 
 } // namespace
 
-TEST(FixShapeTransform, avgpool2d_same)
+TEST(FixShapeTransform, avgpool_same)
 {
   moco::tf::FixShapeTransform fstransform;
   loco::Graph graph;
 
-  auto avgpool2d_node = avgpool2d_network_simple1331(&graph);
-
-  auto padding_data = stdex::make_unique<moco::tf::PaddingData>("SAME");
-  avgpool2d_node->annot(std::move(padding_data));
+  auto avgpool_node = avgpool_network_simple1331(&graph);
+  avgpool_node->padding("SAME");
 
   moco::tf::FixShapeTransform transform;
   transform.run(&graph);
 
-  auto shapedata = avgpool2d_node->annot<moco::tf::ShapeInferenceData>();
+  auto shapedata = avgpool_node->annot<moco::tf::ShapeInferenceData>();
   ASSERT_NE(shapedata, nullptr);
-  auto fshape = shapedata->feature_shape();
-  ASSERT_EQ(fshape.count(), 1);
-  ASSERT_EQ(fshape.height(), 3);
-  ASSERT_EQ(fshape.width(), 3);
-  ASSERT_EQ(fshape.depth(), 1);
+  auto tshape = shapedata->tensor_shape();
+  ASSERT_EQ(tshape.rank(), 4);
+  ASSERT_EQ(tshape.dim(0).value(), 1);
+  ASSERT_EQ(tshape.dim(1).value(), 3);
+  ASSERT_EQ(tshape.dim(2).value(), 3);
+  ASSERT_EQ(tshape.dim(3).value(), 1);
 }
 
-TEST(FixShapeTransform, avgpool2d_valid)
+TEST(FixShapeTransform, avgpool_valid)
 {
   moco::tf::FixShapeTransform fstransform;
   loco::Graph graph;
 
-  auto avgpool2d_node = avgpool2d_network_simple1331(&graph);
-
-  auto padding_data = stdex::make_unique<moco::tf::PaddingData>("VALID");
-  avgpool2d_node->annot(std::move(padding_data));
+  auto avgpool_node = avgpool_network_simple1331(&graph);
+  avgpool_node->padding("VALID");
 
   moco::tf::FixShapeTransform transform;
   transform.run(&graph);
 
-  auto shapedata = avgpool2d_node->annot<moco::tf::ShapeInferenceData>();
+  auto shapedata = avgpool_node->annot<moco::tf::ShapeInferenceData>();
   ASSERT_NE(shapedata, nullptr);
-  auto fshape = shapedata->feature_shape();
-  ASSERT_EQ(fshape.count(), 1);
-  ASSERT_EQ(fshape.height(), 1);
-  ASSERT_EQ(fshape.width(), 1);
-  ASSERT_EQ(fshape.depth(), 1);
+  auto tshape = shapedata->tensor_shape();
+  ASSERT_EQ(tshape.rank(), 4);
+  ASSERT_EQ(tshape.dim(0).value(), 1);
+  ASSERT_EQ(tshape.dim(1).value(), 1);
+  ASSERT_EQ(tshape.dim(2).value(), 1);
+  ASSERT_EQ(tshape.dim(3).value(), 1);
 }
 
 namespace