[moco/tf] Implement fix_shape for TFConst (#4274)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Mon, 15 Jul 2019 10:00:02 +0000 (19:00 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Mon, 15 Jul 2019 10:00:02 +0000 (19:00 +0900)
This will implement fix_shape() for TFConst node

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
contrib/moco-tf/src/Transforms/FixShapeTransform.cpp

index 026c2a0..b9efe1b 100644 (file)
@@ -581,8 +581,19 @@ bool fix_shape(moco::tf::TFBiasAdd *node)
 
 bool fix_shape(moco::tf::TFConst *node)
 {
-  // TODO implement this
-  throw std::runtime_error("NYI fix_shape for TFConst");
+  auto shapedata = node->annot<ShapeInferenceData>();
+  if (shapedata != nullptr)
+  {
+    // shape inference is already done for TFConst
+    return false;
+  }
+
+  // TFConst itself has shape information, copy them
+  auto shape_data = stdex::make_unique<ShapeInferenceData>();
+  copy_shape_values(node, shape_data.get());
+  node->annot(std::move(shape_data));
+
+  return true;
 }
 
 bool fix_shape(moco::tf::TFConv2D *node)