[mlir][Interfaces] ReifyRankedShapedTypeOpInterface returns OpFoldResults
authorMatthias Springer <me@m-sp.org>
Fri, 3 Mar 2023 16:56:39 +0000 (17:56 +0100)
committerMatthias Springer <me@m-sp.org>
Mon, 6 Mar 2023 07:41:28 +0000 (08:41 +0100)
`reifyResultShapes` now returns `OpFoldResult`s instead of `Value`s. This is often more efficient because many transformations immediately attempt to extract a constant from the reified values.

Differential Revision: https://reviews.llvm.org/D145250

17 files changed:
mlir/include/mlir/Interfaces/InferTypeOpInterface.h
mlir/include/mlir/Interfaces/InferTypeOpInterface.td
mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp
mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp
mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
mlir/test/lib/Dialect/Test/TestDialect.cpp

index 4c0ffd0..42f5ec4 100644 (file)
@@ -26,7 +26,7 @@
 namespace mlir {
 
 class ShapedTypeComponents;
-using ReifiedRankedShapedTypeDims = SmallVector<SmallVector<Value>>;
+using ReifiedRankedShapedTypeDims = SmallVector<SmallVector<OpFoldResult>>;
 
 /// Adaptor class to abstract the differences between whether value is from
 /// a ShapedType or ShapedTypeComponents or DenseIntElementsAttribute.
index 052e590..9f71181 100644 (file)
@@ -211,15 +211,16 @@ def ReifyRankedShapedTypeOpInterface :
   let methods = [
     InterfaceMethod<
       /*desc=*/[{
-        Reify the shape of the result of an operation (typically in
-        terms of shape of its operands)
-
-        Insert operations using the given `OpBuilder` that computes
-        the result shape. The `reifiedReturnShapes` is expected to be
-        populated with as many vectors as the number of results of the
-        op. Each of these vectors is expected to be of size equal to
-        rank of the corresponding result. If the shape of a particular
-        result cannot be computed it must be empty.
+        Reify the shape of the result of an operation (typically in terms of the
+        shape of its operands).
+
+        `reifiedReturnShapes` is populated with one vector per op result. Each
+        of those vectors contains an OpFoldResult for each dimension of the
+        shaped type. In case a dimension in the type is static, the
+        corresponding entry is an IntegerAttr. Otherwise, it is a Value. The
+        given builder may be used to insert ops that compute result shapes.
+
+        If the shape of a particular result cannot be computed it must be empty.
       }],
       /*retTy=*/"::mlir::LogicalResult",
       /*methodName=*/"reifyResultShapes",
index 43c8b1d..0f119d5 100644 (file)
@@ -147,7 +147,7 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
               resultDims[shapedValue.cast<OpResult>().getResultNumber()];
           for (const auto &dim : enumerate(tensorType.getShape()))
             if (ShapedType::isDynamic(dim.value()))
-              dynamicSizes.push_back(shape[dim.index()]);
+              dynamicSizes.push_back(shape[dim.index()].get<Value>());
         }
       }
     }
index 66e1807..a25ca38 100644 (file)
@@ -369,13 +369,13 @@ void AllocTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
 
 LogicalResult AllocTensorOp::reifyResultShapes(
     OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
-  auto shapes = llvm::to_vector<4>(llvm::map_range(
-      llvm::seq<int64_t>(0, getType().getRank()), [&](int64_t dim) -> Value {
-        if (isDynamicDim(dim))
-          return getDynamicSize(builder, dim);
-        return builder.create<arith::ConstantIndexOp>(getLoc(),
-                                                      getStaticSize(dim));
-      }));
+  auto shapes = llvm::to_vector<4>(
+      llvm::map_range(llvm::seq<int64_t>(0, getType().getRank()),
+                      [&](int64_t dim) -> OpFoldResult {
+                        if (isDynamicDim(dim))
+                          return getDynamicSize(builder, dim);
+                        return builder.getIndexAttr(getStaticSize(dim));
+                      }));
   reifiedReturnShapes.emplace_back(std::move(shapes));
   return success();
 }
index a5c6dc6..6844b68 100644 (file)
@@ -642,13 +642,12 @@ LinalgOp::reifyResultShapes(OpBuilder &b,
   int64_t pos = 0;
   ArrayRef<AffineExpr> shapeExprs = resultShapesFromInputShapesMap.getResults();
   for (OpOperand *opOperand : getDpsInitOperands()) {
-    SmallVector<Value> shapes;
+    SmallVector<OpFoldResult> shapes;
     for (int64_t dim : llvm::seq<int64_t>(0, getRank(opOperand))) {
       if (checkDimExpr.visit(shapeExprs[pos]))
         shapes.push_back(createOrFoldDimOp(b, loc, opOperand->get(), dim));
       else
-        shapes.push_back(
-            getValueOrCreateConstantIndexOp(b, loc, allResultDimValues[pos]));
+        shapes.push_back(allResultDimValues[pos]);
       pos++;
     }
     reifiedReturnShapes.emplace_back(std::move(shapes));
index dc0b39a..b33d989 100644 (file)
@@ -488,10 +488,9 @@ struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
       return rewriter.notifyMatchFailure(
           padOp, "failed to reify tensor.pad op result shape");
 
-    SmallVector<OpFoldResult> newShape =
-        getAsOpFoldResult(reifiedShape.front());
     auto emptyTensor = rewriter.create<tensor::EmptyOp>(
-        padOp.getLoc(), newShape, padOp.getResultType().getElementType());
+        padOp.getLoc(), reifiedShape.front(),
+        padOp.getResultType().getElementType());
     Value replacement =
         rewriter
             .create<FillOp>(fillOp.getLoc(), ValueRange{padValue},
index 16baebb..700f873 100644 (file)
@@ -14,6 +14,7 @@
 //===----------------------------------------------------------------------===//
 //
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -132,7 +133,8 @@ static SmallVector<Value> reifyOrComputeDynamicSizes(OpBuilder &b,
       for (int64_t i = 0; i < tensorType.getRank(); ++i) {
         if (tensorType.isDynamicDim(i))
           dynSizes.push_back(
-              reifiedShape[value.cast<OpResult>().getResultNumber()][i]);
+              reifiedShape[value.cast<OpResult>().getResultNumber()][i]
+                  .get<Value>());
       }
       return dynSizes;
     }
@@ -298,7 +300,7 @@ mlir::linalg::rewriteInDestinationPassingStyle(RewriterBase &rewriter,
   SmallVector<Value> dynamicSizes;
   for (int64_t i = 0; i < resultType.getRank(); ++i)
     if (resultType.isDynamicDim(i))
-      dynamicSizes.push_back(reifiedShape[0][i]);
+      dynamicSizes.push_back(reifiedShape[0][i].get<Value>());
 
   // If the `padOp` has a nofold attribute and all paddings are known to be 0,
   // explicitly insert a `linalg.copy`.
index 4736165..760b14e 100644 (file)
@@ -75,7 +75,7 @@ struct FusePadOp : OpRewritePattern<tensor::PadOp> {
 
     // Create the tensor of same size as output of the pad op.
     RankedTensorType padResultType = padOp.getResultType();
-    auto resultSizes = getAsOpFoldResult(resultShape[0]);
+    auto resultSizes = resultShape[0];
     auto emptyTensor = rewriter.create<tensor::EmptyOp>(
         loc, resultSizes, padResultType.getElementType());
 
index 2c1c56d..01f2c17 100644 (file)
@@ -204,7 +204,7 @@ linalg::rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
     newOperands.push_back(*paddedOperand);
   }
 
-  SmallVector<SmallVector<Value>> reifiedResultShapes;
+  ReifiedRankedShapedTypeDims reifiedResultShapes;
   if (failed(cast<ReifyRankedShapedTypeOpInterface>(opToPad.getOperation())
                  .reifyResultShapes(rewriter, reifiedResultShapes))) {
     LLVM_DEBUG(DBGS() << "--failed to reify result shapes -> FAIL\n");
@@ -231,11 +231,10 @@ linalg::rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
     int64_t rank = paddedResult.getType().cast<RankedTensorType>().getRank();
     SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
     SmallVector<OpFoldResult> sizes;
-    for (Value v : reifiedResultShapes[resultNumber])
-      sizes.push_back(getAsOpFoldResult(v));
     SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
     paddedSubtensorResults.push_back(rewriter.create<tensor::ExtractSliceOp>(
-        loc, paddedResult, offsets, sizes, strides));
+        loc, paddedResult, offsets, reifiedResultShapes[resultNumber],
+        strides));
   }
   return paddedSubtensorResults;
 }
index 650d71e..2a18c55 100644 (file)
@@ -15,6 +15,7 @@
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
@@ -92,7 +93,7 @@ struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
     if (!dimIndex)
       return failure();
 
-    SmallVector<SmallVector<Value>> reifiedResultShapes;
+    ReifiedRankedShapedTypeDims reifiedResultShapes;
     if (failed(
             rankedShapeTypeOp.reifyResultShapes(rewriter, reifiedResultShapes)))
       return failure();
@@ -106,7 +107,10 @@ struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
         static_cast<size_t>(sourceType.getRank()))
       return failure();
 
-    rewriter.replaceOp(dimOp, reifiedResultShapes[resultNumber][*dimIndex]);
+    rewriter.replaceOp(dimOp,
+                       getValueOrCreateConstantIndexOp(
+                           rewriter, dimOp.getLoc(),
+                           reifiedResultShapes[resultNumber][*dimIndex]));
     return success();
   }
 };
index 84fbae3..eb73a2c 100644 (file)
@@ -38,10 +38,12 @@ getExpandedDimToCollapsedDimMap(ArrayRef<AffineMap> reassociation) {
 /// terms of shape of the `src`, when the reshape op is a collapsing
 /// operation. It is the product of the shape of the collapsed dimensions of the
 /// `src`.
-static OpFoldResult
-getCollapsedOutputDimFromInputShape(OpBuilder &builder, Location loc,
-                                    int64_t dimIndex, Value src,
-                                    ArrayRef<AffineMap> reassociationMap) {
+static OpFoldResult getCollapsedOutputDimFromInputShape(
+    OpBuilder &builder, Location loc, int64_t dimIndex, Value src,
+    ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociationMap) {
+  if (!ShapedType::isDynamic(dstStaticShape[dimIndex])) {
+    return builder.getIndexAttr(dstStaticShape[dimIndex]);
+  }
   AffineMap map = reassociationMap[dimIndex];
   unsigned startPos =
       map.getResults().front().cast<AffineDimExpr>().getPosition();
@@ -65,8 +67,8 @@ static SmallVector<OpFoldResult, 4> getCollapsedOutputShapeFromInputShape(
     ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation) {
   return llvm::to_vector<4>(llvm::map_range(
       llvm::seq<int64_t>(0, dstStaticShape.size()), [&](int64_t dim) {
-        return getCollapsedOutputDimFromInputShape(builder, loc, dim, src,
-                                                   reassociation);
+        return getCollapsedOutputDimFromInputShape(
+            builder, loc, dim, src, dstStaticShape, reassociation);
       }));
 }
 
@@ -77,7 +79,7 @@ static OpFoldResult getExpandedOutputDimFromInputShape(
     ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation,
     llvm::DenseMap<int64_t, int64_t> &expandedDimToCollapsedDim) {
   if (!ShapedType::isDynamic(dstStaticShape[dimIndex])) {
-    return builder.getI64IntegerAttr(dstStaticShape[dimIndex]);
+    return builder.getIndexAttr(dstStaticShape[dimIndex]);
   }
   unsigned sourceDimPos = expandedDimToCollapsedDim[dimIndex];
   unsigned startPos = reassociation[sourceDimPos]
@@ -144,11 +146,9 @@ struct ReifyExpandOrCollapseShapeOp
                     ReifiedRankedShapedTypeDims &reifiedReturnShapes) const {
     auto loc = op->getLoc();
     auto reshapeOp = cast<OpTy>(op);
-    auto resultShape = getReshapeOutputShapeFromInputShape(
+    reifiedReturnShapes.push_back(getReshapeOutputShapeFromInputShape(
         b, loc, reshapeOp.getSrc(), reshapeOp.getResultType().getShape(),
-        reshapeOp.getReassociationMaps());
-    reifiedReturnShapes.push_back(
-        getValueOrCreateConstantIndexOp(b, loc, resultShape));
+        reshapeOp.getReassociationMaps()));
     return success();
   }
 };
@@ -165,8 +165,13 @@ struct ReifyPadOp
     Location loc = padOp.getLoc();
     auto lowPad = padOp.getMixedLowPad();
     auto highPad = padOp.getMixedHighPad();
-    SmallVector<Value> shapes;
+    SmallVector<OpFoldResult> shapes;
     for (auto dim : llvm::seq<int64_t>(0, padOp.getSourceType().getRank())) {
+      if (!padOp.getResultType().isDynamicDim(dim)) {
+        shapes.push_back(b.getIndexAttr(padOp.getResultType().getDimSize(dim)));
+        continue;
+      }
+
       // Shape along each dimension is source dim + low pad + high pad.
       SmallVector<Value> mapOperands;
       mapOperands.push_back(
index 2e9fcb1..2253133 100644 (file)
@@ -87,7 +87,7 @@ FailureOr<Value> tensor::getOrCreateDestination(OpBuilder &b, Location loc,
       return failure();
     if (failed(reifyShapedTypeInterface.reifyResultShapes(b, reifiedShapes)))
       return failure();
-    mixedSizes = getAsOpFoldResult(reifiedShapes[opResult.getResultNumber()]);
+    mixedSizes = reifiedShapes[opResult.getResultNumber()];
   } else {
     // Static shape: Take static sizes directly.
     for (int64_t sz : tensorType.getShape())
@@ -523,14 +523,13 @@ LogicalResult EmptyOp::verify() {
 LogicalResult
 EmptyOp::reifyResultShapes(OpBuilder &builder,
                            ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
-  reifiedReturnShapes.resize(1, SmallVector<Value>(getType().getRank()));
+  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
   unsigned ctr = 0;
   for (int64_t i = 0; i < getType().getRank(); ++i) {
     if (getType().isDynamicDim(i)) {
       reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
     } else {
-      reifiedReturnShapes[0][i] =
-          builder.create<arith::ConstantIndexOp>(getLoc(), i);
+      reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
     }
   }
   return success();
@@ -1004,14 +1003,14 @@ void GenerateOp::getAsmResultNames(
 
 LogicalResult GenerateOp::reifyResultShapes(
     OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
-  reifiedReturnShapes.resize(1, SmallVector<Value>(getType().getRank()));
+  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
   int idx = 0;
   for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
     if (getType().isDynamicDim(dim)) {
       reifiedReturnShapes[0][dim] = getOperand(idx++);
     } else {
-      reifiedReturnShapes[0][dim] = builder.create<arith::ConstantIndexOp>(
-          getLoc(), getType().getDimSize(dim));
+      reifiedReturnShapes[0][dim] =
+          builder.getIndexAttr(getType().getDimSize(dim));
     }
   }
   return success();
@@ -1787,16 +1786,10 @@ LogicalResult ExtractSliceOp::reifyResultShapes(
   reifiedReturnShapes[0].reserve(getType().getRank());
   SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
   llvm::SmallBitVector droppedDims = getDroppedDims();
-  Location loc = getLoc();
   for (const auto &size : enumerate(mixedSizes)) {
     if (droppedDims.test(size.index()))
       continue;
-    if (auto attr = size.value().dyn_cast<Attribute>()) {
-      reifiedReturnShapes[0].push_back(builder.create<arith::ConstantIndexOp>(
-          loc, attr.cast<IntegerAttr>().getInt()));
-      continue;
-    }
-    reifiedReturnShapes[0].push_back(size.value().get<Value>());
+    reifiedReturnShapes[0].push_back(size.value());
   }
   return success();
 }
@@ -2210,7 +2203,7 @@ OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
 
 LogicalResult InsertSliceOp::reifyResultShapes(
     OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
-  reifiedReturnShapes.resize(1, SmallVector<Value>(getType().getRank()));
+  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
   for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
     reifiedReturnShapes[0][dim] =
         builder.createOrFold<tensor::DimOp>(getLoc(), getDest(), dim);
@@ -3160,7 +3153,7 @@ reifyResultShapesImpl(OpTy op, OpBuilder &builder,
   static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
                 "applies to only pack or unpack operations");
   int64_t destRank = op.getDestRank();
-  reifiedReturnShapes.resize(1, SmallVector<Value>(destRank));
+  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(destRank));
   for (auto dim : llvm::seq<int64_t>(0, destRank)) {
     reifiedReturnShapes[0][dim] =
         builder.createOrFold<tensor::DimOp>(op.getLoc(), op.getDest(), dim);
index 33e698f..457f261 100644 (file)
@@ -219,7 +219,7 @@ struct PackOpTiling
     (void)packOp.reifyResultShapes(b, outputShape);
     resultSizes.assign(sizes.begin(), sizes.end());
     for (auto dataTileDim : llvm::seq<unsigned>(inputRank, outputRank))
-      resultSizes.push_back(getAsOpFoldResult(outputShape[0][dataTileDim]));
+      resultSizes.push_back(outputShape[0][dataTileDim]);
 
     return success();
   }
index 66cbd64..f9512fd 100644 (file)
@@ -33,9 +33,8 @@ struct FoldEmptyTensorWithReshapeOp : public OpRewritePattern<ReshapeOp> {
         !llvm::hasSingleElement(resultShapes))
       return failure();
     // TODO: Do not drop tensor type encoding.
-    Value emptyTensor =
-        rewriter.create<EmptyOp>(loc, getAsOpFoldResult(resultShapes[0]),
-                                 reshapeOp.getResultType().getElementType());
+    Value emptyTensor = rewriter.create<EmptyOp>(
+        loc, resultShapes[0], reshapeOp.getResultType().getElementType());
     if (emptyTensor.getType() != reshapeOp.getResultType()) {
       rewriter.replaceOpWithNewOp<tensor::CastOp>(
           reshapeOp, reshapeOp.getResultType(), emptyTensor);
index fbae33e..0ef8729 100644 (file)
@@ -116,8 +116,7 @@ tensor::ExtractSliceFromCollapseHelper::create(OpBuilder &b,
       dyn_cast<ReifyRankedShapedTypeOpInterface>(op.getOperation());
   if (failed(reifyShapedTypeInterface.reifyResultShapes(b, reifiedShapes)))
     return failure();
-  SmallVector<OpFoldResult> collapseShapeOutputShape =
-      getAsOpFoldResult(reifiedShapes[0]);
+  SmallVector<OpFoldResult> &collapseShapeOutputShape = reifiedShapes[0];
   SmallVector<ReassociationIndices> reassociationIndices =
       op.getReassociationIndices();
 
index f9037fe..be375e4 100644 (file)
@@ -193,7 +193,7 @@ struct RewriteExtractSliceFromCollapseShapeBase
 
     // Create the destination tensor using the above values.
     Type elementType = op.getSourceType().getElementType();
-    SmallVector<OpFoldResult> outputShape = getAsOpFoldResult(reifiedShapes[0]);
+    SmallVector<OpFoldResult> outputShape = reifiedShapes[0];
     Value dest = rewriter.create<tensor::EmptyOp>(op->getLoc(), outputShape,
                                                   elementType);
 
index d273497..5bafede 100644 (file)
@@ -1244,7 +1244,7 @@ LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes(
     auto currShape = llvm::to_vector<4>(llvm::map_range(
         llvm::seq<int64_t>(
             0, operand.getType().cast<RankedTensorType>().getRank()),
-        [&](int64_t dim) -> Value {
+        [&](int64_t dim) -> OpFoldResult {
           return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
         }));
     shapes.emplace_back(std::move(currShape));