From: 박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 Date: Tue, 17 Sep 2019 07:35:46 +0000 (+0900) Subject: [exo-tflite] Introduce Div and Sub (#7514) X-Git-Tag: accepted/tizen/unified/20190918.102349~23 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=b3d041950855f5f3d2884012ab004394e430ff0d;p=platform%2Fcore%2Fml%2Fnnfw.git [exo-tflite] Introduce Div and Sub (#7514) This will introduce IR for Div and Sub Signed-off-by: SaeHie Park --- diff --git a/compiler/exo-tflite/src/Dialect/IR/TFLNodes.h b/compiler/exo-tflite/src/Dialect/IR/TFLNodes.h index f9ff222..42de748 100644 --- a/compiler/exo-tflite/src/Dialect/IR/TFLNodes.h +++ b/compiler/exo-tflite/src/Dialect/IR/TFLNodes.h @@ -166,7 +166,21 @@ private: // TODO TFLDepthwiseConv2D -// TODO TFLDiv +/** + * @brief DIV in TensorFlow Lite + */ +class TFLDiv final : public FixedArityNode<2, TFLNodeImpl> +{ +public: + TFLDiv() = default; + +public: + loco::Node *x(void) const { return at(0)->node(); } + void x(loco::Node *node) { at(0)->node(node); } + + loco::Node *y(void) const { return at(1)->node(); } + void y(loco::Node *node) { at(1)->node(node); } +}; /** * @brief MAX_POOL_2D in TensorFlow Lite @@ -234,7 +248,21 @@ public: // TODO TFLSqrt -// TODO TFLSub +/** + * @brief SUB in TensorFlow Lite + */ +class TFLSub final : public FixedArityNode<2, TFLNodeImpl> +{ +public: + TFLSub() = default; + +public: + loco::Node *x(void) const { return at(0)->node(); } + void x(loco::Node *node) { at(0)->node(node); } + + loco::Node *y(void) const { return at(1)->node(); } + void y(loco::Node *node) { at(1)->node(node); } +}; // TODO TFLTanh diff --git a/compiler/exo-tflite/src/Dialect/IR/TFLNodes.lst b/compiler/exo-tflite/src/Dialect/IR/TFLNodes.lst index e870282..a77ac80 100644 --- a/compiler/exo-tflite/src/Dialect/IR/TFLNodes.lst +++ b/compiler/exo-tflite/src/Dialect/IR/TFLNodes.lst @@ -10,7 +10,7 @@ TFL_NODE(AVERAGE_POOL_2D, locoex::TFLAveragePool2D) // TODO TFLConcatenation // TODO TFLConv2D // TODO TFLDepthwiseConv2D -// TODO TFLDiv +TFL_NODE(DIV, locoex::TFLDiv) TFL_NODE(MAX_POOL_2D, locoex::TFLMaxPool2D) TFL_NODE(MUL, locoex::TFLMul) TFL_NODE(RELU, locoex::TFLRelu) @@ -18,6 +18,6 @@ TFL_NODE(RELU, locoex::TFLRelu) // TODO TFLReshape // TODO TFLSoftmax // TODO TFLSqrt -// TODO TFLSub +TFL_NODE(SUB, locoex::TFLSub) // TODO TFLTanh // TODO TFLTranspose diff --git a/compiler/exo-tflite/src/Dialect/IR/TFLNodes.test.cpp b/compiler/exo-tflite/src/Dialect/IR/TFLNodes.test.cpp index efe23f1..b52b452 100644 --- a/compiler/exo-tflite/src/Dialect/IR/TFLNodes.test.cpp +++ b/compiler/exo-tflite/src/Dialect/IR/TFLNodes.test.cpp @@ -40,7 +40,16 @@ TEST(TFLAddTest, constructor) // TODO TFLDepthwiseConv2D -// TODO TFLDiv +TEST(TFLDivTest, constructor) +{ + locoex::TFLDiv div_node; + + ASSERT_EQ(div_node.dialect(), locoex::TFLDialect::get()); + ASSERT_EQ(div_node.opcode(), locoex::TFLOpcode::DIV); + + ASSERT_EQ(div_node.x(), nullptr); + ASSERT_EQ(div_node.y(), nullptr); +} // TODO TFLMaxPool2D @@ -73,7 +82,16 @@ TEST(TFLReluTest, constructor) // TODO TFLSqrt -// TODO TFLSub +TEST(TFLSubTest, constructor) +{ + locoex::TFLSub sub_node; + + ASSERT_EQ(sub_node.dialect(), locoex::TFLDialect::get()); + ASSERT_EQ(sub_node.opcode(), locoex::TFLOpcode::SUB); + + ASSERT_EQ(sub_node.x(), nullptr); + ASSERT_EQ(sub_node.y(), nullptr); +} // TODO TFLTanh diff --git a/compiler/exo-tflite/src/TFLFormattedGraph.cpp b/compiler/exo-tflite/src/TFLFormattedGraph.cpp index 8b1cfde..52ce0ff 100644 --- a/compiler/exo-tflite/src/TFLFormattedGraph.cpp +++ b/compiler/exo-tflite/src/TFLFormattedGraph.cpp @@ -131,7 +131,13 @@ bool TFLNodeSummaryBuilder::summary(const locoex::TFLAveragePool2D *node, // TODO TFLDepthwiseConv2D -// TODO TFLDiv +bool TFLNodeSummaryBuilder::summary(const locoex::TFLDiv *node, locop::NodeSummary &s) const +{ + s.args().append("x", tbl()->lookup(node->x())); + s.args().append("y", tbl()->lookup(node->y())); + s.state(locop::NodeSummary::State::Complete); + return true; +} bool TFLNodeSummaryBuilder::summary(const locoex::TFLMaxPool2D *node, locop::NodeSummary &s) const { @@ -163,7 +169,13 @@ bool TFLNodeSummaryBuilder::summary(const locoex::TFLRelu *node, locop::NodeSumm // TODO TFLSqrt -// TODO TFLSub +bool TFLNodeSummaryBuilder::summary(const locoex::TFLSub *node, locop::NodeSummary &s) const +{ + s.args().append("x", tbl()->lookup(node->x())); + s.args().append("y", tbl()->lookup(node->y())); + s.state(locop::NodeSummary::State::Complete); + return true; +} // TODO TFLTanh