[moco-tf] TFSqueeze shape inference (#6117)
author박천교/On-Device Lab(SR)/Engineer/삼성전자 <ch.bahk@samsung.com>
Fri, 2 Aug 2019 08:11:07 +0000 (17:11 +0900)
committer박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Fri, 2 Aug 2019 08:11:07 +0000 (17:11 +0900)
* [moco-tf] TFSqueeze shape inference

This commit introduces shape inference for TFSqueeze

Signed-off-by: Cheongyo Bahk <ch.bahk@samsung.com>
* Review fix: use set, use lambda for assertion

compiler/moco-tf/src/Transforms/FixShapeTransform.cpp

index 11e330d..9fa72ae 100644 (file)
@@ -1148,8 +1148,96 @@ bool fix_shape(moco::tf::TFRsqrt *node)
 
 bool fix_shape(moco::tf::TFSqueeze *node)
 {
-  // TODO implement
-  throw std::runtime_error("NYI fix_shape TFSqueeze");
+  auto shapedata = node->annot<ShapeInferenceData>();
+  if (shapedata != nullptr)
+  {
+    // shape inference is already done for TFSqueeze
+    return false;
+  }
+
+  auto input = node->input();
+  auto input_shape = input->annot<ShapeInferenceData>();
+  if (input_shape == nullptr)
+  {
+    // Input shape is required for TFSqueeze shape inference
+    return false;
+  }
+
+  // TODO Not sure Squeeze only get input as Tensor
+  // Note that tensor_shape() has assertion in it
+  auto input_tensor_shape = input_shape->tensor_shape();
+
+  auto squeeze_dims_vec = node->squeeze_dims();
+  const std::set<int64_t> squeeze_dims(squeeze_dims_vec.cbegin(), squeeze_dims_vec.cend());
+
+  loco::TensorShape node_shape;
+  uint32_t node_rank = 0;
+
+  if (squeeze_dims.empty())
+  {
+    // Remove all dimensions whose value is 1
+    for (uint32_t axis = 0; axis < input_tensor_shape.rank(); ++axis)
+    {
+      assert(input_tensor_shape.dim(axis).known());
+      auto dim = input_tensor_shape.dim(axis).value();
+      if (dim != 1)
+      {
+        assert(dim > 1);
+        node_shape.rank(++node_rank);
+        node_shape.dim(node_rank - 1) = dim;
+      }
+    }
+  }
+  else
+  {
+    uint32_t input_rank = input_tensor_shape.rank();
+
+    auto is_valid_squeeze_dims = [&squeeze_dims, &input_rank]() {
+      if (squeeze_dims.size() >= input_rank)
+        return false;
+      for (auto squeeze_dim : squeeze_dims)
+      {
+        // Negative squeeze dimensions should be resolve before
+        if (squeeze_dim < 0)
+          return false;
+        if (squeeze_dim >= (int64_t)input_rank)
+          return false;
+      }
+      return true;
+    };
+
+    assert(is_valid_squeeze_dims());
+
+    // Remove squeeze dimensions only
+    for (uint32_t axis = 0; axis < input_rank; ++axis)
+    {
+      assert(input_tensor_shape.dim(axis).known());
+      auto dim = input_tensor_shape.dim(axis).value();
+      if (squeeze_dims.find((int64_t)axis) == squeeze_dims.cend())
+      {
+        // Not squeeze dim
+        node_shape.rank(++node_rank);
+        node_shape.dim(node_rank - 1) = dim;
+      }
+      else
+      {
+        // Is squeeze dim
+        assert(dim == 1);
+        // DO NOTHING
+      }
+    }
+  }
+
+  assert(node_shape.rank() > 0);
+
+  auto shape_annot = stdex::make_unique<ShapeInferenceData>();
+  shape_annot->tensor_shape(node_shape);
+  node->annot(std::move(shape_annot));
+
+  LOGGER(l);
+  INFO(l) << "Fix TFSqueeze shape = " << node_shape;
+
+  return true;
 }
 
 } // namespace