[moco-tf] Remove ShapeInferenceData in Canonicalizers (#7810)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Mon, 30 Sep 2019 03:09:33 +0000 (12:09 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Mon, 30 Sep 2019 03:09:33 +0000 (12:09 +0900)
This will revise Canonicalizers not to use ShapeInferenceData and use node_shape instead

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
compiler/moco-tf/src/Canonicalization/AvgPoolCanonicalizer.cpp
compiler/moco-tf/src/Canonicalization/ConcatV2Canonicalizer.cpp
compiler/moco-tf/src/Canonicalization/DepthwiseConv2dNativeCanonicalizer.cpp
compiler/moco-tf/src/Canonicalization/RsqrtCanonicalizer.cpp
compiler/moco-tf/src/Canonicalization/SoftmaxCanonicalizer.cpp
compiler/moco-tf/src/Canonicalization/SqueezeCanonicalizer.cpp

index b764ba8..73cfaa3 100644 (file)
@@ -18,7 +18,6 @@
 
 #include "Annotations/PadData.h"
 #include "Annotations/StrideData.h"
-#include "Annotations/ShapeInferenceData.h"
 #include "Annotations/WindowData.h"
 
 #include "Dialect/TFDialect.h"
index f5cf1ac..0125c69 100644 (file)
 #include "LogHelper.h"
 
 #include "Annotations/ConcatData.h"
-#include "Annotations/ShapeInferenceData.h"
 
 #include "Dialect/TFDialect.h"
 #include "Dialect/TFNodes.h"
 #include "Dialect/TFNodeVisitor.h"
 #include "Dialect/TFNodeImpl.h"
+#include "Dialect/TFShapeInferenceHelper.h"
 
 #include <moco/Log.h>
 
@@ -76,8 +76,9 @@ bool canonicalize_concat(loco::Graph *graph, moco::tf::TFConcatV2 *node)
   assert(concat_data != nullptr);
   auto axis_value = concat_data->axis();
 
-  auto shapedata = node->annot<ShapeInferenceData>();
-  auto node_rank = shapedata->rank();
+  auto nodeshape = moco::tf::node_shape(node);
+  auto tensorshape = nodeshape.as<loco::TensorShape>();
+  auto node_rank = tensorshape.rank();
 
   uint32_t axis_absolute = (axis_value >= 0) ? axis_value : (int32_t)node_rank + axis_value;
 
index 75907f7..561e04e 100644 (file)
@@ -17,7 +17,6 @@
 #include "DepthwiseConv2dNativeCanonicalizer.h"
 
 #include "Annotations/PadData.h"
-#include "Annotations/ShapeInferenceData.h"
 #include "Annotations/StrideData.h"
 
 #include "Dialect/TFDialect.h"
index 98352f8..aafa8cf 100644 (file)
 
 #include "RsqrtCanonicalizer.h"
 
-#include "Annotations/ShapeInferenceData.h"
-
 #include "Dialect/TFDialect.h"
 #include "Dialect/TFNodes.h"
 #include "Dialect/TFNodeVisitor.h"
 #include "Dialect/TFNodeImpl.h"
+#include "Dialect/TFShapeInferenceHelper.h"
 
 #include <moco/Log.h>
 
@@ -33,12 +32,11 @@ namespace
 {
 
 template <typename T>
-void prepare_const_gen(loco::ConstGen *const_node, const moco::tf::ShapeInferenceData *shapedata,
-                       T value);
+void prepare_const_gen(loco::ConstGen *const_node, const loco::TensorShape &tensorshape, T value);
 
 template <>
-void prepare_const_gen<float>(loco::ConstGen *const_node,
-                              const moco::tf::ShapeInferenceData *shapedata, float value)
+void prepare_const_gen<float>(loco::ConstGen *const_node, const loco::TensorShape &tensorshape,
+                              float value)
 {
   LOGGER(l);
 
@@ -47,18 +45,18 @@ void prepare_const_gen<float>(loco::ConstGen *const_node,
   auto dtype = loco::DataType::FLOAT32;
   const_node->dtype(dtype);
 
-  auto rank = shapedata->rank();
+  auto rank = tensorshape.rank();
   const_node->rank(rank);
   for (uint32_t r = 0; r < rank; ++r)
   {
-    if (shapedata->dim(r).known())
-      const_node->dim(r) = shapedata->dim(r);
+    if (tensorshape.dim(r).known())
+      const_node->dim(r) = tensorshape.dim(r);
     else
       throw std::runtime_error("Cannot handle unknown shape");
 
-    assert(shapedata->dim(r).value() > 0);
+    assert(tensorshape.dim(r).value() > 0);
 
-    const_num_elements *= shapedata->dim(r).value();
+    const_num_elements *= tensorshape.dim(r).value();
   }
 
   INFO(l) << "prepare_const_gen : Elements = " << const_num_elements;
@@ -91,13 +89,14 @@ bool canonicalize_rsqrt(loco::Graph *graph, moco::tf::TFRsqrt *node)
    *                 TFRsqrt is converted to 1 / EltwiseSqrt
    */
 
-  auto rsqrt_shapedata = node->annot<moco::tf::ShapeInferenceData>();
-  if (rsqrt_shapedata == nullptr)
+  auto nodeshape = moco::tf::node_shape(node);
+  if (nodeshape.domain() == loco::Domain::Unknown)
   {
     // We need this shape information
     assert(false); // this shouldn't happen, let's add an alarm
     return false;
   }
+  auto tensorshape = nodeshape.as<loco::TensorShape>();
 
   if (!loco::dtype_known(node))
   {
@@ -114,7 +113,7 @@ bool canonicalize_rsqrt(loco::Graph *graph, moco::tf::TFRsqrt *node)
   switch (dtype)
   {
     case loco::DataType::FLOAT32:
-      prepare_const_gen<float>(const_node, rsqrt_shapedata, 1.0f);
+      prepare_const_gen<float>(const_node, tensorshape, 1.0f);
       break;
 
     default:
index 5bee38b..466ade2 100644 (file)
 
 #include "SoftmaxCanonicalizer.h"
 
-#include "Annotations/ShapeInferenceData.h"
-
 #include "Dialect/TFDialect.h"
 #include "Dialect/TFNodes.h"
 #include "Dialect/TFNodeVisitor.h"
 #include "Dialect/TFNodeImpl.h"
+#include "Dialect/TFShapeInferenceHelper.h"
 
 #include <moco/Log.h>
 
@@ -46,12 +45,11 @@ bool canonicalize_softmax(loco::Graph *graph, moco::tf::TFSoftmax *node)
   *           In ---- TensorSoftmax ----- Out(s)
   */
 
-  auto softmax_shape = node->annot<moco::tf::ShapeInferenceData>();
-
+  auto nodeshape = moco::tf::node_shape(node);
   // Canonicalization into TensorSoftmax is valid when softmax has shape info
-  assert(softmax_shape);
+  assert(nodeshape.domain() != loco::Domain::Unknown);
 
-  auto softmax_tensor_shape = softmax_shape->tensor_shape();
+  auto softmax_tensor_shape = nodeshape.as<loco::TensorShape>();
 
   // Create loco node to replace
   auto softmax = graph->nodes()->create<loco::TensorSoftmax>();
index 12793cb..f2752bf 100644 (file)
 
 #include "SqueezeCanonicalizer.h"
 
-#include "Annotations/ShapeInferenceData.h"
-
 #include "Dialect/TFDialect.h"
 #include "Dialect/TFNodes.h"
 #include "Dialect/TFNodeVisitor.h"
 #include "Dialect/TFNodeImpl.h"
+#include "Dialect/TFShapeInferenceHelper.h"
 
 #include <moco/Log.h>
 
@@ -46,12 +45,12 @@ bool canonicalize_squeeze_to_reshape(loco::Graph *graph, moco::tf::TFSqueeze *no
    *           In ---- FixedReshape ----- Out(s)
    */
 
-  auto squeeze_shape = node->annot<moco::tf::ShapeInferenceData>();
+  auto nodeshape = moco::tf::node_shape(node);
   // canonicalize into FixedReshape is valid when squeeze has shape info
   // TODO Support general Squeeze case
-  assert(squeeze_shape);
+  assert(nodeshape.domain() != loco::Domain::Unknown);
 
-  auto squeeze_tensor_shape = squeeze_shape->tensor_shape();
+  auto squeeze_tensor_shape = nodeshape.as<loco::TensorShape>();
 
   // Create loco node to replace
   auto reshape = graph->nodes()->create<loco::FixedReshape>();