[moco/tf] Implement fix_shape for ConstGen (#3692)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Wed, 5 Jun 2019 04:32:56 +0000 (13:32 +0900)
committer박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Wed, 5 Jun 2019 04:32:56 +0000 (13:32 +0900)
This will implement fix_shape in shape inference for ConstGen node

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
contrib/moco/lib/frontend/tf/src/Transforms/FixShapeTransform.cpp

index 62bd5cc..2e83ac7 100644 (file)
@@ -72,8 +72,27 @@ bool copy_shapedata(const loco::Node *src, loco::Node *dst)
 
 bool fix_shape(loco::ConstGen *node)
 {
-  // TODO fill this
-  return false;
+  const ShapeInferenceData *shapedata = node->annot<ShapeInferenceData>();
+  if (shapedata != nullptr)
+  {
+    // shape inference is already done for ConstGen
+    return false;
+  }
+
+  // ConstGen itself has shape information, copy them
+  auto shape_data = stdex::make_unique<ShapeInferenceData>();
+  uint32_t rank = node->rank();
+  shape_data->rank(rank);
+  for (uint32_t index = 0; index < rank; ++index)
+  {
+    if (node->dim(index).known())
+      shape_data->dim(index) = loco::make_dimension(node->dim(index).value());
+    else
+      shape_data->dim(index).unset();
+  }
+  node->annot(std::move(shape_data));
+
+  return true;
 }
 
 bool fix_shape(loco::FeatureDecode *node)