[moco/tf] Add graph transformation step (#3674)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Mon, 3 Jun 2019 23:51:47 +0000 (08:51 +0900)
committer박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Mon, 3 Jun 2019 23:51:47 +0000 (08:51 +0900)
This will add graph transformation step after loading has finished

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
contrib/moco/lib/frontend/tf/src/Frontend.cpp

index 4a32c26..ce47a20 100644 (file)
@@ -20,6 +20,8 @@
 #include "GraphBuilderContext.h"
 #include "GraphBuilderRegistry.h"
 
+#include "Transforms/Transform.h"
+
 #include <loco/IR/Verifier.h>
 #include <cwrap/Fildes.h>
 #include <stdex/Memory.h>
@@ -221,6 +223,52 @@ void convert_graph(const moco::tf::ModelSignature &signature, tensorflow::GraphD
   assert(loco::valid(graph));
 }
 
+void transform_graph(loco::Graph *graph)
+{
+  std::vector<std::unique_ptr<moco::tf::Transform>> prepare;
+  std::vector<std::unique_ptr<moco::tf::Transform>> transforms;
+  std::vector<std::unique_ptr<moco::tf::Transform>> finalize;
+
+  // Transforms that run only once for preparation and finalization
+  {
+      // TODO add one time preparation when needed
+
+      // TODO add one time finalization when needed
+  }
+
+  // Transforms that run multiple times until there is no transform occured
+  {
+    // TODO add more TensorFlow related transformations
+  }
+
+  // Run preparation
+  for (auto &tr : prepare)
+  {
+    tr->run(graph);
+  }
+
+  bool changed;
+  do
+  {
+    changed = false;
+
+    for (auto &tr : transforms)
+    {
+      changed |= tr->run(graph);
+    }
+
+  } while (changed);
+
+  // Run finalize to cleanup temporary annotations
+  for (auto &tr : finalize)
+  {
+    tr->run(graph);
+  }
+
+  // validate graph
+  assert(loco::valid(graph));
+}
+
 } // namespace
 
 namespace moco
@@ -244,6 +292,8 @@ std::unique_ptr<loco::Graph> Frontend::load(const ModelSignature &signature, con
 
   convert_graph(signature, tf_graph_def, graph.get());
 
+  transform_graph(graph.get());
+
   return std::move(graph);
 }
 
@@ -258,6 +308,8 @@ std::unique_ptr<loco::Graph> Frontend::load(const ModelSignature &signature, std
 
   convert_graph(signature, tf_graph_def, graph.get());
 
+  transform_graph(graph.get());
+
   return std::move(graph);
 }