[MLIR][Shape] Expose extent tensor type builder
authorFrederik Gossen <frgossen@google.com>
Wed, 5 Aug 2020 09:41:42 +0000 (09:41 +0000)
committerFrederik Gossen <frgossen@google.com>
Wed, 5 Aug 2020 09:42:57 +0000 (09:42 +0000)
The extent tensor type is a `tensor<?xindex>` that is used in the shape dialect.
To facilitate the use of this type when working with the shape dialect, we
expose the helper function for its construction.

Differential Revision: https://reviews.llvm.org/D85121

mlir/include/mlir/Dialect/Shape/IR/Shape.h
mlir/lib/Dialect/Shape/IR/Shape.cpp

index 62e4e0c..ca1e066 100644 (file)
@@ -26,6 +26,9 @@ class PatternRewriter;
 
 namespace shape {
 
+/// Alias type for extent tensors.
+RankedTensorType getExtentTensorType(MLIRContext *ctx);
+
 namespace ShapeTypes {
 enum Kind {
   Component = Type::FIRST_SHAPE_TYPE,
index 02fe7b8..be4c3c7 100644 (file)
@@ -24,7 +24,7 @@ namespace {
 #include "ShapeCanonicalization.inc"
 }
 
-static RankedTensorType getExtentTensorType(MLIRContext *ctx) {
+RankedTensorType shape::getExtentTensorType(MLIRContext *ctx) {
   return RankedTensorType::get({ShapedType::kDynamicSize}, IndexType::get(ctx));
 }
 
@@ -713,12 +713,9 @@ OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
 }
 
 void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) {
-  if (arg.getType().isa<ShapedType>()) {
-    auto type = RankedTensorType::get({ShapedType::kDynamicSize},
-                                      builder.getIndexType());
-    return ShapeOfOp::build(builder, result, type, arg);
-  }
-  auto type = ShapeType::get(builder.getContext());
+  Type type = arg.getType().isa<ShapedType>()
+                  ? (Type)getExtentTensorType(builder.getContext())
+                  : (Type)builder.getType<ShapeType>();
   return ShapeOfOp::build(builder, result, type, arg);
 }