[moco-tf] Introduce fmt helper (#6149)
author박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Fri, 2 Aug 2019 08:21:49 +0000 (17:21 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Fri, 2 Aug 2019 08:21:49 +0000 (17:21 +0900)
This commit introduces fmt helper specialized for moco-tf.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
compiler/moco-tf/src/LogHelper.cpp
compiler/moco-tf/src/LogHelper.h
compiler/moco-tf/src/Phase.cpp
compiler/moco-tf/src/Transforms/ConstantFoldingTransform.test.cpp
compiler/moco-tf/src/Transforms/ResolveFusedBatchNorm.test.cpp

index 04a4fb0..e9d50ab 100644 (file)
@@ -56,3 +56,19 @@ std::ostream &operator<<(std::ostream &os, const std::vector<int64_t> &vi64)
   }
   return os;
 }
+
+#include "TFFormattedGraph.h"
+
+namespace moco
+{
+namespace tf
+{
+
+FormattedGraph fmt(loco::Graph *g)
+{
+  auto node_summary_builder = stdex::make_unique<TFNodeSummaryBuilderFactory>();
+  return std::move(locop::fmt<locop::LinearV1>(g).with(std::move(node_summary_builder)));
+}
+
+} // namespace tf
+} // namespace moco
index 439d856..22d89f0 100644 (file)
@@ -17,6 +17,8 @@
 #ifndef __LOG_HELPER_H__
 #define __LOG_HELPER_H__
 
+#include <locop/FormattedGraph.h>
+
 #include <loco/IR/FeatureShape.h>
 #include <loco/IR/FilterShape.h>
 #include <loco/IR/TensorShape.h>
@@ -49,4 +51,18 @@ std::ostream &operator<<(std::ostream &os, const loco::TensorShape &tensor_shape
  */
 std::ostream &operator<<(std::ostream &os, const std::vector<int64_t> &vi64);
 
+namespace moco
+{
+namespace tf
+{
+
+using FormattedGraph = locop::FormattedGraphImpl<locop::Formatter::LinearV1>;
+
+FormattedGraph fmt(loco::Graph *g);
+
+static inline FormattedGraph fmt(const std::unique_ptr<loco::Graph> &g) { return fmt(g.get()); }
+
+} // namespace tf
+} // namespace moco
+
 #endif // __LOG_HELPER_H__
index 96f4654..2f4d0ad 100644 (file)
  */
 
 #include "Phase.h"
-
-#include "TFFormattedGraph.h"
+#include "LogHelper.h"
 
 #include <moco/Log.h>
-#include <locop/FormattedGraph.h>
 
 namespace
 {
@@ -41,8 +39,7 @@ void PhaseRunner<PhaseStrategy::Saturate>::run(const Phase &phase) const
   INFO(l) << "PhaseRunner<Saturate>";
 
   INFO(l) << "Initial graph";
-  INFO(l) << locop::fmt<locop::LinearV1>(_graph).with(
-      stdex::make_unique<TFNodeSummaryBuilderFactory>());
+  INFO(l) << fmt(_graph);
 
   for (bool changed = true; changed;)
   {
@@ -62,8 +59,7 @@ void PhaseRunner<PhaseStrategy::Saturate>::run(const Phase &phase) const
       }
 
       INFO(l) << "After " << transform_name(tr.get()) << " (changed: " << to_char(chg_one) << ")";
-      INFO(l) << locop::fmt<locop::LinearV1>(_graph).with(
-          stdex::make_unique<TFNodeSummaryBuilderFactory>());
+      INFO(l) << fmt(_graph);
     }
   }
 
index 9a63bba..0c57c9f 100644 (file)
 
 #include "ConstantFoldingTransform.h"
 
+#include "LogHelper.h"
 #include "TestHelper.h"
 #include "IR/TFFusedBatchNorm.h"
 #include "Importer.h"
 #include "Canonicalizer.h"
-#include "TFFormattedGraph.h"
 
 #include <loco.h>
-#include <locop/FormattedGraph.h>
 #include <moco/Log.h>
 #include <plier/tf/TestHelper.h>
 
@@ -97,8 +96,7 @@ TEST(ConstantFolding, case01)
   canonicalizer.canonicalize(graph.get());
 
   INFO(l) << "Before ConstantFolding";
-  INFO(l) << locop::fmt<locop::LinearV1>(graph).with(
-      stdex::make_unique<moco::tf::TFNodeSummaryBuilderFactory>());
+  INFO(l) << moco::tf::fmt(graph);
 
   moco::tf::ConstantFoldingTransform transform;
   while (transform.run(graph.get()) == true)
@@ -107,8 +105,7 @@ TEST(ConstantFolding, case01)
   }
 
   INFO(l) << "After ConstantFolding ";
-  INFO(l) << locop::fmt<locop::LinearV1>(graph).with(
-      stdex::make_unique<moco::tf::TFNodeSummaryBuilderFactory>());
+  INFO(l) << moco::tf::fmt(graph);
 
   auto push = moco::tf::test::find_first_node_bytype<loco::Push>(graph.get());
   auto const_gen = dynamic_cast<loco::ConstGen *>(push->from());
@@ -236,8 +233,7 @@ TEST(ConstantFolding, case02)
   canonicalizer.canonicalize(graph.get());
 
   INFO(l) << "Before ConstantFolding";
-  INFO(l) << locop::fmt<locop::LinearV1>(graph).with(
-      stdex::make_unique<moco::tf::TFNodeSummaryBuilderFactory>());
+  INFO(l) << moco::tf::fmt(graph);
 
   moco::tf::ConstantFoldingTransform transform;
   while (transform.run(graph.get()) == true)
@@ -246,8 +242,7 @@ TEST(ConstantFolding, case02)
   }
 
   INFO(l) << "After ConstantFolding ";
-  INFO(l) << locop::fmt<locop::LinearV1>(graph).with(
-      stdex::make_unique<moco::tf::TFNodeSummaryBuilderFactory>());
+  INFO(l) << moco::tf::fmt(graph);
 
   auto push = moco::tf::test::find_first_node_bytype<loco::Push>(graph.get());
   auto const_gen = dynamic_cast<loco::ConstGen *>(push->from());
index c2ccb80..de4e105 100644 (file)
 
 #include "ResolveFusedBatchNorm.h"
 
+#include "LogHelper.h"
 #include "TestHelper.h"
 #include "IR/TFFusedBatchNorm.h"
 #include "Importer.h"
-#include "TFFormattedGraph.h"
 
 #include <loco.h>
-#include <locop/FormattedGraph.h>
 #include <moco/Log.h>
 #include <stdex/Memory.h>
 #include <plier/tf/TestHelper.h>
@@ -217,15 +216,13 @@ TEST(ResolveFusedBatchNorm, fbn_resolve_basic)
   auto graph = importer.import(signature, graph_def);
 
   INFO(l) << "Before ResolveFusedBatchNorm";
-  INFO(l) << locop::fmt<locop::LinearV1>(graph).with(
-      stdex::make_unique<moco::tf::TFNodeSummaryBuilderFactory>());
+  INFO(l) << moco::tf::fmt(graph);
 
   moco::tf::ResolveFusedBatchNorm transform;
   bool changed = transform.run(graph.get());
 
   INFO(l) << "After ResolveFusedBatchNorm " << to_char(changed);
-  INFO(l) << locop::fmt<locop::LinearV1>(graph).with(
-      stdex::make_unique<moco::tf::TFNodeSummaryBuilderFactory>());
+  INFO(l) << moco::tf::fmt(graph);
 
   // Output value test will be done with mocotest-tf
   // Network structure of transformation is not important and may be changed