[moco-tf] Import as TFConcatV2 with a knob (#6348)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Thu, 8 Aug 2019 01:42:30 +0000 (10:42 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Thu, 8 Aug 2019 01:42:30 +0000 (10:42 +0900)
This will change import of ConcatV2 node as Canonical or TF-dialect with a knob.

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
compiler/moco-tf/src/Knob.lst
compiler/moco-tf/src/Op/Concat.cpp
compiler/moco-tf/src/Op/Concat.h [new file with mode: 0644]
compiler/moco-tf/src/Op/Concat.test.cpp

index f778ebc..2cda2e6 100644 (file)
@@ -7,6 +7,7 @@
 // Imports
 KNOB_BOOL(ImportAsTFAvgPool, true, Import AvgPool2D node as TFAvgPool node)
 KNOB_BOOL(ImportAsTFBiasAdd, true, Import BiasAdd node as TFBiasAdd node)
+KNOB_BOOL(ImportAsTFConcatV2, false, Import ConcatV2 node as TFConcatV2 node)
 KNOB_BOOL(ImportAsTFConst, true, Import Const node as TFConst node)
 KNOB_BOOL(ImportAsTFConv2D, true, Import Conv2D node as TFConv2D node)
 KNOB_BOOL(ImportAsTFIdentity, true, Import Identity node as TFIdentity node)
index d608c78..bbdfb45 100644 (file)
  * limitations under the License.
  */
 
+#include "Concat.h"
+
 #include "GraphBuilder.h"
 #include "GraphBuilderContext.h"
+#include "Knob.h"
+
+#include "IR/TFConcatV2.h"
 
 #include "Annotations/ConcatData.h"
 
 #include <cassert>
 #include <stdexcept>
 
-namespace moco
-{
-namespace tf
+namespace
 {
 
-/**
- * @brief GraphBuilder for Concat node of Tensor
- */
-class ConcatV2GraphBuilder final : public GraphBuilder
-{
-public:
-  bool validate(const tensorflow::NodeDef &) const override;
-  void build(const tensorflow::NodeDef &, GraphBuilderContext *) const override;
-};
+using namespace moco::tf;
 
 class ConcatV2GraphUpdate final : public GraphUpdate
 {
@@ -60,7 +55,29 @@ private:
   std::vector<TensorName> _names;
 };
 
-bool ConcatV2GraphBuilder::validate(const tensorflow::NodeDef &node) const
+class TFConcatV2GraphUpdate final : public GraphUpdate
+{
+public:
+  TFConcatV2GraphUpdate(std::vector<moco::tf::TFConcatV2 *> nodes, std::vector<TensorName> names)
+      : _nodes(nodes), _names(names)
+  {
+  }
+
+  void input(const SymbolTable *) const override;
+
+private:
+  std::vector<moco::tf::TFConcatV2 *> _nodes;
+  std::vector<TensorName> _names;
+};
+
+} // namespace
+
+namespace moco
+{
+namespace tf
+{
+
+bool ConcatV2GraphBuilderBase::validate(const tensorflow::NodeDef &node) const
 {
   // Concat node SHOULD have 3 or more inputs, that is 2 + axis
   const int num_inputs = node.input_size() - 1;
@@ -69,11 +86,35 @@ bool ConcatV2GraphBuilder::validate(const tensorflow::NodeDef &node) const
   return (num_inputs >= 2) && (num_inputs == plier::tf::get_int_attr(node, "N"));
 }
 
+/**
+ * @brief GraphBuilder for Concat node of Tensor
+ */
+class ConcatV2GraphBuilder final : public ConcatV2GraphBuilderBase
+{
+public:
+  void build(const tensorflow::NodeDef &, GraphBuilderContext *) const override;
+};
+
 void ConcatV2GraphBuilder::build(const tensorflow::NodeDef &node,
                                  GraphBuilderContext *context) const
 {
   assert(context != nullptr);
 
+  if (moco::tf::get<moco::tf::Knob::ImportAsTFConcatV2>())
+  {
+    ConcatV2GraphBuilderImpl<ImportTarget::TensorFlow> builder;
+    return builder.build(node, context);
+  }
+  else
+  {
+    ConcatV2GraphBuilderImpl<ImportTarget::Canonical> builder;
+    return builder.build(node, context);
+  }
+}
+
+void ConcatV2GraphBuilderImpl<ImportTarget::Canonical>::build(const tensorflow::NodeDef &node,
+                                                              GraphBuilderContext *context) const
+{
   loco::Graph *graph = context->graph();
   NodeDefTable *nodedef = context->nodedef();
   SymbolTable *tensor_names = context->tensor_names();
@@ -162,6 +203,109 @@ void ConcatV2GraphBuilder::build(const tensorflow::NodeDef &node,
   updates->enroll(std::move(update));
 }
 
+void ConcatV2GraphBuilderImpl<ImportTarget::TensorFlow>::build(const tensorflow::NodeDef &node,
+                                                               GraphBuilderContext *context) const
+{
+  /**
+   * @note  This implementation may change when TFConcatV2 inherits from
+   *        something like VariableArity.
+   */
+
+  loco::Graph *graph = context->graph();
+  NodeDefTable *nodedef = context->nodedef();
+  SymbolTable *tensor_names = context->tensor_names();
+  UpdateQueue *updates = context->updates();
+
+  // Concat has 2 or more inputs and current TFConcatV2 is fixed to 2 inputs
+  // for arbitrary N inputs (beginning from 0), TFConcatV2 will be created
+  // as follows;
+  // %0 = TFConcatV2(%in[0], %in[1])
+  // %1 = %0 --> this is to match index of input name
+  // %2 = TFConcatV2(%1, %in[2])
+  // ...
+  // %(N-1) = TFConcatV2(%(N-2), %in[N-1]))
+  // %N = TFConcatV2(%(N-1), %in[N]))
+  //
+  // Output of this sub graph will be set to %N with node.name()
+  //
+  // As we know that each input exist, one of input(lhs) can be linked while creating
+  // %2.lhs = %1
+  // %3.lhs = %2
+  // ...
+  // %(N-1).lhs = %(N-2)
+  // %N.lhs = %(N-1)
+
+  const int num_inputs = node.input_size() - 1;
+
+  std::vector<TFConcatV2 *> concat_nodes;
+  std::vector<TensorName> input_names;
+
+  auto concat_node = graph->nodes()->create<TFConcatV2>();
+  TFConcatV2 *last_concat = concat_node;
+
+  // Queue node input update
+  concat_nodes.push_back(concat_node);              // used for LHS of connection -> %0
+  concat_nodes.push_back(concat_node);              // used for RHS of connection -> %1
+  input_names.push_back(TensorName(node.input(0))); // for first concat (%0) LHS
+  input_names.push_back(TensorName(node.input(1))); // for first concat (%1) RHS
+
+  for (int ni = 2; ni < num_inputs; ++ni)
+  {
+    auto concat_node_next = graph->nodes()->create<TFConcatV2>();
+
+    concat_nodes.push_back(concat_node_next);
+    input_names.push_back(TensorName(node.input(ni)));
+
+    // connect LHS as we know the nodes
+    concat_node_next->lhs(last_concat);
+
+    // update last concat node
+    last_concat = concat_node_next;
+  }
+
+  // register string-name to the last node as output of concat(s)
+  TensorName output_name(node.name(), 0);
+  tensor_names->enroll(output_name, last_concat);
+
+  // Find axis tensorflow::NodeDef and get the axis number
+  std::string axis_name = node.input(num_inputs);
+  const tensorflow::NodeDef *tfnode = nodedef->node(axis_name);
+  // assume data type is int32
+  assert(plier::tf::get_datatype_attr(*tfnode, "dtype") == tensorflow::DataType::DT_INT32);
+  const auto &tensor = plier::tf::get_tensor_attr(*tfnode, "value");
+  assert(tensor.int_val_size() == 1);
+  auto axis_value_read = tensor.int_val(0);
+
+  // set axis for all concat(s) as temporary data
+  // as the first and the second items are actually the same one, skip it.
+  std::vector<TFConcatV2 *>::iterator iter = concat_nodes.begin();
+  for (++iter; iter != concat_nodes.end(); ++iter)
+  {
+    auto concat_node = *iter;
+    auto concat_data = stdex::make_unique<ConcatData>(axis_value_read);
+
+    concat_node->annot(std::move(concat_data));
+  }
+
+  // Input name queue is created like this in 'concat_nodes' and 'input_names'
+  // %0.lhs : %in[0].name
+  // %1.rhs : %in[1].name (as %0 == %1)
+  // %2.rhs : %in[2].name
+  // %3.rhs : %in[3].name
+  // ...
+  // %(N-2).rhs : %in[N-2].name
+  // %(N-1).rhs : %in[N-1].name
+  auto update = stdex::make_unique<TFConcatV2GraphUpdate>(concat_nodes, input_names);
+  updates->enroll(std::move(update));
+}
+
+} // namespace tf
+} // namespace moco
+
+// TODO move this block to upperside
+namespace
+{
+
 void ConcatV2GraphUpdate::input(const SymbolTable *tensor_names) const
 {
   int num_inputs = _names.size();
@@ -181,8 +325,26 @@ void ConcatV2GraphUpdate::input(const SymbolTable *tensor_names) const
   }
 }
 
-} // namespace tf
-} // namespace moco
+void TFConcatV2GraphUpdate::input(const SymbolTable *tensor_names) const
+{
+  int num_inputs = _names.size();
+  assert(num_inputs >= 2);
+  assert(num_inputs == _nodes.size());
+
+  loco::Node *target;
+  // do "%0.lhs : %in[0].name" connection
+  target = tensor_names->node(_names[0]);
+  _nodes[0]->lhs(target);
+
+  for (int i = 1; i < num_inputs; ++i)
+  {
+    // do "%i.rhs : %in[i].name" connections
+    target = tensor_names->node(_names[i]);
+    _nodes[i]->rhs(target);
+  }
+}
+
+} // namespace
 
 #include "GraphBuilderRegistry.h"
 
diff --git a/compiler/moco-tf/src/Op/Concat.h b/compiler/moco-tf/src/Op/Concat.h
new file mode 100644 (file)
index 0000000..6a5a857
--- /dev/null
@@ -0,0 +1,52 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __OP_CONCAT_H__
+#define __OP_CONCAT_H__
+
+#include "GraphBuilder.h"
+#include "ImportTarget.h"
+
+namespace moco
+{
+namespace tf
+{
+
+struct ConcatV2GraphBuilderBase : public GraphBuilder
+{
+  virtual ~ConcatV2GraphBuilderBase() = default;
+
+  bool validate(const tensorflow::NodeDef &) const final;
+};
+
+template <ImportTarget T> class ConcatV2GraphBuilderImpl;
+
+template <>
+struct ConcatV2GraphBuilderImpl<ImportTarget::Canonical> final : public ConcatV2GraphBuilderBase
+{
+  void build(const tensorflow::NodeDef &, GraphBuilderContext *) const final;
+};
+
+template <>
+struct ConcatV2GraphBuilderImpl<ImportTarget::TensorFlow> final : public ConcatV2GraphBuilderBase
+{
+  void build(const tensorflow::NodeDef &, GraphBuilderContext *) const final;
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __OP_CONCAT_H__
index 5489994..339a8ca 100644 (file)
  * limitations under the License.
  */
 
+#include "Concat.h"
+
+#include "IR/TFConcatV2.h"
+
 #include "TestHelper.h"
 
 #include "Importer.h"
@@ -23,6 +27,7 @@
 
 #include <gtest/gtest.h>
 
+using namespace moco::tf;
 using namespace moco::tf::test;
 
 namespace
@@ -157,12 +162,25 @@ TEST(TensorFlowImport, concat_01)
   EXPECT_TRUE(plier::tf::parse_graphdef(concat_01_pbtxtdata, graph_def));
   std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
 
+  // Test "ConcatV2GraphBuilderImpl<ImportTarget::Canonical>"
+  {
+    // TODO fix indent
+    // clang-format off
+
   // what to test:
   // - there should exist TensorConcat
   // - lhs() should not be nullptr
   // - rhs() should not be nullptr
   // - axis() should match
 
+    using ConcatV2GraphBuilder = ConcatV2GraphBuilderImpl<ImportTarget::Canonical>;
+
+    moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()};
+    r.add("ConcatV2", stdex::make_unique<ConcatV2GraphBuilder>());
+    moco::tf::Importer importer{&r};
+
+    std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
   loco::TensorConcat *concat_node =
       moco::tf::test::find_first_node_bytype<loco::TensorConcat>(graph.get());
 
@@ -170,6 +188,33 @@ TEST(TensorFlowImport, concat_01)
   ASSERT_NE(concat_node->lhs(), nullptr);
   ASSERT_NE(concat_node->rhs(), nullptr);
   ASSERT_EQ(concat_node->axis(), 0);
+
+    // clang-format on
+  }
+
+  // Test "ConcatV2GraphBuilderImpl<ImportTarget::Tensorflow>"
+  {
+    // what to test:
+    // - there should exist TensorConcat
+    // - lhs() should not be nullptr
+    // - rhs() should not be nullptr
+    // - axis() should match
+
+    using ConcatV2GraphBuilder = ConcatV2GraphBuilderImpl<ImportTarget::TensorFlow>;
+
+    moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()};
+    r.add("ConcatV2", stdex::make_unique<ConcatV2GraphBuilder>());
+    moco::tf::Importer importer{&r};
+
+    std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+    auto concat_node = moco::tf::test::find_first_node_bytype<moco::tf::TFConcatV2>(graph.get());
+
+    ASSERT_NE(concat_node, nullptr);
+    ASSERT_NE(concat_node->lhs(), nullptr);
+    ASSERT_NE(concat_node->rhs(), nullptr);
+    ASSERT_EQ(concat_node->axis(), 0);
+  }
 }
 
 namespace
@@ -335,7 +380,11 @@ TEST(TensorFlowImport, concat_02)
 
   tensorflow::GraphDef graph_def;
   EXPECT_TRUE(plier::tf::parse_graphdef(concat_02_pbtxtdata, graph_def));
-  std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+  // Test "ConcatV2GraphBuilderImpl<ImportTarget::Canonical>"
+  {
+    // TODO fix indent
+    // clang-format off
 
   // what to test: Concat has 3 inputs --> Importer creates 2 TensorConcat
   // - there should exist two TensorConcat
@@ -345,6 +394,14 @@ TEST(TensorFlowImport, concat_02)
   // - rhs() of #2 should not be nullptr
   // - axis() should match
 
+    using ConcatV2GraphBuilder = ConcatV2GraphBuilderImpl<ImportTarget::Canonical>;
+
+    moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()};
+    r.add("ConcatV2", stdex::make_unique<ConcatV2GraphBuilder>());
+    moco::tf::Importer importer{&r};
+
+    std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
   std::vector<loco::TensorConcat *> concat_nodes =
       moco::tf::test::find_nodes_bytype<loco::TensorConcat>(graph.get());
   ASSERT_EQ(concat_nodes.size(), 2);
@@ -360,6 +417,44 @@ TEST(TensorFlowImport, concat_02)
   ASSERT_TRUE(concat_node0->lhs() == concat_node1 || concat_node1->lhs() == concat_node0);
   ASSERT_EQ(concat_node0->axis(), 0);
   ASSERT_EQ(concat_node1->axis(), 0);
+
+    // clang-format on
+  }
+
+  // Test "ConcatV2GraphBuilderImpl<ImportTarget::TensorFlow>"
+  {
+    // what to test: Concat has 3 inputs --> Importer creates 2 TFConcatV2
+    // - there should exist two TFConcatV2
+    // - lhs() of #1 should not be nullptr
+    // - rhs() of #1 should not be nullptr
+    // - lhs() of #2 should be #1
+    // - rhs() of #2 should not be nullptr
+    // - axis() should match
+
+    using ConcatV2GraphBuilder = ConcatV2GraphBuilderImpl<ImportTarget::TensorFlow>;
+
+    moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()};
+    r.add("ConcatV2", stdex::make_unique<ConcatV2GraphBuilder>());
+    moco::tf::Importer importer{&r};
+
+    std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+    std::vector<moco::tf::TFConcatV2 *> concat_nodes =
+        moco::tf::test::find_nodes_bytype<moco::tf::TFConcatV2>(graph.get());
+    ASSERT_EQ(concat_nodes.size(), 2);
+    moco::tf::TFConcatV2 *concat_node0 = concat_nodes.at(0);
+    moco::tf::TFConcatV2 *concat_node1 = concat_nodes.at(1);
+
+    ASSERT_NE(concat_node0, nullptr);
+    ASSERT_NE(concat_node1, nullptr);
+    ASSERT_NE(concat_node0->lhs(), nullptr);
+    ASSERT_NE(concat_node0->rhs(), nullptr);
+    ASSERT_NE(concat_node1->lhs(), nullptr);
+    ASSERT_NE(concat_node1->rhs(), nullptr);
+    ASSERT_TRUE(concat_node0->lhs() == concat_node1 || concat_node1->lhs() == concat_node0);
+    ASSERT_EQ(concat_node0->axis(), 0);
+    ASSERT_EQ(concat_node1->axis(), 0);
+  }
 }
 
 namespace
@@ -475,7 +570,10 @@ TEST(TensorFlowImport, concat_03)
 
   tensorflow::GraphDef graph_def;
   EXPECT_TRUE(plier::tf::parse_graphdef(concat_03_pbtxtdata, graph_def));
-  std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+  // Test "ConcatV2GraphBuilderImpl<ImportTarget::Canonical>"
+  {
+    // TODO fix indent
+    // clang-format off
 
   // what to test: minus axis value validation
   // - there should exist a TensorConcat
@@ -483,6 +581,14 @@ TEST(TensorFlowImport, concat_03)
   // - rhs() should not be nullptr
   // - axis() should match 2 + (-1), where 2 came from rank of input(s)
 
+    using ConcatV2GraphBuilder = ConcatV2GraphBuilderImpl<ImportTarget::Canonical>;
+
+    moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()};
+    r.add("ConcatV2", stdex::make_unique<ConcatV2GraphBuilder>());
+    moco::tf::Importer importer{&r};
+
+    std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
   loco::TensorConcat *concat_node =
       moco::tf::test::find_first_node_bytype<loco::TensorConcat>(graph.get());
 
@@ -490,4 +596,31 @@ TEST(TensorFlowImport, concat_03)
   ASSERT_NE(concat_node->lhs(), nullptr);
   ASSERT_NE(concat_node->rhs(), nullptr);
   ASSERT_EQ(concat_node->axis(), (2 + (-1)));
+
+    // clang-format on
+  }
+
+  // Test "ConcatV2GraphBuilderImpl<ImportTarget::TensorFlow>"
+  {
+    // what to test: minus axis value validation
+    // - there should exist a TFConcatV2
+    // - lhs() should not be nullptr
+    // - rhs() should not be nullptr
+    // - axis() should match 2 + (-1), where 2 came from rank of input(s)
+
+    using ConcatV2GraphBuilder = ConcatV2GraphBuilderImpl<ImportTarget::TensorFlow>;
+
+    moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()};
+    r.add("ConcatV2", stdex::make_unique<ConcatV2GraphBuilder>());
+    moco::tf::Importer importer{&r};
+
+    std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+    auto concat_node = moco::tf::test::find_first_node_bytype<moco::tf::TFConcatV2>(graph.get());
+
+    ASSERT_NE(concat_node, nullptr);
+    ASSERT_NE(concat_node->lhs(), nullptr);
+    ASSERT_NE(concat_node->rhs(), nullptr);
+    ASSERT_EQ(concat_node->axis(), (2 + (-1)));
+  }
 }