loco::DataType visit(const TFAdd *node) { return dtype_get(node->x()); }
loco::DataType visit(const TFAvgPool *node) { return dtype_get(node->value()); }
loco::DataType visit(const TFBiasAdd *node) { return dtype_get(node->value()); }
- loco::DataType visit(const TFConcatV2 *node) { return dtype_get(node->lhs()); }
+ loco::DataType visit(const TFConcatV2 *node) { return dtype_get(node->values(0)); }
loco::DataType visit(const TFConst *node) { return node->dtype(); }
#include "Dialect/TFNodeDecl.h"
+#include "Dialect/VariadicArityNode.h"
+
namespace moco
{
namespace tf
}
*/
-/**
- * @note As there is no VariableArityNode for now, we will import ConcatV2
- * as cascading of multiple TFConcatV2 nodes like loco::TensorConcat
- */
-
-class TFConcatV2 final : public loco::FixedArityNode<2, TFNodeImpl<TFOpcode::ConcatV2>>
+class TFConcatV2 final : public VariadicArityNode<TFNodeImpl<TFOpcode::ConcatV2>>
{
public:
- TFConcatV2() = default;
+ TFConcatV2(uint32_t arity) : VariadicArityNode<TFNodeImpl<TFOpcode::ConcatV2>>(arity + 1)
+ {
+ // we add +1 for axis of VariadicArityNode ctor
+ // at least one value is required
+ assert(arity >= 1);
+ }
public:
- Node *lhs(void) const { return at(0)->node(); }
- void lhs(Node *node) { at(0)->node(node); }
-
- Node *rhs(void) const { return at(1)->node(); }
- void rhs(Node *node) { at(1)->node(node); }
+ uint32_t num_values(void) const
+ {
+ // last one is for axis
+ return arity() - 1;
+ }
public:
- uint32_t axis(void) const { return _axis; }
- void axis(uint32_t val) { _axis = val; }
+ Node *values(uint32_t index) const
+ {
+ assert(index < num_values());
+ return at(index)->node();
+ }
+ void values(uint32_t index, Node *node)
+ {
+ assert(index < num_values());
+ at(index)->node(node);
+ }
-private:
- // Axis
- uint32_t _axis{0};
+ Node *axis(void) const { return at(num_values())->node(); }
+ void axis(Node *node) { at(num_values())->node(node); }
};
} // namespace tf
TEST(TFConcatV2Test, constructor)
{
- moco::tf::TFConcatV2 concatv2_node;
+ moco::tf::TFConcatV2 concatv2_node(3); // num of values
ASSERT_EQ(concatv2_node.dialect(), moco::tf::TFDialect::get());
ASSERT_EQ(concatv2_node.opcode(), moco::tf::TFOpcode::ConcatV2);
- ASSERT_EQ(concatv2_node.lhs(), nullptr);
- ASSERT_EQ(concatv2_node.rhs(), nullptr);
- ASSERT_EQ(concatv2_node.axis(), 0);
+ ASSERT_EQ(concatv2_node.num_values(), 3);
+ ASSERT_EQ(concatv2_node.values(0), nullptr);
+ ASSERT_EQ(concatv2_node.values(1), nullptr);
+ ASSERT_EQ(concatv2_node.values(2), nullptr);
+ ASSERT_EQ(concatv2_node.axis(), nullptr);
}