[moco-tf] Apply shape_inference_done (#7304)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Mon, 9 Sep 2019 10:14:25 +0000 (19:14 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Mon, 9 Sep 2019 10:14:25 +0000 (19:14 +0900)
This will apply to use shape_inference_done() for fix_shape nodes

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
compiler/moco-tf/src/Transforms/FixShapeTransform.cpp

index c7cab16..9364391 100644 (file)
@@ -337,12 +337,8 @@ template <class T> void update_window_data(T *node)
 
 bool fix_shape(loco::Pull *node)
 {
-  auto shapedata = node->annot<ShapeInferenceData>();
-  if (shapedata != nullptr)
-  {
-    // shape inference is already done for Pull
+  if (shape_inference_done(node))
     return false;
-  }
 
   // Pull itself has shape information, copy them
   auto shape_data = make_shape_inference_data(node);
@@ -379,12 +375,9 @@ bool fix_shape(moco::tf::TFAvgPool *node)
 {
   LOGGER(l);
 
-  auto shapedata = node->annot<ShapeInferenceData>();
-  if (shapedata != nullptr)
-  {
-    // shape inference is already done for TFAvgPool
+  if (shape_inference_done(node))
     return false;
-  }
+
   auto value = node->value();
   loco::NodeShape value_shape;
   if (!node_shape(value, value_shape))
@@ -513,9 +506,8 @@ bool fix_shape(moco::tf::TFConcatV2 *node)
 {
   LOGGER(l);
 
-  if (node->annot<ShapeInferenceData>() != nullptr)
+  if (shape_inference_done(node))
   {
-    // shape inference is already done for TFConcatV2
     INFO(l) << "Fix shape TFConcatV2 already done";
     return false;
   }
@@ -638,12 +630,8 @@ bool fix_shape(moco::tf::TFConcatV2 *node)
 
 bool fix_shape(moco::tf::TFConst *node)
 {
-  auto shapedata = node->annot<ShapeInferenceData>();
-  if (shapedata != nullptr)
-  {
-    // shape inference is already done for TFConst
+  if (shape_inference_done(node))
     return false;
-  }
 
   // TFConst itself has shape information, copy them
   auto shape_data = make_shape_inference_data(node);
@@ -663,12 +651,9 @@ bool fix_shape(moco::tf::TFConv2D *node)
 {
   LOGGER(l);
 
-  auto shapedata = node->annot<ShapeInferenceData>();
-  if (shapedata != nullptr)
-  {
-    // shape inference is already done
+  if (shape_inference_done(node))
     return false;
-  }
+
   auto ifm = node->ifm();
   loco::NodeShape ifm_shape;
   if (!node_shape(ifm, ifm_shape))
@@ -755,12 +740,9 @@ bool fix_shape(moco::tf::TFDepthwiseConv2dNative *node)
 {
   LOGGER(l);
 
-  auto shapedata = node->annot<ShapeInferenceData>();
-  if (shapedata != nullptr)
-  {
-    // shape inference is already done
+  if (shape_inference_done(node))
     return false;
-  }
+
   auto ifm = node->ifm();
   loco::NodeShape ifm_shape;
   if (!node_shape(ifm, ifm_shape))
@@ -858,12 +840,9 @@ bool fix_shape(moco::tf::TFMaxPool *node)
 {
   LOGGER(l);
 
-  auto shapedata = node->annot<ShapeInferenceData>();
-  if (shapedata != nullptr)
-  {
-    // shape inference is already done for TFMaxPool
+  if (shape_inference_done(node))
     return false;
-  }
+
   auto value = node->value();
   loco::NodeShape value_shape;
   if (!node_shape(value, value_shape))
@@ -982,12 +961,8 @@ bool fix_shape(moco::tf::TFRelu6 *node)
 
 bool fix_shape(moco::tf::TFReshape *node)
 {
-  auto shapedata = node->annot<ShapeInferenceData>();
-  if (shapedata != nullptr)
-  {
-    // shape inference is already done for TFReshape
+  if (shape_inference_done(node))
     return false;
-  }
 
   // For now, we only consider Fixed Reshape, i.e. Reshape with determined
   //      'shape' input. So here we only support case when 'shape' input of
@@ -1050,12 +1025,8 @@ bool fix_shape(moco::tf::TFRsqrt *node)
 
 bool fix_shape(moco::tf::TFShape *node)
 {
-  auto shapedata = node->annot<ShapeInferenceData>();
-  if (shapedata != nullptr)
-  {
-    // shape inference is already done for TFShape
+  if (shape_inference_done(node))
     return false;
-  }
 
   auto input = node->input();
   loco::NodeShape input_shape;
@@ -1105,12 +1076,8 @@ bool fix_shape(moco::tf::TFSquaredDifference *node)
 
 bool fix_shape(moco::tf::TFSqueeze *node)
 {
-  auto shapedata = node->annot<ShapeInferenceData>();
-  if (shapedata != nullptr)
-  {
-    // shape inference is already done for TFSqueeze
+  if (shape_inference_done(node))
     return false;
-  }
 
   auto input = node->input();
   loco::NodeShape input_shape;