[moco/tf] Implement fix_shape for TensorConcat (#3694)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Mon, 10 Jun 2019 23:43:47 +0000 (08:43 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Mon, 10 Jun 2019 23:43:47 +0000 (08:43 +0900)
* [moco/tf] Implement fix_shape for TensorConcat

This will implement fix_shape in shape inference for TensorConcat node

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
* add assert and comment

contrib/moco/lib/frontend/tf/src/Transforms/FixShapeTransform.cpp

index 2e83ac7..3d1a5b4 100644 (file)
@@ -16,6 +16,7 @@
 
 #include "FixShapeTransform.h"
 
+#include "Annotations/ConcatData.h"
 #include "Annotations/ShapeInferenceData.h"
 
 #include <loco.h>
@@ -164,8 +165,66 @@ bool fix_shape(loco::ReLU *node)
 
 bool fix_shape(loco::TensorConcat *node)
 {
-  // TODO fill this
-  return false;
+  const ConcatData *concat_data = node->annot<ConcatData>();
+  if (concat_data == nullptr)
+  {
+    // shape inference is already done for TensorConcat
+    assert(node->annot<ShapeInferenceData>() != nullptr);
+    return false;
+  }
+  assert(node->annot<ShapeInferenceData>() == nullptr);
+
+  loco::Node *lhs = node->lhs();
+  loco::Node *rhs = node->rhs();
+  const ShapeInferenceData *lhs_shapedata = lhs->annot<ShapeInferenceData>();
+  const ShapeInferenceData *rhs_shapedata = rhs->annot<ShapeInferenceData>();
+  if (lhs_shapedata == nullptr || rhs_shapedata == nullptr)
+  {
+    // postpone as previous input node(s) hasn't been processed.
+    // this will return false as there was nothing changed, but the input of this
+    // node should be changed and from that this method should be called again.
+    // if not, the network may have some problem and the final output may not have
+    // the right shape value and we can identify with some validation at final stage.
+    return false;
+  }
+
+  uint32_t lhs_rank = lhs_shapedata->rank();
+  uint32_t rhs_rank = rhs_shapedata->rank();
+  assert(lhs_rank == rhs_rank);
+
+  int32_t axis_tf = concat_data->axis();
+  if (axis_tf < 0)
+  {
+    axis_tf = static_cast<int32_t>(lhs_rank) + axis_tf;
+  }
+  assert(0 <= axis_tf && axis_tf < static_cast<int32_t>(lhs_rank));
+  // clear annotation ConcatData
+  node->annot<ConcatData>(nullptr);
+
+  uint32_t axis_loco = static_cast<uint32_t>(axis_tf);
+  node->axis(axis_loco);
+
+  // Set ShapeInferenceData for TensorConcat
+  auto shape_data = stdex::make_unique<ShapeInferenceData>();
+  shape_data->rank(lhs_rank);
+  for (uint32_t index = 0; index < lhs_rank; ++index)
+  {
+    uint32_t lhs_dim = lhs_shapedata->dim(index).value();
+    uint32_t rhs_dim = rhs_shapedata->dim(index).value();
+    // "lhs_dim == rhs_dim" should hold when "index != axis_loco"
+    // or doesn't care when "index == axis_loco"
+    assert(index == axis_loco || lhs_dim == rhs_dim);
+
+    uint32_t new_dim = (index == axis_loco) ? lhs_dim + rhs_dim : lhs_dim;
+
+    if (lhs_shapedata->dim(index).known())
+      shape_data->dim(index) = loco::make_dimension(new_dim);
+    else
+      shape_data->dim(index).unset();
+  }
+  node->annot(std::move(shape_data));
+
+  return true;
 }
 
 } // namespace