From eaa67d71d14ab6dc11c9d32dbf97d6f6365ad085 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=84=B8=ED=9D=AC/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Tue, 15 Oct 2019 07:20:22 +0900 Subject: [PATCH] [moco-tf] Graph build Concat as TF dialect (#8131) This will remove Graph builder fro Concat as Canonical dialect and build only TF dialect Signed-off-by: SaeHie Park --- compiler/moco-tf/src/Op/Concat.cpp | 151 +------------------------------- compiler/moco-tf/src/Op/Concat.h | 19 +--- compiler/moco-tf/src/Op/Concat.test.cpp | 86 +----------------- 3 files changed, 5 insertions(+), 251 deletions(-) diff --git a/compiler/moco-tf/src/Op/Concat.cpp b/compiler/moco-tf/src/Op/Concat.cpp index 33788cd..eab5390 100644 --- a/compiler/moco-tf/src/Op/Concat.cpp +++ b/compiler/moco-tf/src/Op/Concat.cpp @@ -40,21 +40,6 @@ namespace using namespace moco::tf; -class ConcatV2GraphUpdate final : public GraphUpdate -{ -public: - ConcatV2GraphUpdate(std::vector nodes, std::vector names) - : _nodes(nodes), _names(names) - { - } - - void input(const SymbolTable *) const override; - -private: - std::vector _nodes; - std::vector _names; -}; - class TFConcatV2GraphUpdate final : public GraphUpdate { public: @@ -70,25 +55,6 @@ private: std::vector _names; }; -void ConcatV2GraphUpdate::input(const SymbolTable *tensor_names) const -{ - int num_inputs = _names.size(); - assert(num_inputs >= 2); - assert(num_inputs == _nodes.size()); - - loco::Node *target; - // do "%0.lhs : %in[0].name" connection - target = tensor_names->node(_names[0]); - _nodes[0]->lhs(target); - - for (int i = 1; i < num_inputs; ++i) - { - // do "%i.rhs : %in[i].name" connections - target = tensor_names->node(_names[i]); - _nodes[i]->rhs(target); - } -} - void TFConcatV2GraphUpdate::input(const SymbolTable *tensor_names) const { uint32_t num_values = _names.size() - 1; // exclude axis @@ -112,7 +78,7 @@ namespace moco namespace tf { -bool ConcatV2GraphBuilderBase::validate(const tensorflow::NodeDef &node) const +bool ConcatV2GraphBuilder::validate(const tensorflow::NodeDef &node) const { if (!plier::tf::has_attrs(node, {"T", "N", "Tidx"})) return false; @@ -124,126 +90,11 @@ bool ConcatV2GraphBuilderBase::validate(const tensorflow::NodeDef &node) const return (num_inputs >= 2) && (num_inputs == plier::tf::get_int_attr(node, "N")); } -/** - * @brief GraphBuilder for Concat node of Tensor - */ -class ConcatV2GraphBuilder final : public ConcatV2GraphBuilderBase -{ -public: - void build(const tensorflow::NodeDef &, GraphBuilderContext *) const override; -}; - void ConcatV2GraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderContext *context) const { assert(context != nullptr); - if (moco::tf::get()) - { - ConcatV2GraphBuilderImpl builder; - return builder.build(node, context); - } - else - { - ConcatV2GraphBuilderImpl builder; - return builder.build(node, context); - } -} - -void ConcatV2GraphBuilderImpl::build(const tensorflow::NodeDef &node, - GraphBuilderContext *context) const -{ - loco::Graph *graph = context->graph(); - NodeDefTable *nodedef = context->nodedef(); - SymbolTable *tensor_names = context->tensor_names(); - UpdateQueue *updates = context->updates(); - - // Concat has 2 or more inputs and loco TensorConcat is fixed to 2 inputs - // for arbitrary N inputs (beginning from 0), TensorConcat will be created - // as follows; - // %0 = TensorConcat(%in[0], %in[1]) - // %1 = %0 --> this is to match index of input name - // %2 = TensorConcat(%1, %in[2]) - // ... - // %(N-1) = TensorConcat(%(N-2), %in[N-1])) - // %N = TensorConcat(%(N-1), %in[N])) - // - // Output of this sub graph will be set to %N with node.name() - // - // As we know that each input exist, one of input(lhs) can be linked while creating - // %2.lhs = %1 - // %3.lhs = %2 - // ... - // %(N-1).lhs = %(N-2) - // %N.lhs = %(N-1) - - const int num_inputs = node.input_size() - 1; - - std::vector concat_nodes; - std::vector input_names; - - auto concat_node = graph->nodes()->create(); - loco::TensorConcat *last_concat = concat_node; - - // Queue node input update - concat_nodes.push_back(concat_node); // used for LHS of connection -> %0 - concat_nodes.push_back(concat_node); // used for RHS of connection -> %1 - input_names.push_back(TensorName(node.input(0))); // for first concat (%0) LHS - input_names.push_back(TensorName(node.input(1))); // for first concat (%1) RHS - - for (int ni = 2; ni < num_inputs; ++ni) - { - auto concat_node_next = graph->nodes()->create(); - - concat_nodes.push_back(concat_node_next); - input_names.push_back(TensorName(node.input(ni))); - - // connect LHS as we know the nodes - concat_node_next->lhs(last_concat); - - // update last concat node - last_concat = concat_node_next; - } - - // register string-name to the last node as output of concat(s) - TensorName output_name(node.name(), 0); - tensor_names->enroll(output_name, last_concat); - - // Find axis tensorflow::NodeDef and get the axis number - std::string axis_name = node.input(num_inputs); - const tensorflow::NodeDef *tfnode = nodedef->node(axis_name); - // assume data type is int32 - assert(plier::tf::get_datatype_attr(*tfnode, "dtype") == tensorflow::DataType::DT_INT32); - const auto &tensor = plier::tf::get_tensor_attr(*tfnode, "value"); - assert(tensor.int_val_size() == 1); - auto axis_value_read = tensor.int_val(0); - - // set axis for all concat(s) as temporary data - // as the first and the second items are actually the same one, skip it. - std::vector::iterator iter = concat_nodes.begin(); - for (++iter; iter != concat_nodes.end(); ++iter) - { - auto concat_node = *iter; - auto concat_data = stdex::make_unique(axis_value_read); - - concat_node->annot(std::move(concat_data)); - } - - // Input name queue is created like this in 'concat_nodes' and 'input_names' - // %0.lhs : %in[0].name - // %1.rhs : %in[1].name (as %0 == %1) - // %2.rhs : %in[2].name - // %3.rhs : %in[3].name - // ... - // %(N-2).rhs : %in[N-2].name - // %(N-1).rhs : %in[N-1].name - auto update = stdex::make_unique(concat_nodes, input_names); - updates->enroll(std::move(update)); -} - -void ConcatV2GraphBuilderImpl::build(const tensorflow::NodeDef &node, - GraphBuilderContext *context) const -{ loco::Graph *graph = context->graph(); NodeDefTable *nodedef = context->nodedef(); SymbolTable *tensor_names = context->tensor_names(); diff --git a/compiler/moco-tf/src/Op/Concat.h b/compiler/moco-tf/src/Op/Concat.h index 6a5a857..0b4cfee 100644 --- a/compiler/moco-tf/src/Op/Concat.h +++ b/compiler/moco-tf/src/Op/Concat.h @@ -18,31 +18,16 @@ #define __OP_CONCAT_H__ #include "GraphBuilder.h" -#include "ImportTarget.h" namespace moco { namespace tf { -struct ConcatV2GraphBuilderBase : public GraphBuilder +class ConcatV2GraphBuilder : public GraphBuilder { - virtual ~ConcatV2GraphBuilderBase() = default; - +public: bool validate(const tensorflow::NodeDef &) const final; -}; - -template class ConcatV2GraphBuilderImpl; - -template <> -struct ConcatV2GraphBuilderImpl final : public ConcatV2GraphBuilderBase -{ - void build(const tensorflow::NodeDef &, GraphBuilderContext *) const final; -}; - -template <> -struct ConcatV2GraphBuilderImpl final : public ConcatV2GraphBuilderBase -{ void build(const tensorflow::NodeDef &, GraphBuilderContext *) const final; }; diff --git a/compiler/moco-tf/src/Op/Concat.test.cpp b/compiler/moco-tf/src/Op/Concat.test.cpp index e9ddd6b..ec3adaf 100644 --- a/compiler/moco-tf/src/Op/Concat.test.cpp +++ b/compiler/moco-tf/src/Op/Concat.test.cpp @@ -162,49 +162,13 @@ TEST(TensorFlowImport, concat_01) EXPECT_TRUE(plier::tf::parse_graphdef(concat_01_pbtxtdata, graph_def)); std::unique_ptr graph = importer.import(signature, graph_def); - // Test "ConcatV2GraphBuilderImpl" - { - // TODO fix indent - // clang-format off - - // what to test: - // - there should exist TensorConcat - // - lhs() should not be nullptr - // - rhs() should not be nullptr - // - axis() should match - - using ConcatV2GraphBuilder = ConcatV2GraphBuilderImpl; - - moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()}; - r.add("ConcatV2", stdex::make_unique()); - moco::tf::Importer importer{&r}; - - std::unique_ptr graph = importer.import(signature, graph_def); - - loco::TensorConcat *concat_node = - moco::tf::test::find_first_node_bytype(graph.get()); - - ASSERT_NE(concat_node, nullptr); - ASSERT_NE(concat_node->lhs(), nullptr); - ASSERT_NE(concat_node->rhs(), nullptr); - ASSERT_EQ(concat_node->axis(), 0); - - // clang-format on - } - - // Test "ConcatV2GraphBuilderImpl" { // what to test: // - there should exist TFConcatV2 // - there should be two values // - values(idx) should not be nullptr // - axis() should not be nullptr - - using ConcatV2GraphBuilder = ConcatV2GraphBuilderImpl; - - moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()}; - r.add("ConcatV2", stdex::make_unique()); - moco::tf::Importer importer{&r}; + moco::tf::Importer importer; std::unique_ptr graph = importer.import(signature, graph_def); @@ -382,58 +346,12 @@ TEST(TensorFlowImport, concat_02) tensorflow::GraphDef graph_def; EXPECT_TRUE(plier::tf::parse_graphdef(concat_02_pbtxtdata, graph_def)); - // Test "ConcatV2GraphBuilderImpl" - { - // TODO fix indent - // clang-format off - - // what to test: Concat has 3 inputs --> Importer creates 2 TensorConcat - // - there should exist two TensorConcat - // - lhs() of #1 should not be nullptr - // - rhs() of #1 should not be nullptr - // - lhs() of #2 should be #1 - // - rhs() of #2 should not be nullptr - // - axis() should match - - using ConcatV2GraphBuilder = ConcatV2GraphBuilderImpl; - - moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()}; - r.add("ConcatV2", stdex::make_unique()); - moco::tf::Importer importer{&r}; - - std::unique_ptr graph = importer.import(signature, graph_def); - - std::vector concat_nodes = - moco::tf::test::find_nodes_bytype(graph.get()); - ASSERT_EQ(concat_nodes.size(), 2); - loco::TensorConcat *concat_node0 = concat_nodes.at(0); - loco::TensorConcat *concat_node1 = concat_nodes.at(1); - - ASSERT_NE(concat_node0, nullptr); - ASSERT_NE(concat_node1, nullptr); - ASSERT_NE(concat_node0->lhs(), nullptr); - ASSERT_NE(concat_node0->rhs(), nullptr); - ASSERT_NE(concat_node1->lhs(), nullptr); - ASSERT_NE(concat_node1->rhs(), nullptr); - ASSERT_TRUE(concat_node0->lhs() == concat_node1 || concat_node1->lhs() == concat_node0); - ASSERT_EQ(concat_node0->axis(), 0); - ASSERT_EQ(concat_node1->axis(), 0); - - // clang-format on - } - - // Test "ConcatV2GraphBuilderImpl" { // what to test: TFConcatV2 has 3 inputs // - there should exist TFConcatV2 // - values(idx) should not be nullptr // - axis() should not be nullptr - - using ConcatV2GraphBuilder = ConcatV2GraphBuilderImpl; - - moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()}; - r.add("ConcatV2", stdex::make_unique()); - moco::tf::Importer importer{&r}; + moco::tf::Importer importer; std::unique_ptr graph = importer.import(signature, graph_def); -- 2.7.4