From 9482a9a70fe1a33d584f8cd0b6f702a4e941acb5 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=B2=9C=EA=B5=90/On-Device=20Lab=28SR=29/Enginee?= =?utf8?q?r/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Thu, 31 Oct 2019 14:14:49 +0900 Subject: [PATCH] [exo] TFL nodes for Sqrt, Rsqrt, SquaredDifference (#8623) This commit introduces simple TFL-dialect nodes: - locoex::TFLSqrt - locoex::TFLRsqrt - locoex::TFLSquaredDifference Signed-off-by: Cheongyo Bahk --- compiler/exo/src/Dialect/IR/TFLNodes.h | 34 +++++++++++++++++++++++++++++++- compiler/exo/src/Dialect/IR/TFLNodes.lst | 4 +++- compiler/exo/src/TFLFormattedGraph.cpp | 23 ++++++++++++++++++++- 3 files changed, 58 insertions(+), 3 deletions(-) diff --git a/compiler/exo/src/Dialect/IR/TFLNodes.h b/compiler/exo/src/Dialect/IR/TFLNodes.h index 93b482e..d438241 100644 --- a/compiler/exo/src/Dialect/IR/TFLNodes.h +++ b/compiler/exo/src/Dialect/IR/TFLNodes.h @@ -426,9 +426,41 @@ private: Shape _new_shape; }; +class TFLRsqrt final : public FixedArityNode<1, TFLNodeImpl> +{ +public: + TFLRsqrt() = default; + +public: + loco::Node *x(void) const { return at(0)->node(); } + void x(loco::Node *node) { at(0)->node(node); } +}; + // TODO TFLSoftmax -// TODO TFLSqrt +class TFLSqrt final : public FixedArityNode<1, TFLNodeImpl> +{ +public: + TFLSqrt() = default; + +public: + loco::Node *x(void) const { return at(0)->node(); } + void x(loco::Node *node) { at(0)->node(node); } +}; + +class TFLSquaredDifference final + : public FixedArityNode<2, TFLNodeImpl> +{ +public: + TFLSquaredDifference() = default; + +public: + loco::Node *x(void) const { return at(0)->node(); } + void x(loco::Node *node) { at(0)->node(node); } + + loco::Node *y(void) const { return at(1)->node(); } + void y(loco::Node *node) { at(1)->node(node); } +}; /** * @brief SUB in TensorFlow Lite diff --git a/compiler/exo/src/Dialect/IR/TFLNodes.lst b/compiler/exo/src/Dialect/IR/TFLNodes.lst index fa85181..b04b093 100644 --- a/compiler/exo/src/Dialect/IR/TFLNodes.lst +++ b/compiler/exo/src/Dialect/IR/TFLNodes.lst @@ -17,8 +17,10 @@ TFL_NODE(MUL, locoex::TFLMul) TFL_NODE(RELU, locoex::TFLRelu) TFL_NODE(RELU6, locoex::TFLRelu6) TFL_NODE(RESHAPE, locoex::TFLReshape) +TFL_NODE(RSQRT, locoex::TFLRsqrt) // TODO TFLSoftmax -// TODO TFLSqrt +TFL_NODE(SQRT, locoex::TFLSqrt) +TFL_NODE(SQUARED_DIFFERENCE, locoex::TFLSquaredDifference) TFL_NODE(SUB, locoex::TFLSub) // TODO TFLTanh TFL_NODE(TRANSPOSE, locoex::TFLTranspose) diff --git a/compiler/exo/src/TFLFormattedGraph.cpp b/compiler/exo/src/TFLFormattedGraph.cpp index bcafc2d..b733a5b 100644 --- a/compiler/exo/src/TFLFormattedGraph.cpp +++ b/compiler/exo/src/TFLFormattedGraph.cpp @@ -297,9 +297,30 @@ bool TFLNodeSummaryBuilder::summary(const locoex::TFLReshape *node, locop::NodeS return true; } +bool TFLNodeSummaryBuilder::summary(const locoex::TFLRsqrt *node, locop::NodeSummary &s) const +{ + s.args().append("x", tbl()->lookup(node->x())); + s.state(locop::NodeSummary::State::Complete); + return true; +} + // TODO TFLSoftmax -// TODO TFLSqrt +bool TFLNodeSummaryBuilder::summary(const locoex::TFLSqrt *node, locop::NodeSummary &s) const +{ + s.args().append("x", tbl()->lookup(node->x())); + s.state(locop::NodeSummary::State::Complete); + return true; +} + +bool TFLNodeSummaryBuilder::summary(const locoex::TFLSquaredDifference *node, + locop::NodeSummary &s) const +{ + s.args().append("x", tbl()->lookup(node->x())); + s.args().append("y", tbl()->lookup(node->y())); + s.state(locop::NodeSummary::State::Complete); + return true; +} bool TFLNodeSummaryBuilder::summary(const locoex::TFLSub *node, locop::NodeSummary &s) const { -- 2.7.4