[moco-tf] Clear ConcatV2 fix shape (#6523)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Tue, 13 Aug 2019 04:52:01 +0000 (13:52 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Tue, 13 Aug 2019 04:52:01 +0000 (13:52 +0900)
This will erase ConcatV2 fix shape to reduce complexity of diff for migration to using VariadicArityNode

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
compiler/moco-tf/src/Transforms/FixShapeTransform.cpp

index d3f45b2..fc89e3d 100644 (file)
@@ -790,66 +790,11 @@ bool fix_shape(moco::tf::TFBiasAdd *node)
 
 bool fix_shape(moco::tf::TFConcatV2 *node)
 {
-  auto concat_data = node->annot<ConcatData>();
-  if (concat_data == nullptr)
-  {
-    // shape inference is already done for TFConcatV2
-    assert(node->annot<ShapeInferenceData>() != nullptr);
-    return false;
-  }
-  assert(node->annot<ShapeInferenceData>() == nullptr);
+  (void)node;
 
-  auto lhs = node->lhs();
-  auto rhs = node->rhs();
-  auto lhs_shapedata = lhs->annot<ShapeInferenceData>();
-  auto rhs_shapedata = rhs->annot<ShapeInferenceData>();
-  if (lhs_shapedata == nullptr || rhs_shapedata == nullptr)
-  {
-    // postpone as previous input node(s) hasn't been processed.
-    // this will return false as there was nothing changed, but the input of this
-    // node should be changed and from that this method should be called again.
-    // if not, the network may have some problem and the final output may not have
-    // the right shape value and we can identify with some validation at final stage.
-    return false;
-  }
-
-  uint32_t lhs_rank = lhs_shapedata->rank();
-  uint32_t rhs_rank = rhs_shapedata->rank();
-  assert(lhs_rank == rhs_rank);
+  throw std::runtime_error("NYI fix_shape TFConcatV2");
 
-  int32_t axis_tf = concat_data->axis();
-  if (axis_tf < 0)
-  {
-    axis_tf = static_cast<int32_t>(lhs_rank) + axis_tf;
-  }
-  assert(0 <= axis_tf && axis_tf < static_cast<int32_t>(lhs_rank));
-  // clear annotation ConcatData
-  node->annot<ConcatData>(nullptr);
-
-  uint32_t axis_loco = static_cast<uint32_t>(axis_tf);
-  node->axis(axis_loco);
-
-  // Set ShapeInferenceData for TensorConcat
-  auto shape_data = stdex::make_unique<ShapeInferenceData>();
-  shape_data->rank(lhs_rank);
-  for (uint32_t index = 0; index < lhs_rank; ++index)
-  {
-    uint32_t lhs_dim = lhs_shapedata->dim(index).value();
-    uint32_t rhs_dim = rhs_shapedata->dim(index).value();
-    // "lhs_dim == rhs_dim" should hold when "index != axis_loco"
-    // or doesn't care when "index == axis_loco"
-    assert(index == axis_loco || lhs_dim == rhs_dim);
-
-    uint32_t new_dim = (index == axis_loco) ? lhs_dim + rhs_dim : lhs_dim;
-
-    if (lhs_shapedata->dim(index).known())
-      shape_data->dim(index) = new_dim;
-    else
-      shape_data->dim(index).unset();
-  }
-  node->annot(std::move(shape_data));
-
-  return true;
+  return false;
 }
 
 bool fix_shape(moco::tf::TFConst *node)