[mlir] Convert ConstShapeOp to a static tensor type.
authorAdrian Kuegel <akuegel@google.com>
Tue, 5 Oct 2021 09:10:42 +0000 (11:10 +0200)
committerAdrian Kuegel <akuegel@google.com>
Tue, 5 Oct 2021 10:14:43 +0000 (12:14 +0200)
ConstShapeOp knows its shape, so it should also have a static tensor type.

Differential Revision: https://reviews.llvm.org/D111127

mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir

index f622d5e..3aef4bb 100644 (file)
@@ -191,7 +191,7 @@ LogicalResult ConstShapeOpConverter::matchAndRewrite(
   Type indexTy = rewriter.getIndexType();
   Value tensor =
       rewriter.create<tensor::FromElementsOp>(loc, indexTy, extentOperands);
-  Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
+  Type resultTy = RankedTensorType::get({op.shape().size()}, indexTy);
   rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, tensor);
   return success();
 }
index bc551f8..ccc8d56 100644 (file)
@@ -89,29 +89,29 @@ func @get_extent_from_extent_tensor(%extents : tensor<?xindex>, %idx : index)
 
 // Lower `const_shape` to `tensor.from_elements`.
 // CHECK-LABEL: @const_shape
-// CHECK-SAME: () -> tensor<?xindex>
-func @const_shape() -> tensor<?xindex> {
+// CHECK-SAME: () -> tensor<3xindex>
+func @const_shape() -> tensor<3xindex> {
   // CHECK: %[[C1:.*]] = constant 1 : index
   // CHECK: %[[C2:.*]] = constant 2 : index
   // CHECK: %[[C3:.*]] = constant 3 : index
   // CHECK: %[[TENSOR3:.*]] = tensor.from_elements %[[C1]], %[[C2]], %[[C3]]
-  // CHECK: %[[RESULT:.*]] = tensor.cast %[[TENSOR3]] : tensor<3xindex> to tensor<?xindex>
-  // CHECK: return %[[RESULT]] : tensor<?xindex>
-  %shape = shape.const_shape [1, 2, 3] : tensor<?xindex>
-  return %shape : tensor<?xindex>
+  // CHECK: %[[RESULT:.*]] = tensor.cast %[[TENSOR3]] : tensor<3xindex> to tensor<3xindex>
+  // CHECK: return %[[RESULT]] : tensor<3xindex>
+  %shape = shape.const_shape [1, 2, 3] : tensor<3xindex>
+  return %shape : tensor<3xindex>
 }
 
 // -----
 
 // Lower `const_shape` in the case of rank 0.
 // CHECK-LABEL: func @const_shape_zero_elements
-// CHECK-SAME: () -> tensor<?xindex>
-func @const_shape_zero_elements() -> tensor<?xindex> {
+// CHECK-SAME: () -> tensor<0xindex>
+func @const_shape_zero_elements() -> tensor<0xindex> {
   // CHECK: %[[TENSOR:.*]] = tensor.from_elements : tensor<0xindex>
-  // CHECK: %[[RESULT:.*]] = tensor.cast %[[TENSOR]] : tensor<0xindex> to tensor<?xindex>
-  // CHECK: return %[[RESULT]] : tensor<?xindex>
-  %shape = shape.const_shape [] : tensor<?xindex>
-  return %shape : tensor<?xindex>
+  // CHECK: %[[RESULT:.*]] = tensor.cast %[[TENSOR]] : tensor<0xindex> to tensor<0xindex>
+  // CHECK: return %[[RESULT]] : tensor<0xindex>
+  %shape = shape.const_shape [] : tensor<0xindex>
+  return %shape : tensor<0xindex>
 }
 
 // -----