From 758329dc7cd3b0da835a4f865b89003263050080 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Fri, 10 Mar 2023 11:25:15 +0100 Subject: [PATCH] [mlir][NFC] reifyResultShapes: Add extra error checking This change adds a new helper function `mlir::reifyResultShapes` that calls the corresponding interface method and also checks the result produced by the implementation when running in debug mode. Bugs due to incorrect interface implementations can be difficult to debug. This helper function also reduces the amount of code needed at call sites: the cast to `ReifyRankedShapedTypeOpInterface` is done in the helper function. Differential Revision: https://reviews.llvm.org/D145777 --- .../include/mlir/Interfaces/InferTypeOpInterface.h | 6 +++ .../Bufferization/IR/BufferizableOpInterface.cpp | 20 +++++----- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 4 +- .../Transforms/ConvertToDestinationStyle.cpp | 25 ++++++------- .../Transforms/FusePadOpWithLinalgProducer.cpp | 5 +-- mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp | 3 +- .../Transforms/ResolveShapedTypeResultDims.cpp | 25 +++---------- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 6 +-- .../Tensor/IR/TensorTilingInterfaceImpl.cpp | 9 ++--- .../Dialect/Tensor/Transforms/EmptyOpPatterns.cpp | 5 +-- .../Transforms/ExtractSliceFromReshapeUtils.cpp | 4 +- mlir/lib/Interfaces/InferTypeOpInterface.cpp | 43 ++++++++++++++++++++++ .../lib/Dialect/Tensor/TestTensorTransforms.cpp | 2 +- mlir/test/lib/Dialect/Test/TestDialect.cpp | 11 ++++-- 14 files changed, 92 insertions(+), 76 deletions(-) diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h index 42f5ec4..b63d8b6 100644 --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h @@ -28,6 +28,12 @@ namespace mlir { class ShapedTypeComponents; using ReifiedRankedShapedTypeDims = SmallVector>; +/// Reify the shape of the result of an operation (typically in terms of the +/// shape of its operands). +LogicalResult +reifyResultShapes(OpBuilder &b, Operation *op, + ReifiedRankedShapedTypeDims &reifiedReturnShapes); + /// Adaptor class to abstract the differences between whether value is from /// a ShapedType or ShapedTypeComponents or DenseIntElementsAttribute. class ShapeAdaptor { diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp index 0f119d5..3b965cf 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -138,17 +138,15 @@ FailureOr bufferization::allocateTensorForShapedValue( bool reifiedShapes = false; if (shapedValue.getType().isa() && shapedValue.isa()) { - if (auto rankedOp = dyn_cast_or_null( - shapedValue.getDefiningOp())) { - ReifiedRankedShapedTypeDims resultDims; - if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) { - reifiedShapes = true; - auto &shape = - resultDims[shapedValue.cast().getResultNumber()]; - for (const auto &dim : enumerate(tensorType.getShape())) - if (ShapedType::isDynamic(dim.value())) - dynamicSizes.push_back(shape[dim.index()].get()); - } + ReifiedRankedShapedTypeDims resultDims; + if (succeeded( + reifyResultShapes(b, shapedValue.getDefiningOp(), resultDims))) { + reifiedShapes = true; + auto &shape = + resultDims[shapedValue.cast().getResultNumber()]; + for (const auto &dim : enumerate(tensorType.getShape())) + if (ShapedType::isDynamic(dim.value())) + dynamicSizes.push_back(shape[dim.index()].get()); } } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index b33d989..f6a5879 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -482,9 +482,7 @@ struct FoldFillWithPad final : public OpRewritePattern { return failure(); ReifiedRankedShapedTypeDims reifiedShape; - ReifyRankedShapedTypeOpInterface interface = - cast(padOp.getOperation()); - if (failed(interface.reifyResultShapes(rewriter, reifiedShape))) + if (failed(reifyResultShapes(rewriter, padOp, reifiedShape))) return rewriter.notifyMatchFailure( padOp, "failed to reify tensor.pad op result shape"); diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp index 700f873..3ec5094 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp @@ -125,19 +125,17 @@ static SmallVector reifyOrComputeDynamicSizes(OpBuilder &b, return {}; // Try to reify dynamic sizes. - if (auto reifiableOp = - value.getDefiningOp()) { - ReifiedRankedShapedTypeDims reifiedShape; - if (succeeded(reifiableOp.reifyResultShapes(b, reifiedShape))) { - SmallVector dynSizes; - for (int64_t i = 0; i < tensorType.getRank(); ++i) { - if (tensorType.isDynamicDim(i)) - dynSizes.push_back( - reifiedShape[value.cast().getResultNumber()][i] - .get()); - } - return dynSizes; + ReifiedRankedShapedTypeDims reifiedShape; + if (value.isa() && + succeeded(reifyResultShapes(b, value.getDefiningOp(), reifiedShape))) { + SmallVector dynSizes; + for (int64_t i = 0; i < tensorType.getRank(); ++i) { + if (tensorType.isDynamicDim(i)) + dynSizes.push_back( + reifiedShape[value.cast().getResultNumber()][i] + .get()); } + return dynSizes; } // Create tensor.dim ops. @@ -293,8 +291,7 @@ mlir::linalg::rewriteInDestinationPassingStyle(RewriterBase &rewriter, Location loc = padOp.getLoc(); RankedTensorType resultType = padOp.getResultType(); ReifiedRankedShapedTypeDims reifiedShape; - if (failed(cast(padOp.getOperation()) - .reifyResultShapes(rewriter, reifiedShape))) + if (failed(reifyResultShapes(rewriter, padOp, reifiedShape))) return rewriter.notifyMatchFailure( padOp, "failed to reify tensor.pad op result shape"); SmallVector dynamicSizes; diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp index 760b14e..b6e2ffc 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp @@ -62,10 +62,7 @@ struct FusePadOp : OpRewritePattern { padOp, "only supported for ops with all parallel iterator types"); } ReifiedRankedShapedTypeDims resultShape; - ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface = - dyn_cast(padOp.getOperation()); - if (failed(reifyShapedTypeInterface.reifyResultShapes(rewriter, - resultShape)) || + if (failed(reifyResultShapes(rewriter, padOp, resultShape)) || resultShape.size() != 1) { return rewriter.notifyMatchFailure( padOp, "failed to get shape of pad op result"); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 9de0f76..2ba1562 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -205,8 +205,7 @@ linalg::rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, } ReifiedRankedShapedTypeDims reifiedResultShapes; - if (failed(cast(opToPad.getOperation()) - .reifyResultShapes(rewriter, reifiedResultShapes))) { + if (failed(reifyResultShapes(rewriter, opToPad, reifiedResultShapes))) { LLVM_DEBUG(DBGS() << "--failed to reify result shapes -> FAIL\n"); return rewriter.notifyMatchFailure(opToPad, "failed to reify result shapes"); diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp index 2a18c55..50ac04d 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp @@ -84,33 +84,18 @@ struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern { OpResult dimValue = dimOp.getSource().template dyn_cast(); if (!dimValue) return failure(); - auto rankedShapeTypeOp = - dyn_cast(dimValue.getOwner()); - if (!rankedShapeTypeOp) - return failure(); - std::optional dimIndex = dimOp.getConstantIndex(); if (!dimIndex) return failure(); ReifiedRankedShapedTypeDims reifiedResultShapes; - if (failed( - rankedShapeTypeOp.reifyResultShapes(rewriter, reifiedResultShapes))) - return failure(); - - if (reifiedResultShapes.size() != rankedShapeTypeOp->getNumResults()) + if (failed(reifyResultShapes(rewriter, dimValue.getOwner(), + reifiedResultShapes))) return failure(); - unsigned resultNumber = dimValue.getResultNumber(); - auto sourceType = dimValue.getType().dyn_cast(); - if (reifiedResultShapes[resultNumber].size() != - static_cast(sourceType.getRank())) - return failure(); - - rewriter.replaceOp(dimOp, - getValueOrCreateConstantIndexOp( - rewriter, dimOp.getLoc(), - reifiedResultShapes[resultNumber][*dimIndex])); + Value replacement = getValueOrCreateConstantIndexOp( + rewriter, dimOp.getLoc(), reifiedResultShapes[resultNumber][*dimIndex]); + rewriter.replaceOp(dimOp, replacement); return success(); } }; diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index e1bf889..5755ddf 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -81,11 +81,7 @@ FailureOr tensor::getOrCreateDestination(OpBuilder &b, Location loc, if (!tensorType.hasStaticShape()) { // Dynamic shape: Query ReifyRankedShapedTypeOpInterface. ReifiedRankedShapedTypeDims reifiedShapes; - ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface = - dyn_cast(opResult.getDefiningOp()); - if (!reifyShapedTypeInterface) - return failure(); - if (failed(reifyShapedTypeInterface.reifyResultShapes(b, reifiedShapes))) + if (failed(reifyResultShapes(b, opResult.getDefiningOp(), reifiedShapes))) return failure(); mixedSizes = reifiedShapes[opResult.getResultNumber()]; } else { diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp index 457f261..ecebf21 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp @@ -34,10 +34,7 @@ struct PadOpTiling : public TilingInterface::ExternalModel { SmallVector getIterationDomain(Operation *op, OpBuilder &b) const { ReifiedRankedShapedTypeDims reifiedShapes; - ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface = - dyn_cast(op); - (void)reifyShapedTypeInterface.reifyResultShapes(b, reifiedShapes); - + (void)reifyResultShapes(b, op, reifiedShapes); Location loc = op->getLoc(); Value zero = b.create(loc, 0); Value one = b.create(loc, 1); @@ -84,7 +81,7 @@ static SmallVector getPackUnPackIterationDomain(OpTy op, Value zero = builder.create(loc, 0); Value one = builder.create(loc, 1); ReifiedRankedShapedTypeDims resultShape; - (void)op.reifyResultShapes(builder, resultShape); + (void)reifyResultShapes(builder, op, resultShape); SmallVector loopBounds(rank); for (auto dim : llvm::seq(0, rank)) { loopBounds[dim].offset = zero; @@ -216,7 +213,7 @@ struct PackOpTiling resultOffsets.append(outputRank - inputRank, zeroAttr); ReifiedRankedShapedTypeDims outputShape; - (void)packOp.reifyResultShapes(b, outputShape); + (void)reifyResultShapes(b, packOp, outputShape); resultSizes.assign(sizes.begin(), sizes.end()); for (auto dataTileDim : llvm::seq(inputRank, outputRank)) resultSizes.push_back(outputShape[0][dataTileDim]); diff --git a/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp index f9512fd..99679bc 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp @@ -26,10 +26,7 @@ struct FoldEmptyTensorWithReshapeOp : public OpRewritePattern { return failure(); Location loc = reshapeOp.getLoc(); ReifiedRankedShapedTypeDims resultShapes; - ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface = - cast(reshapeOp.getOperation()); - if (failed(reifyShapedTypeInterface.reifyResultShapes(rewriter, - resultShapes)) || + if (failed(reifyResultShapes(rewriter, reshapeOp, resultShapes)) || !llvm::hasSingleElement(resultShapes)) return failure(); // TODO: Do not drop tensor type encoding. diff --git a/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp b/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp index 0ef8729..f1ad357 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp @@ -112,9 +112,7 @@ tensor::ExtractSliceFromCollapseHelper::create(OpBuilder &b, // Materialize the output shape of the collapse_shape operation. This will // create IR describing the output shape in terms of the input shape. ReifiedRankedShapedTypeDims reifiedShapes; - ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface = - dyn_cast(op.getOperation()); - if (failed(reifyShapedTypeInterface.reifyResultShapes(b, reifiedShapes))) + if (failed(reifyResultShapes(b, op, reifiedShapes))) return failure(); SmallVector &collapseShapeOutputShape = reifiedShapes[0]; SmallVector reassociationIndices = diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp index b76f236..7d464af 100644 --- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp +++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp @@ -22,6 +22,49 @@ namespace mlir { #include "mlir/Interfaces/InferTypeOpInterface.cpp.inc" } // namespace mlir +LogicalResult +mlir::reifyResultShapes(OpBuilder &b, Operation *op, + ReifiedRankedShapedTypeDims &reifiedReturnShapes) { + auto reifiableOp = dyn_cast(op); + if (!reifiableOp) + return failure(); + LogicalResult status = reifiableOp.reifyResultShapes(b, reifiedReturnShapes); +#ifndef NDEBUG + if (failed(status)) + return failure(); + // Assert that ReifyRankedShapedTypeOpInterface::reifyResultShapes produced + // a correct result. + int64_t resultIdx = 0; + for (OpResult result : op->getResults()) { + auto shapedType = result.getType().dyn_cast(); + if (!shapedType) + continue; + if (!shapedType.hasRank()) { + // Nothing to check for unranked shaped values. + ++resultIdx; + continue; + } + // Assert one OpFoldResult per dimension. + assert(shapedType.getRank() == + static_cast(reifiedReturnShapes[resultIdx].size()) && + "incorrect implementation of ReifyRankedShapedTypeOpInterface"); + for (int64_t dim = 0; dim < shapedType.getRank(); ++dim) { + // reifyResultShapes must return: + // * Attribute for static dimensions + // * Value for dynamic dimensions + assert(shapedType.isDynamicDim(dim) == + reifiedReturnShapes[resultIdx][dim].is() && + "incorrect implementation of ReifyRankedShapedTypeOpInterface"); + } + ++resultIdx; + } + // Assert that every shaped value result was reified. + assert(resultIdx == static_cast(reifiedReturnShapes.size()) && + "incorrect implementation of ReifyRankedShapedTypeOpInterface"); +#endif // NDEBUG + return status; +} + bool ShapeAdaptor::hasRank() const { if (val.isNull()) return false; diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp index be375e4..3898892 100644 --- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp +++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp @@ -188,7 +188,7 @@ struct RewriteExtractSliceFromCollapseShapeBase // Materialize the output shape values of the slice operation. ReifiedRankedShapedTypeDims reifiedShapes; - if (failed(op.reifyResultShapes(rewriter, reifiedShapes))) + if (failed(reifyResultShapes(rewriter, op, reifiedShapes))) return rewriter.notifyMatchFailure(op, "failed to reify result shapes"); // Create the destination tensor using the above values. diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp index 5bafede..09f2fcc 100644 --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -1241,11 +1241,16 @@ LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes( Location loc = getLoc(); shapes.reserve(getNumOperands()); for (Value operand : llvm::reverse(getOperands())) { + auto tensorType = operand.getType().cast(); auto currShape = llvm::to_vector<4>(llvm::map_range( - llvm::seq( - 0, operand.getType().cast().getRank()), + llvm::seq(0, tensorType.getRank()), [&](int64_t dim) -> OpFoldResult { - return builder.createOrFold(loc, operand, dim); + return tensorType.isDynamicDim(dim) + ? static_cast( + builder.createOrFold(loc, operand, + dim)) + : static_cast( + builder.getIndexAttr(tensorType.getDimSize(dim))); })); shapes.emplace_back(std::move(currShape)); } -- 2.7.4