using namespace moco::tf;
-class TFNodeSummaryBuilder final : public locop::NodeSummaryBuilder
+/// TFNodeSummaryBuilder with default implementation
+class TFNodeSummaryBuilderBase : public locop::NodeSummaryBuilder
{
public:
- TFNodeSummaryBuilder(const locop::SymbolTable *tbl) : _tbl{tbl}
+ TFNodeSummaryBuilderBase(const locop::SymbolTable *tbl) : _tbl{tbl}
{
// DO NOTHING
}
public:
bool build(const loco::Node *, locop::NodeSummary &s) const final;
-private:
-#define TENSORFLOW_NODE(OPCODE, CLASS) bool summary(const CLASS *, locop::NodeSummary &) const;
+protected:
+#define TENSORFLOW_NODE(OPCODE, CLASS) \
+ virtual bool summary(const CLASS *, locop::NodeSummary &) const { return false; }
#include "Dialect/TFNodes.lst"
#undef TENSORFLOW_NODE
-private:
+protected:
const locop::SymbolTable *_tbl;
};
-bool TFNodeSummaryBuilder::build(const loco::Node *node, locop::NodeSummary &s) const
+class TFNodeSummaryBuilder final : public TFNodeSummaryBuilderBase
+{
+public:
+ TFNodeSummaryBuilder(const locop::SymbolTable *tbl) : TFNodeSummaryBuilderBase(_tbl)
+ {
+ // DO NOTHING
+ }
+
+private:
+#define IMPLEMENT(CLASS) bool summary(const CLASS *, locop::NodeSummary &) const final
+ IMPLEMENT(TFAdd);
+ IMPLEMENT(TFBiasAdd);
+ IMPLEMENT(TFConst);
+ IMPLEMENT(TFConv2D);
+ IMPLEMENT(TFFusedBatchNorm);
+ IMPLEMENT(TFMul);
+ IMPLEMENT(TFRsqrt);
+#undef IMPLEMENT
+};
+
+bool TFNodeSummaryBuilderBase::build(const loco::Node *node, locop::NodeSummary &s) const
{
if (node->dialect() != TFDialect::get())
return false;