[moco-tf] Revise ResolveConstantShape (#7842)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Mon, 30 Sep 2019 22:21:06 +0000 (07:21 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Mon, 30 Sep 2019 22:21:06 +0000 (07:21 +0900)
This will revise ResolveConstantShape to use NodeShape instead of ShapeInferenceData

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
compiler/moco-tf/src/Transforms/ResolveConstantShape.cpp

index 017aa66..4eaa679 100644 (file)
  */
 
 #include "ResolveConstantShape.h"
+#include "TFShapeInferenceHelper.h"
 
 #include "IR/TFShape.h"
 #include "IR/TFConst.h"
-#include "Annotations/ShapeInferenceData.h"
 
 #include <loco.h>
 
@@ -44,28 +44,27 @@ namespace
  */
 bool resolve_constant_shape(loco::Graph *graph, moco::tf::TFShape *shape_node)
 {
-  using moco::tf::ShapeInferenceData;
-
-  auto input_shape = shape_node->input()->annot<ShapeInferenceData>();
+  auto input_shape = moco::tf::node_shape(shape_node->input());
 
   // Check condition
-  if (!input_shape)
+  if (input_shape.domain() == loco::Domain::Unknown)
   {
     // Cannot resolve without known input_shape
     return false;
   }
-  auto shape_rank = input_shape->rank();
+
+  auto input_tensor_shape = input_shape.as<loco::TensorShape>();
+
+  auto shape_rank = input_tensor_shape.rank();
   for (uint32_t axis = 0; axis < shape_rank; ++axis)
   {
-    if (!input_shape->dim(axis).known())
+    if (!input_tensor_shape.dim(axis).known())
     {
       // Cannot resolve with unknown dimension
       return false;
     }
   }
 
-  auto input_tensor_shape = input_shape->tensor_shape();
-
   // Make TFConst to replace TFShape
   auto const_node = graph->nodes()->create<moco::tf::TFConst>();