[moco-tf] Revise shape inf for TFShape and rest (#8053)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Fri, 11 Oct 2019 06:55:48 +0000 (15:55 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Fri, 11 Oct 2019 06:55:48 +0000 (15:55 +0900)
This will revise shape inference for TFShape and reset of the nodes to be done in TFShapeInferenceRule

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
compiler/moco-tf/src/Dialect/TFShapeInferenceRule.cpp
compiler/moco-tf/src/Transforms/FixShapeTransform.cpp

index 1cbde70..9997420 100644 (file)
@@ -478,6 +478,134 @@ public:
     return node_shape_with_check(node->x());
   }
 
+  loco::NodeShape visit(const moco::tf::TFShape *node) final
+  {
+    auto input_shape = node_shape(node->input());
+    auto input_tensor_shape = input_shape.as<loco::TensorShape>();
+
+    loco::TensorShape output_shape;
+
+    // Note that input shape becomes node(TFShape)'s value
+    output_shape.rank(1);
+    output_shape.dim(0) = input_tensor_shape.rank();
+
+    return loco::NodeShape(output_shape);
+  }
+
+  loco::NodeShape visit(const moco::tf::TFSoftmax *node) final
+  {
+    return node_shape_with_check(node->logits());
+  }
+
+  loco::NodeShape visit(const moco::tf::TFSqrt *node) final
+  {
+    return node_shape_with_check(node->x());
+  }
+
+  loco::NodeShape visit(const moco::tf::TFSquaredDifference *node) final
+  {
+    return binary_node_shape(node);
+  }
+
+  loco::NodeShape visit(const moco::tf::TFSqueeze *node) final
+  {
+    auto input_shape = node_shape(node->input());
+
+    // TODO Not sure Squeeze only get input as Tensor
+    // Note that tensor_shape() has assertion in it
+    auto input_tensor_shape = input_shape.as<loco::TensorShape>();
+
+    auto squeeze_dims_vec = node->squeeze_dims();
+    std::set<int64_t> squeeze_dims(squeeze_dims_vec.cbegin(), squeeze_dims_vec.cend());
+
+    loco::TensorShape output_shape;
+    uint32_t output_rank = 0;
+
+    if (squeeze_dims.empty())
+    {
+      // Remove all dimensions whose value is 1
+      for (uint32_t axis = 0; axis < input_tensor_shape.rank(); ++axis)
+      {
+        assert(input_tensor_shape.dim(axis).known());
+        auto dim = input_tensor_shape.dim(axis).value();
+        if (dim != 1)
+        {
+          assert(dim > 1);
+          output_shape.rank(++output_rank);
+          output_shape.dim(output_rank - 1) = dim;
+        }
+      }
+    }
+    else
+    {
+      uint32_t input_rank = input_tensor_shape.rank();
+
+      // Sanity check for 'squeeze_dims'
+      auto is_valid_squeeze_dims = [&squeeze_dims, &input_rank]() {
+        if (!(squeeze_dims.size() < input_rank))
+          return false;
+        for (auto squeeze_dim : squeeze_dims)
+        {
+          if (!(squeeze_dim >= -(int64_t)input_rank))
+            return false;
+          if (!(squeeze_dim < (int64_t)input_rank))
+            return false;
+        }
+        return true;
+      };
+
+      if (!is_valid_squeeze_dims())
+      {
+        throw std::runtime_error("Fix shape for TFSqueeze: invalid squeeze dimension");
+      }
+
+      // Resolve negative squeeze dimension
+      std::set<int64_t> resolved_squeeze_dims;
+      for (auto squeeze_dim : squeeze_dims)
+      {
+        if (squeeze_dim < 0)
+          resolved_squeeze_dims.insert(squeeze_dim + (int64_t)input_rank);
+        else
+          resolved_squeeze_dims.insert(squeeze_dim);
+      }
+
+      // Remove squeeze dimensions only
+      for (uint32_t axis = 0; axis < input_rank; ++axis)
+      {
+        assert(input_tensor_shape.dim(axis).known());
+        auto dim = input_tensor_shape.dim(axis).value();
+        if (resolved_squeeze_dims.find((int64_t)axis) == resolved_squeeze_dims.cend())
+        {
+          // Not squeeze dim
+          output_shape.rank(++output_rank);
+          output_shape.dim(output_rank - 1) = dim;
+        }
+        else
+        {
+          // Is squeeze dim
+          assert(dim == 1);
+          // DO NOTHING
+        }
+      }
+    }
+
+    assert(output_shape.rank() > 0);
+
+    return loco::NodeShape(output_shape);
+  }
+
+  loco::NodeShape visit(const moco::tf::TFStopGradient *node) final
+  {
+    return node_shape_with_check(node->input());
+  }
+
+  loco::NodeShape visit(const moco::tf::TFSub *node) final { return binary_node_shape(node); }
+
+  loco::NodeShape visit(const moco::tf::TFTanh *node) final
+  {
+    return node_shape_with_check(node->x());
+  }
+
 public:
   loco::NodeShape visit(const moco::tf::TFNode *node) final
   {
index 635d5f4..9d714f7 100644 (file)
@@ -469,189 +469,21 @@ bool fix_shape(moco::tf::TFReshape *node) { return false; }
 
 bool fix_shape(moco::tf::TFRsqrt *node) { return false; }
 
-bool fix_shape(moco::tf::TFShape *node)
-{
-  if (shape_inference_done(node))
-    return false;
-
-  auto input = node->input();
-  loco::NodeShape input_shape;
-  if (!node_shape(input, input_shape))
-  {
-    // Input shape is required for TFShape shape inference
-    return false;
-  }
-  loco::TensorShape input_tensor_shape = input_shape.as<loco::TensorShape>();
-
-  loco::TensorShape node_shape;
-
-  // Note that input shape becomes node(TFShape)'s value
-  node_shape.rank(1);
-  node_shape.dim(0) = input_tensor_shape.rank();
-
-  auto shape_annot = stdex::make_unique<ShapeInferenceData>();
-  shape_annot->tensor_shape(node_shape);
-  node->annot(std::move(shape_annot));
-
-  LOGGER(l);
-  INFO(l) << "Fix TFShape shape = " << node_shape;
-
-  return true;
-}
-
-bool fix_shape(moco::tf::TFSqrt *node)
-{
-  // Output shape is same as the input x
-  auto x = node->x();
-  return copy_shapedata(x, node);
-}
-
-bool fix_shape(moco::tf::TFSoftmax *node)
-{
-  // Output shape is same as the input x
-  auto logits = node->logits();
-  return copy_shapedata(logits, node);
-}
-
-bool fix_shape(moco::tf::TFSquaredDifference *node)
-{
-  auto x = node->x();
-  auto y = node->y();
-  return copy_shapedata(x, y, node);
-}
-
-bool fix_shape(moco::tf::TFSqueeze *node)
-{
-  if (shape_inference_done(node))
-    return false;
-
-  auto input = node->input();
-  loco::NodeShape input_shape;
-  if (!node_shape(input, input_shape))
-  {
-    // Input shape is required for TFSqueeze shape inference
-    return false;
-  }
-
-  // TODO Not sure Squeeze only get input as Tensor
-  // Note that tensor_shape() has assertion in it
-  auto input_tensor_shape = input_shape.as<loco::TensorShape>();
-
-  auto squeeze_dims_vec = node->squeeze_dims();
-  std::set<int64_t> squeeze_dims(squeeze_dims_vec.cbegin(), squeeze_dims_vec.cend());
-
-  loco::TensorShape node_shape;
-  uint32_t node_rank = 0;
-
-  if (squeeze_dims.empty())
-  {
-    // Remove all dimensions whose value is 1
-    for (uint32_t axis = 0; axis < input_tensor_shape.rank(); ++axis)
-    {
-      assert(input_tensor_shape.dim(axis).known());
-      auto dim = input_tensor_shape.dim(axis).value();
-      if (dim != 1)
-      {
-        assert(dim > 1);
-        node_shape.rank(++node_rank);
-        node_shape.dim(node_rank - 1) = dim;
-      }
-    }
-  }
-  else
-  {
-    uint32_t input_rank = input_tensor_shape.rank();
-
-    // Sanity check for 'squeeze_dims'
-    auto is_valid_squeeze_dims = [&squeeze_dims, &input_rank]() {
-      if (!(squeeze_dims.size() < input_rank))
-        return false;
-      for (auto squeeze_dim : squeeze_dims)
-      {
-        if (!(squeeze_dim >= -(int64_t)input_rank))
-          return false;
-        if (!(squeeze_dim < (int64_t)input_rank))
-          return false;
-      }
-      return true;
-    };
-
-    if (!is_valid_squeeze_dims())
-    {
-      throw std::runtime_error("Fix shape for TFSqueeze: invalid squeeze dimension");
-    }
-
-    // Resolve negative squeeze dimension
-    std::set<int64_t> resolved_squeeze_dims;
-    for (auto squeeze_dim : squeeze_dims)
-    {
-      if (squeeze_dim < 0)
-        resolved_squeeze_dims.insert(squeeze_dim + (int64_t)input_rank);
-      else
-        resolved_squeeze_dims.insert(squeeze_dim);
-    }
-
-    // Remove squeeze dimensions only
-    for (uint32_t axis = 0; axis < input_rank; ++axis)
-    {
-      assert(input_tensor_shape.dim(axis).known());
-      auto dim = input_tensor_shape.dim(axis).value();
-      if (resolved_squeeze_dims.find((int64_t)axis) == resolved_squeeze_dims.cend())
-      {
-        // Not squeeze dim
-        node_shape.rank(++node_rank);
-        node_shape.dim(node_rank - 1) = dim;
-      }
-      else
-      {
-        // Is squeeze dim
-        assert(dim == 1);
-        // DO NOTHING
-      }
-    }
-  }
-
-  assert(node_shape.rank() > 0);
+bool fix_shape(moco::tf::TFShape *node) { return false; }
 
-  auto shape_annot = stdex::make_unique<ShapeInferenceData>();
-  shape_annot->tensor_shape(node_shape);
-  node->annot(std::move(shape_annot));
+bool fix_shape(moco::tf::TFSqrt *node) { return false; }
 
-  LOGGER(l);
-  INFO(l) << "Fix TFSqueeze shape = " << node_shape;
-
-  return true;
-}
+bool fix_shape(moco::tf::TFSoftmax *node) { return false; }
 
-bool fix_shape(moco::tf::TFStopGradient *node)
-{
-  // Output shape is same as the input
-  auto input = node->input();
-  return copy_shapedata(input, node);
-}
+bool fix_shape(moco::tf::TFSquaredDifference *node) { return false; }
 
-bool fix_shape(moco::tf::TFSub *node)
-{
-  auto x = node->x();
-  auto y = node->y();
-  loco::NodeShape x_shape;
-  loco::NodeShape y_shape;
+bool fix_shape(moco::tf::TFSqueeze *node) { return false; }
 
-  if (!node_shape(x, x_shape))
-    return false;
-  if (!node_shape(y, y_shape))
-    return false;
+bool fix_shape(moco::tf::TFStopGradient *node) { return false; }
 
-  // Output shape is same as the input
-  return copy_shapedata(x, y, node);
-}
+bool fix_shape(moco::tf::TFSub *node) { return false; }
 
-bool fix_shape(moco::tf::TFTanh *node)
-{
-  // Output shape is same as the input
-  auto x = node->x();
-  return copy_shapedata(x, node);
-}
+bool fix_shape(moco::tf::TFTanh *node) { return false; }
 
 bool fix_shape(locoex::COpCall *node)
 {