[moco-tf] Update TFConcatV2 IR to have variadic arity (#6663)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Mon, 19 Aug 2019 04:51:14 +0000 (13:51 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Mon, 19 Aug 2019 04:51:14 +0000 (13:51 +0900)
This will update TFConcatV2 IR to have variadic arity

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
compiler/moco-tf/src/Dialect/TFTypeInferenceRule.cpp
compiler/moco-tf/src/IR/TFConcatV2.h
compiler/moco-tf/src/IR/TFConcatV2.test.cpp

index feb5a77..d8e823e 100644 (file)
@@ -34,7 +34,7 @@ struct TypeForwardAlgorithm final : public moco::tf::TFNodeVisitor<loco::DataTyp
   loco::DataType visit(const TFAdd *node) { return dtype_get(node->x()); }
   loco::DataType visit(const TFAvgPool *node) { return dtype_get(node->value()); }
   loco::DataType visit(const TFBiasAdd *node) { return dtype_get(node->value()); }
-  loco::DataType visit(const TFConcatV2 *node) { return dtype_get(node->lhs()); }
+  loco::DataType visit(const TFConcatV2 *node) { return dtype_get(node->values(0)); }
 
   loco::DataType visit(const TFConst *node) { return node->dtype(); }
 
index 2dd2214..1db44cd 100644 (file)
@@ -19,6 +19,8 @@
 
 #include "Dialect/TFNodeDecl.h"
 
+#include "Dialect/VariadicArityNode.h"
+
 namespace moco
 {
 namespace tf
@@ -53,30 +55,37 @@ node {
 }
 */
 
-/**
- * @note  As there is no VariableArityNode for now, we will import ConcatV2
- *        as cascading of multiple TFConcatV2 nodes like loco::TensorConcat
- */
-
-class TFConcatV2 final : public loco::FixedArityNode<2, TFNodeImpl<TFOpcode::ConcatV2>>
+class TFConcatV2 final : public VariadicArityNode<TFNodeImpl<TFOpcode::ConcatV2>>
 {
 public:
-  TFConcatV2() = default;
+  TFConcatV2(uint32_t arity) : VariadicArityNode<TFNodeImpl<TFOpcode::ConcatV2>>(arity + 1)
+  {
+    // we add +1 for axis of VariadicArityNode ctor
+    // at least one value is required
+    assert(arity >= 1);
+  }
 
 public:
-  Node *lhs(void) const { return at(0)->node(); }
-  void lhs(Node *node) { at(0)->node(node); }
-
-  Node *rhs(void) const { return at(1)->node(); }
-  void rhs(Node *node) { at(1)->node(node); }
+  uint32_t num_values(void) const
+  {
+    // last one is for axis
+    return arity() - 1;
+  }
 
 public:
-  uint32_t axis(void) const { return _axis; }
-  void axis(uint32_t val) { _axis = val; }
+  Node *values(uint32_t index) const
+  {
+    assert(index < num_values());
+    return at(index)->node();
+  }
+  void values(uint32_t index, Node *node)
+  {
+    assert(index < num_values());
+    at(index)->node(node);
+  }
 
-private:
-  // Axis
-  uint32_t _axis{0};
+  Node *axis(void) const { return at(num_values())->node(); }
+  void axis(Node *node) { at(num_values())->node(node); }
 };
 
 } // namespace tf
index 85ea264..89eac1b 100644 (file)
 
 TEST(TFConcatV2Test, constructor)
 {
-  moco::tf::TFConcatV2 concatv2_node;
+  moco::tf::TFConcatV2 concatv2_node(3); // num of values
 
   ASSERT_EQ(concatv2_node.dialect(), moco::tf::TFDialect::get());
   ASSERT_EQ(concatv2_node.opcode(), moco::tf::TFOpcode::ConcatV2);
 
-  ASSERT_EQ(concatv2_node.lhs(), nullptr);
-  ASSERT_EQ(concatv2_node.rhs(), nullptr);
-  ASSERT_EQ(concatv2_node.axis(), 0);
+  ASSERT_EQ(concatv2_node.num_values(), 3);
+  ASSERT_EQ(concatv2_node.values(0), nullptr);
+  ASSERT_EQ(concatv2_node.values(1), nullptr);
+  ASSERT_EQ(concatv2_node.values(2), nullptr);
+  ASSERT_EQ(concatv2_node.axis(), nullptr);
 }