[moco/tf] Introduce TFNodeSummaryBuilderBase (#4335)
author박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Thu, 18 Jul 2019 04:48:14 +0000 (13:48 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Thu, 18 Jul 2019 04:48:14 +0000 (13:48 +0900)
This commit introduces TFNodeSummaryBuilderBase which provides a default
pretty printer for all TF nodes.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
contrib/moco-tf/src/TFFormattedGraph.cpp

index 13907af..2730ef5 100644 (file)
@@ -28,10 +28,11 @@ namespace
 
 using namespace moco::tf;
 
-class TFNodeSummaryBuilder final : public locop::NodeSummaryBuilder
+/// TFNodeSummaryBuilder with default implementation
+class TFNodeSummaryBuilderBase : public locop::NodeSummaryBuilder
 {
 public:
-  TFNodeSummaryBuilder(const locop::SymbolTable *tbl) : _tbl{tbl}
+  TFNodeSummaryBuilderBase(const locop::SymbolTable *tbl) : _tbl{tbl}
   {
     // DO NOTHING
   }
@@ -39,16 +40,37 @@ public:
 public:
   bool build(const loco::Node *, locop::NodeSummary &s) const final;
 
-private:
-#define TENSORFLOW_NODE(OPCODE, CLASS) bool summary(const CLASS *, locop::NodeSummary &) const;
+protected:
+#define TENSORFLOW_NODE(OPCODE, CLASS) \
+  virtual bool summary(const CLASS *, locop::NodeSummary &) const { return false; }
 #include "Dialect/TFNodes.lst"
 #undef TENSORFLOW_NODE
 
-private:
+protected:
   const locop::SymbolTable *_tbl;
 };
 
-bool TFNodeSummaryBuilder::build(const loco::Node *node, locop::NodeSummary &s) const
+class TFNodeSummaryBuilder final : public TFNodeSummaryBuilderBase
+{
+public:
+  TFNodeSummaryBuilder(const locop::SymbolTable *tbl) : TFNodeSummaryBuilderBase(_tbl)
+  {
+    // DO NOTHING
+  }
+
+private:
+#define IMPLEMENT(CLASS) bool summary(const CLASS *, locop::NodeSummary &) const final
+  IMPLEMENT(TFAdd);
+  IMPLEMENT(TFBiasAdd);
+  IMPLEMENT(TFConst);
+  IMPLEMENT(TFConv2D);
+  IMPLEMENT(TFFusedBatchNorm);
+  IMPLEMENT(TFMul);
+  IMPLEMENT(TFRsqrt);
+#undef IMPLEMENT
+};
+
+bool TFNodeSummaryBuilderBase::build(const loco::Node *node, locop::NodeSummary &s) const
 {
   if (node->dialect() != TFDialect::get())
     return false;