[moco-tf] GraphBuilder for TFConcatV2 (#6673)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Tue, 20 Aug 2019 00:41:20 +0000 (09:41 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Tue, 20 Aug 2019 00:41:20 +0000 (09:41 +0900)
This will fill GraphBuilder for TFConcatV2 and tests. Also ShapeInference implementation that is required by the test code.

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
compiler/moco-tf/src/Op/Concat.cpp
compiler/moco-tf/src/Op/Concat.test.cpp
compiler/moco-tf/src/Transforms/FixShapeTransform.cpp

index 2b18b07..33788cd 100644 (file)
@@ -55,6 +55,21 @@ private:
   std::vector<TensorName> _names;
 };
 
+class TFConcatV2GraphUpdate final : public GraphUpdate
+{
+public:
+  TFConcatV2GraphUpdate(moco::tf::TFConcatV2 *node, std::vector<TensorName> names)
+      : _node(node), _names(names)
+  {
+  }
+
+  void input(const SymbolTable *) const override;
+
+private:
+  moco::tf::TFConcatV2 *_node;
+  std::vector<TensorName> _names;
+};
+
 void ConcatV2GraphUpdate::input(const SymbolTable *tensor_names) const
 {
   int num_inputs = _names.size();
@@ -74,6 +89,22 @@ void ConcatV2GraphUpdate::input(const SymbolTable *tensor_names) const
   }
 }
 
+void TFConcatV2GraphUpdate::input(const SymbolTable *tensor_names) const
+{
+  uint32_t num_values = _names.size() - 1; // exclude axis
+  assert(num_values >= 1);
+
+  for (uint32_t i = 0; i < num_values; ++i)
+  {
+    auto input_node = tensor_names->node(_names[i]);
+    assert(input_node != nullptr);
+    _node->values(i, input_node);
+  }
+  auto axis_node = tensor_names->node(_names[num_values]);
+  assert(axis_node != nullptr);
+  _node->axis(axis_node);
+}
+
 } // namespace
 
 namespace moco
@@ -213,18 +244,28 @@ void ConcatV2GraphBuilderImpl<ImportTarget::Canonical>::build(const tensorflow::
 void ConcatV2GraphBuilderImpl<ImportTarget::TensorFlow>::build(const tensorflow::NodeDef &node,
                                                                GraphBuilderContext *context) const
 {
-  /**
-   * @note  This implementation may change when TFConcatV2 inherits from
-   *        something like VariableArity.
-   */
-
   loco::Graph *graph = context->graph();
   NodeDefTable *nodedef = context->nodedef();
   SymbolTable *tensor_names = context->tensor_names();
   UpdateQueue *updates = context->updates();
 
-  // TODO implement
-  throw std::runtime_error("NYI ConcatV2GraphBuilderImpl");
+  const int num_inputs = node.input_size() - 1;
+  std::vector<TensorName> input_names;
+  auto concat_node = graph->nodes()->create<TFConcatV2>(num_inputs);
+
+  for (int ni = 0; ni < num_inputs; ++ni)
+  {
+    input_names.push_back(TensorName(node.input(ni)));
+  }
+  // last one is the axis
+  input_names.push_back(TensorName(node.input(num_inputs)));
+
+  // register string-name to the last node as output of concat(s)
+  TensorName output_name(node.name(), 0);
+  tensor_names->enroll(output_name, concat_node);
+
+  auto update = stdex::make_unique<TFConcatV2GraphUpdate>(concat_node, input_names);
+  updates->enroll(std::move(update));
 }
 
 } // namespace tf
index 53daa63..c7dab10 100644 (file)
@@ -193,7 +193,29 @@ TEST(TensorFlowImport, concat_01)
   }
 
   // Test "ConcatV2GraphBuilderImpl<ImportTarget::Tensorflow>"
-  // TODO implement test
+  {
+    // what to test:
+    // - there should exist TFConcatV2
+    // - there should be two values
+    // - values(idx) should not be nullptr
+    // - axis() should not be nullptr
+
+    using ConcatV2GraphBuilder = ConcatV2GraphBuilderImpl<ImportTarget::TensorFlow>;
+
+    moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()};
+    r.add("ConcatV2", stdex::make_unique<ConcatV2GraphBuilder>());
+    moco::tf::Importer importer{&r};
+
+    std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+    auto concat_node = moco::tf::test::find_first_node_bytype<moco::tf::TFConcatV2>(graph.get());
+
+    ASSERT_NE(concat_node, nullptr);
+    ASSERT_EQ(concat_node->num_values(), 2);
+    ASSERT_NE(concat_node->values(0), nullptr);
+    ASSERT_NE(concat_node->values(1), nullptr);
+    ASSERT_NE(concat_node->axis(), nullptr);
+  }
 }
 
 namespace
@@ -401,7 +423,29 @@ TEST(TensorFlowImport, concat_02)
   }
 
   // Test "ConcatV2GraphBuilderImpl<ImportTarget::TensorFlow>"
-  // TODO implement test
+  {
+    // what to test: TFConcatV2 has 3 inputs
+    // - there should exist TFConcatV2
+    // - values(idx) should not be nullptr
+    // - axis() should not be nullptr
+
+    using ConcatV2GraphBuilder = ConcatV2GraphBuilderImpl<ImportTarget::TensorFlow>;
+
+    moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()};
+    r.add("ConcatV2", stdex::make_unique<ConcatV2GraphBuilder>());
+    moco::tf::Importer importer{&r};
+
+    std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+    auto concat_node = moco::tf::test::find_first_node_bytype<moco::tf::TFConcatV2>(graph.get());
+
+    ASSERT_NE(concat_node, nullptr);
+    ASSERT_EQ(concat_node->num_values(), 3);
+    ASSERT_NE(concat_node->values(0), nullptr);
+    ASSERT_NE(concat_node->values(1), nullptr);
+    ASSERT_NE(concat_node->values(2), nullptr);
+    ASSERT_NE(concat_node->axis(), nullptr);
+  }
 }
 
 namespace
@@ -547,6 +591,5 @@ TEST(TensorFlowImport, concat_03)
     // clang-format on
   }
 
-  // Test "ConcatV2GraphBuilderImpl<ImportTarget::TensorFlow>"
-  // TODO implement test
+  // Validation of axis value is skipped for TF
 }
index 59314e1..280e3e7 100644 (file)
@@ -788,13 +788,160 @@ bool fix_shape(moco::tf::TFBiasAdd *node)
   return copy_shapedata(value, node);
 }
 
+template <class CONST_CLASS> bool valid_scala_value(CONST_CLASS *node)
+{
+  LOGGER(l);
+
+  auto shapedata = node->template annot<ShapeInferenceData>();
+  assert(shapedata != nullptr);
+
+  if (node->dtype() != loco::DataType::S32)
+  {
+    INFO(l) << "valid_scala_value not S32";
+    return false;
+  }
+
+  auto tensor_shape = shapedata->tensor_shape();
+  if (!(tensor_shape.rank() == 0 || tensor_shape.rank() == 1))
+  {
+    INFO(l) << "valid_scala_value rank not 0/1 : " << tensor_shape.rank();
+    return false;
+  }
+
+  return true;
+}
+
+template <class CONST_CLASS> int32_t scala_value(CONST_CLASS *node)
+{
+  auto shapedata = node->template annot<ShapeInferenceData>();
+  assert(shapedata != nullptr);
+
+  assert(node->dtype() == loco::DataType::S32);
+
+  auto tensor_shape = shapedata->tensor_shape();
+  assert(tensor_shape.rank() == 0 || tensor_shape.rank() == 1);
+
+  return node->template at<loco::DataType::S32>(0);
+}
+
 bool fix_shape(moco::tf::TFConcatV2 *node)
 {
-  (void)node;
+  LOGGER(l);
+
+  if (node->annot<ShapeInferenceData>() != nullptr)
+  {
+    // shape inference is already done for TFConcatV2
+    INFO(l) << "Fix shape TFConcatV2 already done";
+    return false;
+  }
+  // ConcatData should be null
+  assert(node->annot<ConcatData>() == nullptr);
+
+  // Check shape inference data are all ready
+  // Check shape rank are all same
+  auto value_a = node->values(0);
+  auto value_a_shapedata = value_a->annot<ShapeInferenceData>();
+  if (value_a_shapedata == nullptr)
+  {
+    // shape inference is not ready for this value
+    INFO(l) << "Fix shape TFConcatV2 value 0 shape_data not ready";
+    return false;
+  }
+  uint32_t a_rank = value_a_shapedata->rank();
 
-  throw std::runtime_error("NYI fix_shape TFConcatV2");
+  uint32_t num_values = node->num_values();
+  for (uint32_t ni = 1; ni < num_values; ++ni)
+  {
+    auto value_b = node->values(ni);
+    auto value_b_shapedata = value_b->annot<ShapeInferenceData>();
+    if (value_b_shapedata == nullptr)
+    {
+      // shape inference is not ready for this value
+      INFO(l) << "Fix shape TFConcatV2 value " << ni << " shape_data not ready";
+      return false;
+    }
+    uint32_t b_rank = value_b_shapedata->rank();
+    assert(a_rank == b_rank);
+  }
 
-  return false;
+  // check for axis
+  auto axis_node = node->axis();
+  auto axis_shapedata = axis_node->annot<ShapeInferenceData>();
+  if (axis_shapedata == nullptr)
+  {
+    // shape inference is not ready for axis_node
+    INFO(l) << "Fix shape TFConcatV2 axis shape_data not ready";
+    return false;
+  }
+
+  int32_t axis_value = 0;
+  bool axis_available = false;
+  {
+    // check for axis is TFConst
+    auto tfconst = dynamic_cast<moco::tf::TFConst *>(axis_node);
+    if (tfconst != nullptr)
+    {
+      if (valid_scala_value(tfconst))
+      {
+        axis_value = scala_value(tfconst);
+        axis_available = true;
+      }
+    }
+  }
+  {
+    // check for axis is ConstGen
+    auto constgen = dynamic_cast<loco::ConstGen *>(axis_node);
+    if (constgen != nullptr)
+    {
+      if (valid_scala_value(constgen))
+      {
+        axis_value = scala_value(constgen);
+        axis_available = true;
+      }
+    }
+  }
+  if (!axis_available)
+  {
+    // we cannot find a valid axis value
+    INFO(l) << "Fix shape TFConcatV2 axis_available false";
+    return false;
+  }
+
+  auto concat_data = stdex::make_unique<ConcatData>(axis_value);
+  node->annot(std::move(concat_data));
+
+  uint32_t axis_absolute = (axis_value >= 0) ? axis_value : (int32_t)a_rank + axis_value;
+
+  auto shape_data = stdex::make_unique<ShapeInferenceData>();
+  shape_data->rank(a_rank);
+
+  for (uint32_t index = 0; index < a_rank; ++index)
+  {
+    if (value_a_shapedata->dim(index).known())
+    {
+      uint32_t dim = value_a_shapedata->dim(index).value();
+      if (index == axis_absolute)
+      {
+        uint32_t dim_acc = dim;
+        for (uint32_t ni = 1; ni < num_values; ++ni)
+        {
+          auto value_b = node->values(ni);
+          auto value_b_shapedata = value_b->annot<ShapeInferenceData>();
+          assert(value_b_shapedata->dim(index).known());
+          dim_acc += value_b_shapedata->dim(index).value();
+        }
+        dim = dim_acc;
+      }
+      shape_data->dim(index) = dim;
+    }
+    else
+      shape_data->dim(index).unset();
+  }
+  node->annot(std::move(shape_data));
+
+  INFO(l) << "Fix TFConcat shape = " << node->annot<ShapeInferenceData>();
+
+  return true;
 }
 
 bool fix_shape(moco::tf::TFConst *node)