From befe6681c2d545f204fddece98d68b36b24cd386 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=A2=85=ED=98=84/On-Device=20Lab=28SR=29/Staff?= =?utf8?q?=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Thu, 18 Jul 2019 13:48:14 +0900 Subject: [PATCH] [moco/tf] Introduce TFNodeSummaryBuilderBase (#4335) This commit introduces TFNodeSummaryBuilderBase which provides a default pretty printer for all TF nodes. Signed-off-by: Jonghyun Park --- contrib/moco-tf/src/TFFormattedGraph.cpp | 34 ++++++++++++++++++++++++++------ 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/contrib/moco-tf/src/TFFormattedGraph.cpp b/contrib/moco-tf/src/TFFormattedGraph.cpp index 13907af..2730ef5 100644 --- a/contrib/moco-tf/src/TFFormattedGraph.cpp +++ b/contrib/moco-tf/src/TFFormattedGraph.cpp @@ -28,10 +28,11 @@ namespace using namespace moco::tf; -class TFNodeSummaryBuilder final : public locop::NodeSummaryBuilder +/// TFNodeSummaryBuilder with default implementation +class TFNodeSummaryBuilderBase : public locop::NodeSummaryBuilder { public: - TFNodeSummaryBuilder(const locop::SymbolTable *tbl) : _tbl{tbl} + TFNodeSummaryBuilderBase(const locop::SymbolTable *tbl) : _tbl{tbl} { // DO NOTHING } @@ -39,16 +40,37 @@ public: public: bool build(const loco::Node *, locop::NodeSummary &s) const final; -private: -#define TENSORFLOW_NODE(OPCODE, CLASS) bool summary(const CLASS *, locop::NodeSummary &) const; +protected: +#define TENSORFLOW_NODE(OPCODE, CLASS) \ + virtual bool summary(const CLASS *, locop::NodeSummary &) const { return false; } #include "Dialect/TFNodes.lst" #undef TENSORFLOW_NODE -private: +protected: const locop::SymbolTable *_tbl; }; -bool TFNodeSummaryBuilder::build(const loco::Node *node, locop::NodeSummary &s) const +class TFNodeSummaryBuilder final : public TFNodeSummaryBuilderBase +{ +public: + TFNodeSummaryBuilder(const locop::SymbolTable *tbl) : TFNodeSummaryBuilderBase(_tbl) + { + // DO NOTHING + } + +private: +#define IMPLEMENT(CLASS) bool summary(const CLASS *, locop::NodeSummary &) const final + IMPLEMENT(TFAdd); + IMPLEMENT(TFBiasAdd); + IMPLEMENT(TFConst); + IMPLEMENT(TFConv2D); + IMPLEMENT(TFFusedBatchNorm); + IMPLEMENT(TFMul); + IMPLEMENT(TFRsqrt); +#undef IMPLEMENT +}; + +bool TFNodeSummaryBuilderBase::build(const loco::Node *node, locop::NodeSummary &s) const { if (node->dialect() != TFDialect::get()) return false; -- 2.7.4