From 59aeabf5b21fa753581a11c8d0a080a0c179c217 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=B2=9C=EA=B5=90/On-Device=20Lab=28SR=29/Enginee?= =?utf8?q?r/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Mon, 4 Nov 2019 07:09:46 +0900 Subject: [PATCH] [exo] TFLMean IR (#8698) This commit introduces TFLMean IR Signed-off-by: Cheongyo Bahk --- compiler/exo/src/Dialect/IR/TFLNodes.h | 17 +++++++++++++++++ compiler/exo/src/Dialect/IR/TFLNodes.lst | 1 + compiler/exo/src/TFLFormattedGraph.cpp | 9 +++++++++ 3 files changed, 27 insertions(+) diff --git a/compiler/exo/src/Dialect/IR/TFLNodes.h b/compiler/exo/src/Dialect/IR/TFLNodes.h index d438241..d4ef31a 100644 --- a/compiler/exo/src/Dialect/IR/TFLNodes.h +++ b/compiler/exo/src/Dialect/IR/TFLNodes.h @@ -357,6 +357,23 @@ private: Filter _filter; }; +class TFLMean final : public FixedArityNode<2, TFLNodeImpl> +{ +public: + loco::Node *input(void) const { return at(0)->node(); } + void input(loco::Node *node) { at(0)->node(node); } + + loco::Node *reduction_indices(void) const { return at(1)->node(); } + void reduction_indices(loco::Node *node) { at(1)->node(node); } + +public: + bool keep_dims(void) const { return _keep_dims; } + void keep_dims(bool keep_dims) { _keep_dims = keep_dims; } + +private: + bool _keep_dims = false; +}; + /** * @brief MUL in TensorFlow Lite */ diff --git a/compiler/exo/src/Dialect/IR/TFLNodes.lst b/compiler/exo/src/Dialect/IR/TFLNodes.lst index b04b093..20584dc 100644 --- a/compiler/exo/src/Dialect/IR/TFLNodes.lst +++ b/compiler/exo/src/Dialect/IR/TFLNodes.lst @@ -13,6 +13,7 @@ TFL_NODE(CONV_2D, locoex::TFLConv2D) TFL_NODE(DEPTHWISE_CONV_2D, locoex::TFLDepthwiseConv2D) TFL_NODE(DIV, locoex::TFLDiv) TFL_NODE(MAX_POOL_2D, locoex::TFLMaxPool2D) +TFL_NODE(MEAN, locoex::TFLMean) TFL_NODE(MUL, locoex::TFLMul) TFL_NODE(RELU, locoex::TFLRelu) TFL_NODE(RELU6, locoex::TFLRelu6) diff --git a/compiler/exo/src/TFLFormattedGraph.cpp b/compiler/exo/src/TFLFormattedGraph.cpp index b733a5b..eadf28a 100644 --- a/compiler/exo/src/TFLFormattedGraph.cpp +++ b/compiler/exo/src/TFLFormattedGraph.cpp @@ -262,6 +262,15 @@ bool TFLNodeSummaryBuilder::summary(const locoex::TFLMaxPool2D *node, locop::Nod return true; } +bool TFLNodeSummaryBuilder::summary(const locoex::TFLMean *node, locop::NodeSummary &s) const +{ + s.args().append("input", tbl()->lookup(node->input())); + s.args().append("reduction_indices", tbl()->lookup(node->reduction_indices())); + s.args().append("keep_dims", node->keep_dims() ? "true" : "false"); + s.state(locop::NodeSummary::State::Complete); + return true; +} + bool TFLNodeSummaryBuilder::summary(const locoex::TFLMul *node, locop::NodeSummary &s) const { auto fused = node->fusedActivationFunction(); -- 2.7.4