From b1d2687501f87d7158289a90a864ddf32b843d49 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Mon, 17 Jul 2023 16:08:08 +0200 Subject: [PATCH] [mlir][IR] Remove duplicate `isLastMemrefDimUnitStride` functions This function is duplicated in various dialects. Differential Revision: https://reviews.llvm.org/D155462 --- mlir/include/mlir/Dialect/Vector/IR/VectorOps.h | 4 ---- mlir/include/mlir/IR/BuiltinTypes.h | 6 +++++- mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp | 8 +++----- mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp | 8 -------- mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 11 ----------- mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp | 11 ----------- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 9 --------- mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp | 4 ++-- mlir/lib/IR/BuiltinTypes.cpp | 8 +++++++- 9 files changed, 17 insertions(+), 52 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h index 49a2351..4a624bd 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h @@ -90,10 +90,6 @@ ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef values); Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, Location loc, Value vector); -/// Return true if the last dimension of the MemRefType has unit stride. Also -/// return true for memrefs with no strides. -bool isLastMemrefDimUnitStride(MemRefType type); - /// Build the default minor identity map suitable for a vector transfer. This /// also handles the case memref<... x vector<...>> -> vector<...> in which the /// rank of the identity map must take the vector element type into account. diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h index c4a3c3e..de363fc 100644 --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -534,9 +534,13 @@ AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef sizes, AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef sizes, MLIRContext *context); -/// Return true if the layout for `t` is compatible with strided semantics. +/// Return "true" if the layout for `t` is compatible with strided semantics. bool isStrided(MemRefType t); +/// Return "true" if the last dimension of the given type has a static unit +/// stride. Also return "true" for types with no strides. +bool isLastMemrefDimUnitStride(MemRefType type); + } // namespace mlir #endif // MLIR_IR_BUILTINTYPES_H diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 9901385..d0c0d8f 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -92,13 +92,11 @@ LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, // Check if the last stride is non-unit or the memory space is not zero. static LogicalResult isMemRefTypeSupported(MemRefType memRefType, LLVMTypeConverter &converter) { - int64_t offset; - SmallVector strides; - auto successStrides = getStridesAndOffset(memRefType, strides, offset); + if (!isLastMemrefDimUnitStride(memRefType)) + return failure(); FailureOr addressSpace = converter.getMemRefAddressSpace(memRefType); - if (failed(successStrides) || strides.back() != 1 || failed(addressSpace) || - *addressSpace != 0) + if (failed(addressSpace) || *addressSpace != 0) return failure(); return success(); } diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp index 6936613..fc274c9 100644 --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -1185,14 +1185,6 @@ struct Strategy1d { } }; -/// Return true if the last dimension of the MemRefType has unit stride. -static bool isLastMemrefDimUnitStride(MemRefType type) { - int64_t offset; - SmallVector strides; - auto successStrides = getStridesAndOffset(type, strides, offset); - return succeeded(successStrides) && (strides.empty() || strides.back() == 1); -} - /// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is /// necessary in cases where a 1D vector transfer op cannot be lowered into /// vector load/stores due to non-unit strides or broadcasts: diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index 414b54f..f809a96 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -1546,17 +1546,6 @@ void MemcpyOp::getCanonicalizationPatterns(RewritePatternSet &results, // GPU_SubgroupMmaLoadMatrixOp //===----------------------------------------------------------------------===// -/// Return true if the last dimension of the MemRefType has unit stride. Also -/// return true for memrefs with no strides. -static bool isLastMemrefDimUnitStride(MemRefType type) { - int64_t offset; - SmallVector strides; - if (failed(getStridesAndOffset(type, strides, offset))) { - return false; - } - return strides.back() == 1; -} - LogicalResult SubgroupMmaLoadMatrixOp::verify() { auto srcType = getSrcMemref().getType(); auto resType = getRes().getType(); diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp index c3a62f4..2868660 100644 --- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp +++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp @@ -53,17 +53,6 @@ bool nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(MemRefType type) { // NVGPU_DeviceAsyncCopyOp //===----------------------------------------------------------------------===// -/// Return true if the last dimension of the MemRefType has unit stride. Also -/// return true for memrefs with no strides. -static bool isLastMemrefDimUnitStride(MemRefType type) { - int64_t offset; - SmallVector strides; - if (failed(getStridesAndOffset(type, strides, offset))) { - return false; - } - return strides.back() == 1; -} - LogicalResult DeviceAsyncCopyOp::verify() { auto srcMemref = llvm::cast(getSrc().getType()); auto dstMemref = llvm::cast(getDst().getType()); diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index b64aec1..e4cf54c 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -130,15 +130,6 @@ static bool isSupportedCombiningKind(CombiningKind combiningKind, return false; } -/// Return true if the last dimension of the MemRefType has unit stride. Also -/// return true for memrefs with no strides. -bool mlir::vector::isLastMemrefDimUnitStride(MemRefType type) { - int64_t offset; - SmallVector strides; - auto successStrides = getStridesAndOffset(type, strides, offset); - return succeeded(successStrides) && (strides.empty() || strides.back() == 1); -} - AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType) { int64_t elementVectorRank = 0; diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp index c880eb4..9589482 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp @@ -419,7 +419,7 @@ struct TransferReadToVectorLoadLowering return rewriter.notifyMatchFailure(read, "not a memref source"); // Non-unit strides are handled by VectorToSCF. - if (!vector::isLastMemrefDimUnitStride(memRefType)) + if (!isLastMemrefDimUnitStride(memRefType)) return rewriter.notifyMatchFailure(read, "!= 1 stride needs VectorToSCF"); // If there is broadcasting involved then we first load the unbroadcasted @@ -567,7 +567,7 @@ struct TransferWriteToVectorStoreLowering }); // Non-unit strides are handled by VectorToSCF. - if (!vector::isLastMemrefDimUnitStride(memRefType)) + if (!isLastMemrefDimUnitStride(memRefType)) return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { diag << "most minor stride is not 1: " << write; }); diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp index b5ebaa0..60cff9d 100644 --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -956,10 +956,16 @@ AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef sizes, return makeCanonicalStridedLayoutExpr(sizes, exprs, context); } -/// Return true if the layout for `t` is compatible with strided semantics. bool mlir::isStrided(MemRefType t) { int64_t offset; SmallVector strides; auto res = getStridesAndOffset(t, strides, offset); return succeeded(res); } + +bool mlir::isLastMemrefDimUnitStride(MemRefType type) { + int64_t offset; + SmallVector strides; + auto successStrides = getStridesAndOffset(type, strides, offset); + return succeeded(successStrides) && (strides.empty() || strides.back() == 1); +} -- 2.7.4