[exo-tflite] removing old graph builder (#7524)
author윤현식/On-Device Lab(SR)/Principal Engineer/삼성전자 <hyunsik.yoon@samsung.com>
Wed, 18 Sep 2019 06:05:07 +0000 (15:05 +0900)
committer박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Wed, 18 Sep 2019 06:05:07 +0000 (15:05 +0900)
Old graph builder (`PushPullGraph<..>`) is now replaced with `TestGraph`.

Signed-off-by: Hyun Sik Yoon <hyunsik.yoon@samsung.com>
compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.test.cpp
compiler/exo-tflite/src/TestGraph.h

index bdacaf0..679a359 100644 (file)
@@ -90,12 +90,15 @@ TEST(TFLShapeInferenceRuleTest, minimal_with_TFLRelu)
 // https://www.corvil.com/kb/what-is-the-difference-between-same-and-valid-padding-in-tf-nn-max-pool-of-tensorflow
 TEST(TFLShapeInferenceRuleTest, avgpool2d_valid)
 {
-  exo::test::PullPushGraph<locoex::TFLAveragePool2D> test_graph;
-  auto pull = test_graph.pull;
+  exo::test::TestGraph graph;
+  auto tfl_node = graph.append<locoex::TFLAveragePool2D>(graph.pull);
+  graph.complete();
+
+  auto pull = graph.pull;
   {
     pull->shape({1, 4, 3, 1});
   }
-  auto tfl_node = test_graph.middle_node;
+  // setting TFLAveragePool2D
   {
     tfl_node->filter()->h(2);
     tfl_node->filter()->w(2);
@@ -114,7 +117,7 @@ TEST(TFLShapeInferenceRuleTest, avgpool2d_valid)
   rules.bind(loco::CanonicalDialect::get(), &canonical_rule)
       .bind(locoex::TFLDialect::get(), &tfl_rule);
 
-  loco::apply(&rules).to(test_graph.g.get());
+  loco::apply(&rules).to(graph.g.get());
 
   // Verify
   {
@@ -132,13 +135,16 @@ TEST(TFLShapeInferenceRuleTest, avgpool2d_valid)
 
 TEST(TFLShapeInferenceRuleTest, avgpool2d_same)
 {
-  exo::test::PullPushGraph<locoex::TFLAveragePool2D> test_graph;
-  auto pull = test_graph.pull;
+  exo::test::TestGraph graph;
+  auto tfl_node = graph.append<locoex::TFLAveragePool2D>(graph.pull);
+  graph.complete();
+
+  auto pull = graph.pull;
   {
     pull->shape({1, 4, 3, 1});
   }
 
-  auto tfl_node = test_graph.middle_node;
+  // setting TFLAveragePool2D
   {
     tfl_node->filter()->h(2);
     tfl_node->filter()->w(2);
@@ -158,7 +164,7 @@ TEST(TFLShapeInferenceRuleTest, avgpool2d_same)
   rules.bind(loco::CanonicalDialect::get(), &canonical_rule)
       .bind(locoex::TFLDialect::get(), &tfl_rule);
 
-  loco::apply(&rules).to(test_graph.g.get());
+  loco::apply(&rules).to(graph.g.get());
 
   // Verify
   {
index 79d4d4b..867c515 100644 (file)
@@ -30,52 +30,6 @@ namespace exo
 namespace test
 {
 
-// THIS WILL BE DEPRECATED. USE TestGraph instead.
-// graph to build [Pull - some node of type T - Push]
-template <typename T> struct PullPushGraph
-{
-public:
-  std::unique_ptr<loco::Graph> g;
-  loco::Pull *pull;
-  loco::Push *push;
-  T *middle_node;
-
-  PullPushGraph()
-  {
-    // g = Pull - T - Push
-    g = loco::make_graph();
-
-    pull = g->nodes()->create<loco::Pull>();
-
-    middle_node = g->nodes()->create<T>();
-    {
-      setInput();
-    }
-
-    push = g->nodes()->create<loco::Push>();
-    {
-      push->from(middle_node);
-    }
-
-    auto input = g->inputs()->create();
-    {
-      input->name("input");
-      loco::link(input, pull);
-    }
-    auto output = g->outputs()->create();
-    {
-      output->name("output");
-      loco::link(output, push);
-    }
-  }
-
-private:
-  void setInput(); // set the input of T
-};
-
-// setInput of TFL nodes
-template <> void PullPushGraph<locoex::TFLAveragePool2D>::setInput() { middle_node->value(pull); }
-
 class TestGraph
 {
 public: