[moco-tf] Implement shape inference for TFReshape (#5914)
author박천교/On-Device Lab(SR)/Engineer/삼성전자 <ch.bahk@samsung.com>
Sun, 28 Jul 2019 23:52:34 +0000 (08:52 +0900)
committer박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Sun, 28 Jul 2019 23:52:34 +0000 (08:52 +0900)
This commit implements shape inference for TFReshape on fix_shape().

Signed-off-by: Cheongyo Bahk <ch.bahk@samsung.com>
compiler/moco-tf/src/Transforms/FixShapeTransform.cpp

index 87a2c9e..37cec53 100644 (file)
@@ -745,8 +745,56 @@ bool fix_shape(moco::tf::TFMul *node)
 
 bool fix_shape(moco::tf::TFReshape *node)
 {
-  // TODO implement
-  throw std::runtime_error("NYI fix_shape TFReshape");
+  auto shapedata = node->annot<ShapeInferenceData>();
+  if (shapedata != nullptr)
+  {
+    // shape inference is already done for TFReshape
+    return false;
+  }
+
+  // For now, we only consider Fixed Reshape, i.e. Reshape with determined
+  //      'shape' input. So here we only support case when 'shape' input of
+  //      TFReshape is TFConst. If 'shape' input is not TFConst, another
+  //      transform (e.g. constant folding) should be done beforehand to make
+  //      it TFConst.
+  // TODO Support dynamic Reshape
+  // Note that 'shape()' here is 'shape' input, not node's shape information
+  auto const_shape_input = dynamic_cast<moco::tf::TFConst *>(node->shape());
+  if (!const_shape_input)
+  {
+    // 'shape' input of TFReshape is not TFConst, try next time when it becomes TFConst
+    return false;
+  }
+
+  // 'Shape' input should be integer tensor of rank 1, e.g. [2, 3, 4] or [3, -1]
+  assert(const_shape_input->dtype() == loco::DataType::S32);
+  assert(const_shape_input->rank() == 1);
+
+  auto shape_rank = const_shape_input->dim(0).value();
+  assert(shape_rank > 0);
+
+  loco::TensorShape shape_data;
+  shape_data.rank(shape_rank);
+  for (uint32_t axis = 0; axis < shape_rank; ++axis)
+  {
+    shape_data.dim(axis) = const_shape_input->at<loco::DataType::S32>(axis);
+  }
+
+  // TODO Compare 'tensor' input and validate coherency?
+  // Not sure this is appropriate stage for this task.
+
+  auto shape_annot = stdex::make_unique<ShapeInferenceData>();
+  shape_annot->tensor_shape(shape_data);
+  node->annot(std::move(shape_annot));
+
+  {
+    LOGGER(l);
+    auto shapedata = node->annot<ShapeInferenceData>();
+    assert(shapedata != nullptr);
+    INFO(l) << "Fix TFReshape shape = " << shapedata->tensor_shape();
+  }
+
+  return true;
 }
 
 bool fix_shape(moco::tf::TFRsqrt *node)