[moco-tf] Graph build Concat as TF dialect (#8131)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Mon, 14 Oct 2019 22:20:22 +0000 (07:20 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Mon, 14 Oct 2019 22:20:22 +0000 (07:20 +0900)
This will remove Graph builder fro Concat as Canonical dialect and build only TF dialect

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
compiler/moco-tf/src/Op/Concat.cpp
compiler/moco-tf/src/Op/Concat.h
compiler/moco-tf/src/Op/Concat.test.cpp

index 33788cd..eab5390 100644 (file)
@@ -40,21 +40,6 @@ namespace
 
 using namespace moco::tf;
 
-class ConcatV2GraphUpdate final : public GraphUpdate
-{
-public:
-  ConcatV2GraphUpdate(std::vector<loco::TensorConcat *> nodes, std::vector<TensorName> names)
-      : _nodes(nodes), _names(names)
-  {
-  }
-
-  void input(const SymbolTable *) const override;
-
-private:
-  std::vector<loco::TensorConcat *> _nodes;
-  std::vector<TensorName> _names;
-};
-
 class TFConcatV2GraphUpdate final : public GraphUpdate
 {
 public:
@@ -70,25 +55,6 @@ private:
   std::vector<TensorName> _names;
 };
 
-void ConcatV2GraphUpdate::input(const SymbolTable *tensor_names) const
-{
-  int num_inputs = _names.size();
-  assert(num_inputs >= 2);
-  assert(num_inputs == _nodes.size());
-
-  loco::Node *target;
-  // do "%0.lhs : %in[0].name" connection
-  target = tensor_names->node(_names[0]);
-  _nodes[0]->lhs(target);
-
-  for (int i = 1; i < num_inputs; ++i)
-  {
-    // do "%i.rhs : %in[i].name" connections
-    target = tensor_names->node(_names[i]);
-    _nodes[i]->rhs(target);
-  }
-}
-
 void TFConcatV2GraphUpdate::input(const SymbolTable *tensor_names) const
 {
   uint32_t num_values = _names.size() - 1; // exclude axis
@@ -112,7 +78,7 @@ namespace moco
 namespace tf
 {
 
-bool ConcatV2GraphBuilderBase::validate(const tensorflow::NodeDef &node) const
+bool ConcatV2GraphBuilder::validate(const tensorflow::NodeDef &node) const
 {
   if (!plier::tf::has_attrs(node, {"T", "N", "Tidx"}))
     return false;
@@ -124,126 +90,11 @@ bool ConcatV2GraphBuilderBase::validate(const tensorflow::NodeDef &node) const
   return (num_inputs >= 2) && (num_inputs == plier::tf::get_int_attr(node, "N"));
 }
 
-/**
- * @brief GraphBuilder for Concat node of Tensor
- */
-class ConcatV2GraphBuilder final : public ConcatV2GraphBuilderBase
-{
-public:
-  void build(const tensorflow::NodeDef &, GraphBuilderContext *) const override;
-};
-
 void ConcatV2GraphBuilder::build(const tensorflow::NodeDef &node,
                                  GraphBuilderContext *context) const
 {
   assert(context != nullptr);
 
-  if (moco::tf::get<moco::tf::Knob::ImportAsTFConcatV2>())
-  {
-    ConcatV2GraphBuilderImpl<ImportTarget::TensorFlow> builder;
-    return builder.build(node, context);
-  }
-  else
-  {
-    ConcatV2GraphBuilderImpl<ImportTarget::Canonical> builder;
-    return builder.build(node, context);
-  }
-}
-
-void ConcatV2GraphBuilderImpl<ImportTarget::Canonical>::build(const tensorflow::NodeDef &node,
-                                                              GraphBuilderContext *context) const
-{
-  loco::Graph *graph = context->graph();
-  NodeDefTable *nodedef = context->nodedef();
-  SymbolTable *tensor_names = context->tensor_names();
-  UpdateQueue *updates = context->updates();
-
-  // Concat has 2 or more inputs and loco TensorConcat is fixed to 2 inputs
-  // for arbitrary N inputs (beginning from 0), TensorConcat will be created
-  // as follows;
-  // %0 = TensorConcat(%in[0], %in[1])
-  // %1 = %0 --> this is to match index of input name
-  // %2 = TensorConcat(%1, %in[2])
-  // ...
-  // %(N-1) = TensorConcat(%(N-2), %in[N-1]))
-  // %N = TensorConcat(%(N-1), %in[N]))
-  //
-  // Output of this sub graph will be set to %N with node.name()
-  //
-  // As we know that each input exist, one of input(lhs) can be linked while creating
-  // %2.lhs = %1
-  // %3.lhs = %2
-  // ...
-  // %(N-1).lhs = %(N-2)
-  // %N.lhs = %(N-1)
-
-  const int num_inputs = node.input_size() - 1;
-
-  std::vector<loco::TensorConcat *> concat_nodes;
-  std::vector<TensorName> input_names;
-
-  auto concat_node = graph->nodes()->create<loco::TensorConcat>();
-  loco::TensorConcat *last_concat = concat_node;
-
-  // Queue node input update
-  concat_nodes.push_back(concat_node);              // used for LHS of connection -> %0
-  concat_nodes.push_back(concat_node);              // used for RHS of connection -> %1
-  input_names.push_back(TensorName(node.input(0))); // for first concat (%0) LHS
-  input_names.push_back(TensorName(node.input(1))); // for first concat (%1) RHS
-
-  for (int ni = 2; ni < num_inputs; ++ni)
-  {
-    auto concat_node_next = graph->nodes()->create<loco::TensorConcat>();
-
-    concat_nodes.push_back(concat_node_next);
-    input_names.push_back(TensorName(node.input(ni)));
-
-    // connect LHS as we know the nodes
-    concat_node_next->lhs(last_concat);
-
-    // update last concat node
-    last_concat = concat_node_next;
-  }
-
-  // register string-name to the last node as output of concat(s)
-  TensorName output_name(node.name(), 0);
-  tensor_names->enroll(output_name, last_concat);
-
-  // Find axis tensorflow::NodeDef and get the axis number
-  std::string axis_name = node.input(num_inputs);
-  const tensorflow::NodeDef *tfnode = nodedef->node(axis_name);
-  // assume data type is int32
-  assert(plier::tf::get_datatype_attr(*tfnode, "dtype") == tensorflow::DataType::DT_INT32);
-  const auto &tensor = plier::tf::get_tensor_attr(*tfnode, "value");
-  assert(tensor.int_val_size() == 1);
-  auto axis_value_read = tensor.int_val(0);
-
-  // set axis for all concat(s) as temporary data
-  // as the first and the second items are actually the same one, skip it.
-  std::vector<loco::TensorConcat *>::iterator iter = concat_nodes.begin();
-  for (++iter; iter != concat_nodes.end(); ++iter)
-  {
-    auto concat_node = *iter;
-    auto concat_data = stdex::make_unique<ConcatData>(axis_value_read);
-
-    concat_node->annot(std::move(concat_data));
-  }
-
-  // Input name queue is created like this in 'concat_nodes' and 'input_names'
-  // %0.lhs : %in[0].name
-  // %1.rhs : %in[1].name (as %0 == %1)
-  // %2.rhs : %in[2].name
-  // %3.rhs : %in[3].name
-  // ...
-  // %(N-2).rhs : %in[N-2].name
-  // %(N-1).rhs : %in[N-1].name
-  auto update = stdex::make_unique<ConcatV2GraphUpdate>(concat_nodes, input_names);
-  updates->enroll(std::move(update));
-}
-
-void ConcatV2GraphBuilderImpl<ImportTarget::TensorFlow>::build(const tensorflow::NodeDef &node,
-                                                               GraphBuilderContext *context) const
-{
   loco::Graph *graph = context->graph();
   NodeDefTable *nodedef = context->nodedef();
   SymbolTable *tensor_names = context->tensor_names();
index 6a5a857..0b4cfee 100644 (file)
 #define __OP_CONCAT_H__
 
 #include "GraphBuilder.h"
-#include "ImportTarget.h"
 
 namespace moco
 {
 namespace tf
 {
 
-struct ConcatV2GraphBuilderBase : public GraphBuilder
+class ConcatV2GraphBuilder : public GraphBuilder
 {
-  virtual ~ConcatV2GraphBuilderBase() = default;
-
+public:
   bool validate(const tensorflow::NodeDef &) const final;
-};
-
-template <ImportTarget T> class ConcatV2GraphBuilderImpl;
-
-template <>
-struct ConcatV2GraphBuilderImpl<ImportTarget::Canonical> final : public ConcatV2GraphBuilderBase
-{
-  void build(const tensorflow::NodeDef &, GraphBuilderContext *) const final;
-};
-
-template <>
-struct ConcatV2GraphBuilderImpl<ImportTarget::TensorFlow> final : public ConcatV2GraphBuilderBase
-{
   void build(const tensorflow::NodeDef &, GraphBuilderContext *) const final;
 };
 
index e9ddd6b..ec3adaf 100644 (file)
@@ -162,49 +162,13 @@ TEST(TensorFlowImport, concat_01)
   EXPECT_TRUE(plier::tf::parse_graphdef(concat_01_pbtxtdata, graph_def));
   std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
 
-  // Test "ConcatV2GraphBuilderImpl<ImportTarget::Canonical>"
-  {
-    // TODO fix indent
-    // clang-format off
-
-  // what to test:
-  // - there should exist TensorConcat
-  // - lhs() should not be nullptr
-  // - rhs() should not be nullptr
-  // - axis() should match
-
-    using ConcatV2GraphBuilder = ConcatV2GraphBuilderImpl<ImportTarget::Canonical>;
-
-    moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()};
-    r.add("ConcatV2", stdex::make_unique<ConcatV2GraphBuilder>());
-    moco::tf::Importer importer{&r};
-
-    std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
-
-  loco::TensorConcat *concat_node =
-      moco::tf::test::find_first_node_bytype<loco::TensorConcat>(graph.get());
-
-  ASSERT_NE(concat_node, nullptr);
-  ASSERT_NE(concat_node->lhs(), nullptr);
-  ASSERT_NE(concat_node->rhs(), nullptr);
-  ASSERT_EQ(concat_node->axis(), 0);
-
-    // clang-format on
-  }
-
-  // Test "ConcatV2GraphBuilderImpl<ImportTarget::Tensorflow>"
   {
     // what to test:
     // - there should exist TFConcatV2
     // - there should be two values
     // - values(idx) should not be nullptr
     // - axis() should not be nullptr
-
-    using ConcatV2GraphBuilder = ConcatV2GraphBuilderImpl<ImportTarget::TensorFlow>;
-
-    moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()};
-    r.add("ConcatV2", stdex::make_unique<ConcatV2GraphBuilder>());
-    moco::tf::Importer importer{&r};
+    moco::tf::Importer importer;
 
     std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
 
@@ -382,58 +346,12 @@ TEST(TensorFlowImport, concat_02)
   tensorflow::GraphDef graph_def;
   EXPECT_TRUE(plier::tf::parse_graphdef(concat_02_pbtxtdata, graph_def));
 
-  // Test "ConcatV2GraphBuilderImpl<ImportTarget::Canonical>"
-  {
-    // TODO fix indent
-    // clang-format off
-
-  // what to test: Concat has 3 inputs --> Importer creates 2 TensorConcat
-  // - there should exist two TensorConcat
-  // - lhs() of #1 should not be nullptr
-  // - rhs() of #1 should not be nullptr
-  // - lhs() of #2 should be #1
-  // - rhs() of #2 should not be nullptr
-  // - axis() should match
-
-    using ConcatV2GraphBuilder = ConcatV2GraphBuilderImpl<ImportTarget::Canonical>;
-
-    moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()};
-    r.add("ConcatV2", stdex::make_unique<ConcatV2GraphBuilder>());
-    moco::tf::Importer importer{&r};
-
-    std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
-
-  std::vector<loco::TensorConcat *> concat_nodes =
-      moco::tf::test::find_nodes_bytype<loco::TensorConcat>(graph.get());
-  ASSERT_EQ(concat_nodes.size(), 2);
-  loco::TensorConcat *concat_node0 = concat_nodes.at(0);
-  loco::TensorConcat *concat_node1 = concat_nodes.at(1);
-
-  ASSERT_NE(concat_node0, nullptr);
-  ASSERT_NE(concat_node1, nullptr);
-  ASSERT_NE(concat_node0->lhs(), nullptr);
-  ASSERT_NE(concat_node0->rhs(), nullptr);
-  ASSERT_NE(concat_node1->lhs(), nullptr);
-  ASSERT_NE(concat_node1->rhs(), nullptr);
-  ASSERT_TRUE(concat_node0->lhs() == concat_node1 || concat_node1->lhs() == concat_node0);
-  ASSERT_EQ(concat_node0->axis(), 0);
-  ASSERT_EQ(concat_node1->axis(), 0);
-
-    // clang-format on
-  }
-
-  // Test "ConcatV2GraphBuilderImpl<ImportTarget::TensorFlow>"
   {
     // what to test: TFConcatV2 has 3 inputs
     // - there should exist TFConcatV2
     // - values(idx) should not be nullptr
     // - axis() should not be nullptr
-
-    using ConcatV2GraphBuilder = ConcatV2GraphBuilderImpl<ImportTarget::TensorFlow>;
-
-    moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()};
-    r.add("ConcatV2", stdex::make_unique<ConcatV2GraphBuilder>());
-    moco::tf::Importer importer{&r};
+    moco::tf::Importer importer;
 
     std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);