[moco/tf] Add Concat node converter (#3671)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Tue, 4 Jun 2019 08:11:21 +0000 (17:11 +0900)
committer박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Tue, 4 Jun 2019 08:11:21 +0000 (17:11 +0900)
* [moco/tf] Add Concat node converter

This will add Concat node converter for TF loading

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
* rename

* update comments

contrib/moco/lib/frontend/tf/src/Op/Concat.cpp [new file with mode: 0644]
contrib/moco/lib/frontend/tf/src/Op/Concat.test.cpp [new file with mode: 0644]

diff --git a/contrib/moco/lib/frontend/tf/src/Op/Concat.cpp b/contrib/moco/lib/frontend/tf/src/Op/Concat.cpp
new file mode 100644 (file)
index 0000000..90e92ac
--- /dev/null
@@ -0,0 +1,182 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Convert.h"
+#include "GraphBuilder.h"
+#include "GraphBuilderContext.h"
+
+#include "Annotations/ConcatData.h"
+
+#include <loco.h>
+#include <stdex/Memory.h>
+
+#include <tensorflow/core/framework/graph.pb.h>
+
+#include <cassert>
+#include <stdexcept>
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief GraphBuilder for Concat node of Tensor
+ */
+class ConcatV2GraphBuilder final : public GraphBuilder
+{
+public:
+  bool validate(const tensorflow::NodeDef &) const override;
+  void build(const tensorflow::NodeDef &, GraphBuilderContext *) const override;
+};
+
+class ConcatV2GraphUpdate final : public GraphUpdate
+{
+public:
+  ConcatV2GraphUpdate(std::vector<loco::TensorConcat *> nodes, std::vector<std::string> names)
+      : _nodes(nodes), _names(names)
+  {
+  }
+
+  void input(const SymbolTable *) const override;
+
+private:
+  std::vector<loco::TensorConcat *> _nodes;
+  std::vector<std::string> _names;
+};
+
+bool ConcatV2GraphBuilder::validate(const tensorflow::NodeDef &node) const { return true; }
+
+void ConcatV2GraphBuilder::build(const tensorflow::NodeDef &node,
+                                 GraphBuilderContext *context) const
+{
+  assert(context != nullptr);
+
+  loco::Graph *graph = context->graph();
+  NodeDefTable *nodedef = context->nodedef();
+  SymbolTable *nodes = context->nodes();
+  UpdateQueue *updates = context->updates();
+
+  // Concat has 2 or more inputs and loco TensorConcat is fixed to 2 inputs
+  // for arbitrary N inputs (beginning from 0), TensorConcat will be created
+  // as follows;
+  // %0 = TensorConcat(%in[0], %in[1])
+  // %1 = %0 --> this is to match index of input name
+  // %2 = TensorConcat(%1, %in[2])
+  // ...
+  // %(N-1) = TensorConcat(%(N-2), %in[N-1]))
+  // %N = TensorConcat(%(N-1), %in[N]))
+  //
+  // Output of this sub graph will be set to %N with node.name()
+  //
+  // As we know that each input exist, one of input(lhs) can be linked while creating
+  // %2.lhs = %1
+  // %3.lhs = %2
+  // ...
+  // %(N-1).lhs = %(N-2)
+  // %N.lhs = %(N-1)
+
+  // Queue node input update
+  // Concat node SHOULD have 3 or more inputs, that is 2 + axis
+  const int num_inputs = node.input_size() - 1;
+  assert(num_inputs >= 2);
+  assert(num_inputs == get_int_attr(node, "N"));
+
+  std::vector<loco::TensorConcat *> concat_nodes;
+  std::vector<std::string> input_names;
+
+  auto concat_node = graph->nodes()->create<loco::TensorConcat>();
+  loco::TensorConcat *last_concat = concat_node;
+
+  concat_nodes.push_back(concat_node);  // used for LHS of connection -> %0
+  concat_nodes.push_back(concat_node);  // used for RHS of connection -> %1
+  input_names.push_back(node.input(0)); // for first concat (%0) LHS
+  input_names.push_back(node.input(1)); // for first concat (%1) RHS
+
+  for (int ni = 2; ni < num_inputs; ++ni)
+  {
+    auto concat_node_next = graph->nodes()->create<loco::TensorConcat>();
+
+    concat_nodes.push_back(concat_node_next);
+    input_names.push_back(node.input(ni));
+
+    // connect LHS as we know the nodes
+    concat_node_next->lhs(last_concat);
+
+    // update last concat node
+    last_concat = concat_node_next;
+  }
+
+  // register string-name to the last node as output of concat(s)
+  nodes->enroll(node.name(), last_concat);
+
+  // Find axis tensorflow::NodeDef and get the axis number
+  std::string axis_name = node.input(num_inputs);
+  const tensorflow::NodeDef *tfnode = nodedef->node(axis_name);
+  // assume data type is int32
+  assert(get_datatype_attr(*tfnode, "dtype") == tensorflow::DataType::DT_INT32);
+  const auto &tensor = get_tensor_attr(*tfnode, "value");
+  assert(tensor.int_val_size() == 1);
+  auto axis_value_read = tensor.int_val(0);
+
+  // set axis for all concat(s) as temporary data
+  // as the first and the second items are actually the same one, skip it.
+  std::vector<loco::TensorConcat *>::iterator iter = concat_nodes.begin();
+  for (++iter; iter != concat_nodes.end(); ++iter)
+  {
+    auto concat_node = *iter;
+    auto concat_data = stdex::make_unique<ConcatData>(axis_value_read);
+
+    concat_node->annot(std::move(concat_data));
+  }
+
+  // Input name queue is created like this in 'concat_nodes' and 'input_names'
+  // %0.lhs : %in[0].name
+  // %1.rhs : %in[1].name (as %0 == %1)
+  // %2.rhs : %in[2].name
+  // %3.rhs : %in[3].name
+  // ...
+  // %(N-2).rhs : %in[N-2].name
+  // %(N-1).rhs : %in[N-1].name
+  auto update = stdex::make_unique<ConcatV2GraphUpdate>(concat_nodes, input_names);
+  updates->enroll(std::move(update));
+}
+
+void ConcatV2GraphUpdate::input(const SymbolTable *table) const
+{
+  int num_inputs = _names.size();
+  assert(num_inputs >= 2);
+  assert(num_inputs == _nodes.size());
+
+  loco::Node *target;
+  // do "%0.lhs : %in[0].name" connection
+  target = table->node(_names[0]);
+  _nodes[0]->lhs(target);
+
+  for (int i = 1; i < num_inputs; ++i)
+  {
+    // do "%i.rhs : %in[i].name" connections
+    target = table->node(_names[i]);
+    _nodes[i]->rhs(target);
+  }
+}
+
+} // namespace tf
+} // namespace moco
+
+#include "GraphBuilderRegistry.h"
+
+REGISTER_OP_BUILDER(ConcatV2, ConcatV2GraphBuilder)
diff --git a/contrib/moco/lib/frontend/tf/src/Op/Concat.test.cpp b/contrib/moco/lib/frontend/tf/src/Op/Concat.test.cpp
new file mode 100644 (file)
index 0000000..da11bab
--- /dev/null
@@ -0,0 +1,495 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TestHelper.h"
+
+#include <moco/tf/Frontend.h>
+
+#include <loco.h>
+
+#include <gtest/gtest.h>
+
+#include <cstring>
+
+using namespace moco::tf::test;
+
+namespace
+{
+
+// clang-format off
+const char *concat_01_pbtxtdata = STRING_CONTENT(
+node {
+  name: "Input01"
+  op: "Const"
+  attr {
+    key: "dtype"
+    value {
+      type: DT_FLOAT
+    }
+  }
+  attr {
+    key: "value"
+    value {
+      tensor {
+        dtype: DT_FLOAT
+        tensor_shape {
+          dim {
+            size: 2
+          }
+          dim {
+            size: 3
+          }
+        }
+        float_val: 1
+        float_val: 2
+        float_val: 3
+        float_val: 4
+        float_val: 5
+        float_val: 6
+      }
+    }
+  }
+}
+node {
+  name: "Input02"
+  op: "Const"
+  attr {
+    key: "dtype"
+    value {
+      type: DT_FLOAT
+    }
+  }
+  attr {
+    key: "value"
+    value {
+      tensor {
+        dtype: DT_FLOAT
+        tensor_shape {
+          dim {
+            size: 2
+          }
+          dim {
+            size: 3
+          }
+        }
+        float_val: 7
+        float_val: 8
+        float_val: 9
+        float_val: 10
+        float_val: 11
+        float_val: 12
+      }
+    }
+  }
+}
+node {
+  name: "Axis"
+  op: "Const"
+  attr {
+    key: "dtype"
+    value {
+      type: DT_INT32
+    }
+  }
+  attr {
+    key: "value"
+    value {
+      tensor {
+        dtype: DT_INT32
+        tensor_shape {
+        }
+        int_val: 0
+      }
+    }
+  }
+}
+node {
+  name: "Concat"
+  op: "ConcatV2"
+  input: "Input01"
+  input: "Input02"
+  input: "Axis"
+  attr {
+    key: "N"
+    value {
+      i: 2
+    }
+  }
+  attr {
+    key: "T"
+    value {
+      type: DT_FLOAT
+    }
+  }
+  attr {
+    key: "Tidx"
+    value {
+      type: DT_INT32
+    }
+  }
+}
+);
+// clang-format on
+
+} // namespace
+
+TEST(TensorFlowFrontend, concat_01)
+{
+  moco::tf::Frontend frontend;
+  moco::tf::ModelSignature signature;
+
+  imemstream mempb(concat_01_pbtxtdata, std::strlen(concat_01_pbtxtdata));
+
+  signature.add_output("Concat");
+
+  std::unique_ptr<loco::Graph> graph =
+      frontend.load(signature, &mempb, moco::tf::Frontend::FileType::Text);
+
+  // check all nodes are created
+  // - 0 input, 3 const, 1 concat, 1 output
+  loco::Graph::NodeContext *nodes = graph->nodes();
+  ASSERT_EQ(nodes->size(), 5);
+
+  loco::Graph::InputContext *inputs = graph->inputs();
+  ASSERT_EQ(inputs->size(), 0);
+
+  loco::Graph::OutputContext *outputs = graph->outputs();
+  ASSERT_EQ(outputs->size(), 1);
+
+  loco::TensorConcat *concat_node = nullptr;
+  int const_count = 0;
+  for (int i = 0; i < nodes->size(); ++i)
+  {
+    auto node_1 = dynamic_cast<loco::TensorConcat *>(nodes->at(i));
+    if (node_1 != nullptr)
+    {
+      ASSERT_EQ(concat_node, nullptr);
+      concat_node = node_1;
+    }
+    auto node_2 = dynamic_cast<loco::ConstGen *>(nodes->at(i));
+    const_count += (node_2 != nullptr) ? 1 : 0;
+  }
+  ASSERT_NE(concat_node, nullptr);
+  ASSERT_EQ(const_count, 3);
+}
+
+namespace
+{
+
+// clang-format off
+const char *concat_02_pbtxtdata = STRING_CONTENT(
+node {
+  name: "Input01"
+  op: "Const"
+  attr {
+    key: "dtype"
+    value {
+      type: DT_FLOAT
+    }
+  }
+  attr {
+    key: "value"
+    value {
+      tensor {
+        dtype: DT_FLOAT
+        tensor_shape {
+          dim {
+            size: 2
+          }
+          dim {
+            size: 3
+          }
+        }
+        float_val: 1
+        float_val: 2
+        float_val: 3
+        float_val: 4
+        float_val: 5
+        float_val: 6
+      }
+    }
+  }
+}
+node {
+  name: "Input02"
+  op: "Const"
+  attr {
+    key: "dtype"
+    value {
+      type: DT_FLOAT
+    }
+  }
+  attr {
+    key: "value"
+    value {
+      tensor {
+        dtype: DT_FLOAT
+        tensor_shape {
+          dim {
+            size: 2
+          }
+          dim {
+            size: 3
+          }
+        }
+        float_val: 7
+        float_val: 8
+        float_val: 9
+        float_val: 10
+        float_val: 11
+        float_val: 12
+      }
+    }
+  }
+}
+node {
+  name: "Input03"
+  op: "Const"
+  attr {
+    key: "dtype"
+    value {
+      type: DT_FLOAT
+    }
+  }
+  attr {
+    key: "value"
+    value {
+      tensor {
+        dtype: DT_FLOAT
+        tensor_shape {
+          dim {
+            size: 2
+          }
+          dim {
+            size: 3
+          }
+        }
+        float_val: 13
+        float_val: 14
+        float_val: 15
+        float_val: 16
+        float_val: 17
+        float_val: 18
+      }
+    }
+  }
+}
+node {
+  name: "Axis"
+  op: "Const"
+  attr {
+    key: "dtype"
+    value {
+      type: DT_INT32
+    }
+  }
+  attr {
+    key: "value"
+    value {
+      tensor {
+        dtype: DT_INT32
+        tensor_shape {
+        }
+        int_val: 0
+      }
+    }
+  }
+}
+node {
+  name: "Concat"
+  op: "ConcatV2"
+  input: "Input01"
+  input: "Input02"
+  input: "Input03"
+  input: "Axis"
+  attr {
+    key: "N"
+    value {
+      i: 3
+    }
+  }
+  attr {
+    key: "T"
+    value {
+      type: DT_FLOAT
+    }
+  }
+  attr {
+    key: "Tidx"
+    value {
+      type: DT_INT32
+    }
+  }
+}
+);
+// clang-format on
+
+} // namespace
+
+TEST(TensorFlowFrontend, concat_02)
+{
+  moco::tf::Frontend frontend;
+  moco::tf::ModelSignature signature;
+
+  imemstream mempb(concat_02_pbtxtdata, std::strlen(concat_02_pbtxtdata));
+
+  signature.add_output("Concat");
+
+  std::unique_ptr<loco::Graph> graph =
+      frontend.load(signature, &mempb, moco::tf::Frontend::FileType::Text);
+
+  // check all nodes are created
+  // - 0 input, 4 const, 2 concat, 1 output
+  loco::Graph::NodeContext *nodes = graph->nodes();
+  ASSERT_EQ(nodes->size(), 7);
+
+  loco::Graph::InputContext *inputs = graph->inputs();
+  ASSERT_EQ(inputs->size(), 0);
+
+  loco::Graph::OutputContext *outputs = graph->outputs();
+  ASSERT_EQ(outputs->size(), 1);
+
+  loco::TensorConcat *concat_node[2] = {nullptr, nullptr};
+  int concat_count = 0;
+  int const_count = 0;
+  for (int i = 0; i < nodes->size(); ++i)
+  {
+    auto node_1 = dynamic_cast<loco::TensorConcat *>(nodes->at(i));
+    if (node_1 != nullptr)
+    {
+      ASSERT_EQ(concat_node[concat_count], nullptr);
+      concat_node[concat_count] = node_1;
+      concat_count++;
+    }
+    auto node_2 = dynamic_cast<loco::ConstGen *>(nodes->at(i));
+    const_count += (node_2 != nullptr) ? 1 : 0;
+  }
+  ASSERT_EQ(concat_count, 2);
+  ASSERT_NE(concat_node[0], nullptr);
+  ASSERT_NE(concat_node[1], nullptr);
+  ASSERT_EQ(const_count, 4);
+
+  EXPECT_TRUE(concat_node[1]->lhs() == concat_node[0] || concat_node[0]->lhs() == concat_node[1]);
+}
+
+namespace
+{
+
+// clang-format off
+const char *concat_03_pbtxtdata = STRING_CONTENT(
+node {
+  name: "Input01"
+  op: "Placeholder"
+  attr {
+    key: "dtype"
+    value {
+      type: DT_FLOAT
+    }
+  }
+  attr {
+    key: "shape"
+    value {
+      shape {
+        dim {
+          size: 2
+        }
+        dim {
+          size: 3
+        }
+      }
+    }
+  }
+}
+node {
+  name: "Input02"
+  op: "Placeholder"
+  attr {
+    key: "dtype"
+    value {
+      type: DT_FLOAT
+    }
+  }
+  attr {
+    key: "shape"
+    value {
+      shape {
+        dim {
+          size: 2
+        }
+        dim {
+          size: 3
+        }
+      }
+    }
+  }
+}
+node {
+  name: "Axis"
+  op: "Const"
+  attr {
+    key: "dtype"
+    value {
+      type: DT_INT32
+    }
+  }
+  attr {
+    key: "value"
+    value {
+      tensor {
+        dtype: DT_INT32
+        tensor_shape {
+        }
+        int_val: -1
+      }
+    }
+  }
+}
+node {
+  name: "Concat"
+  op: "ConcatV2"
+  input: "Input01"
+  input: "Input02"
+  input: "Axis"
+  attr {
+    key: "N"
+    value {
+      i: 2
+    }
+  }
+  attr {
+    key: "T"
+    value {
+      type: DT_FLOAT
+    }
+  }
+  attr {
+    key: "Tidx"
+    value {
+      type: DT_INT32
+    }
+  }
+}
+);
+// clang-format on
+
+} // namespace
+
+TEST(TensorFlowFrontend, concat_03)
+{
+  /**
+   * TODO Minus axis is not supported yet. Implement this when ready.
+   */
+}