From: 박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 Date: Mon, 19 Aug 2019 04:51:14 +0000 (+0900) Subject: [moco-tf] Update TFConcatV2 IR to have variadic arity (#6663) X-Git-Tag: accepted/tizen/unified/20190903.052428~341 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=ecccb90d337114195a6c89f7be1a65fbca2b5923;p=platform%2Fcore%2Fml%2Fnnfw.git [moco-tf] Update TFConcatV2 IR to have variadic arity (#6663) This will update TFConcatV2 IR to have variadic arity Signed-off-by: SaeHie Park --- diff --git a/compiler/moco-tf/src/Dialect/TFTypeInferenceRule.cpp b/compiler/moco-tf/src/Dialect/TFTypeInferenceRule.cpp index feb5a77..d8e823e 100644 --- a/compiler/moco-tf/src/Dialect/TFTypeInferenceRule.cpp +++ b/compiler/moco-tf/src/Dialect/TFTypeInferenceRule.cpp @@ -34,7 +34,7 @@ struct TypeForwardAlgorithm final : public moco::tf::TFNodeVisitorx()); } loco::DataType visit(const TFAvgPool *node) { return dtype_get(node->value()); } loco::DataType visit(const TFBiasAdd *node) { return dtype_get(node->value()); } - loco::DataType visit(const TFConcatV2 *node) { return dtype_get(node->lhs()); } + loco::DataType visit(const TFConcatV2 *node) { return dtype_get(node->values(0)); } loco::DataType visit(const TFConst *node) { return node->dtype(); } diff --git a/compiler/moco-tf/src/IR/TFConcatV2.h b/compiler/moco-tf/src/IR/TFConcatV2.h index 2dd2214..1db44cd 100644 --- a/compiler/moco-tf/src/IR/TFConcatV2.h +++ b/compiler/moco-tf/src/IR/TFConcatV2.h @@ -19,6 +19,8 @@ #include "Dialect/TFNodeDecl.h" +#include "Dialect/VariadicArityNode.h" + namespace moco { namespace tf @@ -53,30 +55,37 @@ node { } */ -/** - * @note As there is no VariableArityNode for now, we will import ConcatV2 - * as cascading of multiple TFConcatV2 nodes like loco::TensorConcat - */ - -class TFConcatV2 final : public loco::FixedArityNode<2, TFNodeImpl> +class TFConcatV2 final : public VariadicArityNode> { public: - TFConcatV2() = default; + TFConcatV2(uint32_t arity) : VariadicArityNode>(arity + 1) + { + // we add +1 for axis of VariadicArityNode ctor + // at least one value is required + assert(arity >= 1); + } public: - Node *lhs(void) const { return at(0)->node(); } - void lhs(Node *node) { at(0)->node(node); } - - Node *rhs(void) const { return at(1)->node(); } - void rhs(Node *node) { at(1)->node(node); } + uint32_t num_values(void) const + { + // last one is for axis + return arity() - 1; + } public: - uint32_t axis(void) const { return _axis; } - void axis(uint32_t val) { _axis = val; } + Node *values(uint32_t index) const + { + assert(index < num_values()); + return at(index)->node(); + } + void values(uint32_t index, Node *node) + { + assert(index < num_values()); + at(index)->node(node); + } -private: - // Axis - uint32_t _axis{0}; + Node *axis(void) const { return at(num_values())->node(); } + void axis(Node *node) { at(num_values())->node(node); } }; } // namespace tf diff --git a/compiler/moco-tf/src/IR/TFConcatV2.test.cpp b/compiler/moco-tf/src/IR/TFConcatV2.test.cpp index 85ea264..89eac1b 100644 --- a/compiler/moco-tf/src/IR/TFConcatV2.test.cpp +++ b/compiler/moco-tf/src/IR/TFConcatV2.test.cpp @@ -22,12 +22,14 @@ TEST(TFConcatV2Test, constructor) { - moco::tf::TFConcatV2 concatv2_node; + moco::tf::TFConcatV2 concatv2_node(3); // num of values ASSERT_EQ(concatv2_node.dialect(), moco::tf::TFDialect::get()); ASSERT_EQ(concatv2_node.opcode(), moco::tf::TFOpcode::ConcatV2); - ASSERT_EQ(concatv2_node.lhs(), nullptr); - ASSERT_EQ(concatv2_node.rhs(), nullptr); - ASSERT_EQ(concatv2_node.axis(), 0); + ASSERT_EQ(concatv2_node.num_values(), 3); + ASSERT_EQ(concatv2_node.values(0), nullptr); + ASSERT_EQ(concatv2_node.values(1), nullptr); + ASSERT_EQ(concatv2_node.values(2), nullptr); + ASSERT_EQ(concatv2_node.axis(), nullptr); }