[exo-tflite] Introduce TFLAdd (#7113)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Wed, 4 Sep 2019 01:43:39 +0000 (10:43 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Wed, 4 Sep 2019 01:43:39 +0000 (10:43 +0900)
This will introduce TFLAdd IR and required summary builder

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
compiler/exo-tflite/src/Dialect/IR/TFLNodes.h
compiler/exo-tflite/src/Dialect/IR/TFLNodes.lst
compiler/exo-tflite/src/Dialect/IR/TFLNodes.test.cpp
compiler/exo-tflite/src/TFLFormattedGraph.cpp

index e8fab4f..d0078e7 100644 (file)
@@ -112,7 +112,21 @@ private:
   int32_t _h;
 };
 
-// TODO TFLAdd
+/**
+ * @brief ADD in TensorFlow Lite
+ */
+class TFLAdd final : public FixedArityNode<2, TFLNodeImpl<TFLOpcode::ADD>>
+{
+public:
+  TFLAdd() = default;
+
+public:
+  loco::Node *x(void) const { return at(0)->node(); }
+  void x(loco::Node *node) { at(0)->node(node); }
+
+  loco::Node *y(void) const { return at(1)->node(); }
+  void y(loco::Node *node) { at(1)->node(node); }
+};
 
 /**
  * @brief AVERAGE_POOL_2D in TensorFlow Lite
index 4c9a527..60059d0 100644 (file)
@@ -5,7 +5,7 @@
 //
 // PLEASE SORT NODE DECLS IN ALPHABETICAL ORDER
 //
-// TODO TFLAdd
+TFL_NODE(ADD, locoex::TFLAdd)
 TFL_NODE(AVERAGE_POOL_2D, locoex::TFLAveragePool2D)
 // TODO TFLConcatenation
 // TODO TFLConv2D
index 6f0abc8..542fcf3 100644 (file)
 
 #include <gtest/gtest.h>
 
-// TODO TFLAdd
+TEST(TFLAddTest, constructor)
+{
+  locoex::TFLAdd add_node;
+
+  ASSERT_EQ(add_node.dialect(), locoex::TFLDialect::get());
+  ASSERT_EQ(add_node.opcode(), locoex::TFLOpcode::ADD);
+
+  ASSERT_EQ(add_node.x(), nullptr);
+  ASSERT_EQ(add_node.y(), nullptr);
+}
 
 // TODO TFLAveragePool2D
 
index 7be3c2a..31a6d99 100644 (file)
@@ -90,7 +90,14 @@ bool TFLNodeSummaryBuilderBase::build(const loco::Node *node, locop::NodeSummary
   return false;
 }
 
-// TODO TFLAdd
+bool TFLNodeSummaryBuilder::summary(const locoex::TFLAdd *node, locop::NodeSummary &s) const
+{
+  s.opname("TFL.ADD");
+  s.args().append("x", tbl()->lookup(node->x()));
+  s.args().append("y", tbl()->lookup(node->y()));
+  s.state(locop::NodeSummary::State::Complete);
+  return true;
+}
 
 bool TFLNodeSummaryBuilder::summary(const locoex::TFLAveragePool2D *node,
                                     locop::NodeSummary &s) const