[moco-tf] TFShape shape inference (#6373)
author박천교/On-Device Lab(SR)/Engineer/삼성전자 <ch.bahk@samsung.com>
Thu, 8 Aug 2019 07:42:02 +0000 (16:42 +0900)
committer박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Thu, 8 Aug 2019 07:42:02 +0000 (16:42 +0900)
This commit implements TFShape node's shape inference.

Signed-off-by: Cheongyo Bahk <ch.bahk@samsung.com>
compiler/moco-tf/src/Transforms/FixShapeTransform.cpp

index a372109..d3f45b2 100644 (file)
@@ -1349,8 +1349,35 @@ bool fix_shape(moco::tf::TFRsqrt *node)
 
 bool fix_shape(moco::tf::TFShape *node)
 {
-  // TODO implement
-  throw std::runtime_error("TFShape shape inference NYI");
+  auto shapedata = node->annot<ShapeInferenceData>();
+  if (shapedata != nullptr)
+  {
+    // shape inference is already done for TFShape
+    return false;
+  }
+
+  auto input = node->input();
+  auto input_shape = input->annot<ShapeInferenceData>();
+  if (input_shape == nullptr)
+  {
+    // Input shape is required for TFShape shape inference
+    return false;
+  }
+
+  loco::TensorShape node_shape;
+
+  // Note that input shape becomes node(TFShape)'s value
+  node_shape.rank(1);
+  node_shape.dim(0) = input_shape->rank();
+
+  auto shape_annot = stdex::make_unique<ShapeInferenceData>();
+  shape_annot->tensor_shape(node_shape);
+  node->annot(std::move(shape_annot));
+
+  LOGGER(l);
+  INFO(l) << "Fix TFShape shape = " << node_shape;
+
+  return true;
 }
 
 bool fix_shape(moco::tf::TFSqueeze *node)