From b3d041950855f5f3d2884012ab004394e430ff0d Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=84=B8=ED=9D=AC/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Tue, 17 Sep 2019 16:35:46 +0900 Subject: [PATCH] [exo-tflite] Introduce Div and Sub (#7514) This will introduce IR for Div and Sub Signed-off-by: SaeHie Park --- compiler/exo-tflite/src/Dialect/IR/TFLNodes.h | 32 ++++++++++++++++++++-- compiler/exo-tflite/src/Dialect/IR/TFLNodes.lst | 4 +-- .../exo-tflite/src/Dialect/IR/TFLNodes.test.cpp | 22 +++++++++++++-- compiler/exo-tflite/src/TFLFormattedGraph.cpp | 16 +++++++++-- 4 files changed, 66 insertions(+), 8 deletions(-) diff --git a/compiler/exo-tflite/src/Dialect/IR/TFLNodes.h b/compiler/exo-tflite/src/Dialect/IR/TFLNodes.h index f9ff222..42de748 100644 --- a/compiler/exo-tflite/src/Dialect/IR/TFLNodes.h +++ b/compiler/exo-tflite/src/Dialect/IR/TFLNodes.h @@ -166,7 +166,21 @@ private: // TODO TFLDepthwiseConv2D -// TODO TFLDiv +/** + * @brief DIV in TensorFlow Lite + */ +class TFLDiv final : public FixedArityNode<2, TFLNodeImpl> +{ +public: + TFLDiv() = default; + +public: + loco::Node *x(void) const { return at(0)->node(); } + void x(loco::Node *node) { at(0)->node(node); } + + loco::Node *y(void) const { return at(1)->node(); } + void y(loco::Node *node) { at(1)->node(node); } +}; /** * @brief MAX_POOL_2D in TensorFlow Lite @@ -234,7 +248,21 @@ public: // TODO TFLSqrt -// TODO TFLSub +/** + * @brief SUB in TensorFlow Lite + */ +class TFLSub final : public FixedArityNode<2, TFLNodeImpl> +{ +public: + TFLSub() = default; + +public: + loco::Node *x(void) const { return at(0)->node(); } + void x(loco::Node *node) { at(0)->node(node); } + + loco::Node *y(void) const { return at(1)->node(); } + void y(loco::Node *node) { at(1)->node(node); } +}; // TODO TFLTanh diff --git a/compiler/exo-tflite/src/Dialect/IR/TFLNodes.lst b/compiler/exo-tflite/src/Dialect/IR/TFLNodes.lst index e870282..a77ac80 100644 --- a/compiler/exo-tflite/src/Dialect/IR/TFLNodes.lst +++ b/compiler/exo-tflite/src/Dialect/IR/TFLNodes.lst @@ -10,7 +10,7 @@ TFL_NODE(AVERAGE_POOL_2D, locoex::TFLAveragePool2D) // TODO TFLConcatenation // TODO TFLConv2D // TODO TFLDepthwiseConv2D -// TODO TFLDiv +TFL_NODE(DIV, locoex::TFLDiv) TFL_NODE(MAX_POOL_2D, locoex::TFLMaxPool2D) TFL_NODE(MUL, locoex::TFLMul) TFL_NODE(RELU, locoex::TFLRelu) @@ -18,6 +18,6 @@ TFL_NODE(RELU, locoex::TFLRelu) // TODO TFLReshape // TODO TFLSoftmax // TODO TFLSqrt -// TODO TFLSub +TFL_NODE(SUB, locoex::TFLSub) // TODO TFLTanh // TODO TFLTranspose diff --git a/compiler/exo-tflite/src/Dialect/IR/TFLNodes.test.cpp b/compiler/exo-tflite/src/Dialect/IR/TFLNodes.test.cpp index efe23f1..b52b452 100644 --- a/compiler/exo-tflite/src/Dialect/IR/TFLNodes.test.cpp +++ b/compiler/exo-tflite/src/Dialect/IR/TFLNodes.test.cpp @@ -40,7 +40,16 @@ TEST(TFLAddTest, constructor) // TODO TFLDepthwiseConv2D -// TODO TFLDiv +TEST(TFLDivTest, constructor) +{ + locoex::TFLDiv div_node; + + ASSERT_EQ(div_node.dialect(), locoex::TFLDialect::get()); + ASSERT_EQ(div_node.opcode(), locoex::TFLOpcode::DIV); + + ASSERT_EQ(div_node.x(), nullptr); + ASSERT_EQ(div_node.y(), nullptr); +} // TODO TFLMaxPool2D @@ -73,7 +82,16 @@ TEST(TFLReluTest, constructor) // TODO TFLSqrt -// TODO TFLSub +TEST(TFLSubTest, constructor) +{ + locoex::TFLSub sub_node; + + ASSERT_EQ(sub_node.dialect(), locoex::TFLDialect::get()); + ASSERT_EQ(sub_node.opcode(), locoex::TFLOpcode::SUB); + + ASSERT_EQ(sub_node.x(), nullptr); + ASSERT_EQ(sub_node.y(), nullptr); +} // TODO TFLTanh diff --git a/compiler/exo-tflite/src/TFLFormattedGraph.cpp b/compiler/exo-tflite/src/TFLFormattedGraph.cpp index 8b1cfde..52ce0ff 100644 --- a/compiler/exo-tflite/src/TFLFormattedGraph.cpp +++ b/compiler/exo-tflite/src/TFLFormattedGraph.cpp @@ -131,7 +131,13 @@ bool TFLNodeSummaryBuilder::summary(const locoex::TFLAveragePool2D *node, // TODO TFLDepthwiseConv2D -// TODO TFLDiv +bool TFLNodeSummaryBuilder::summary(const locoex::TFLDiv *node, locop::NodeSummary &s) const +{ + s.args().append("x", tbl()->lookup(node->x())); + s.args().append("y", tbl()->lookup(node->y())); + s.state(locop::NodeSummary::State::Complete); + return true; +} bool TFLNodeSummaryBuilder::summary(const locoex::TFLMaxPool2D *node, locop::NodeSummary &s) const { @@ -163,7 +169,13 @@ bool TFLNodeSummaryBuilder::summary(const locoex::TFLRelu *node, locop::NodeSumm // TODO TFLSqrt -// TODO TFLSub +bool TFLNodeSummaryBuilder::summary(const locoex::TFLSub *node, locop::NodeSummary &s) const +{ + s.args().append("x", tbl()->lookup(node->x())); + s.args().append("y", tbl()->lookup(node->y())); + s.state(locop::NodeSummary::State::Complete); + return true; +} // TODO TFLTanh -- 2.7.4