From ef976337f581dd8a80820a8b14b4bbd70670b7fc Mon Sep 17 00:00:00 2001 From: River Riddle Date: Fri, 24 Sep 2021 17:51:20 +0000 Subject: [PATCH] [mlir:OpConversion] Remove the remaing usages of the deprecated matchAndRewrite methods This commits updates the remaining usages of the ArrayRef based matchAndRewrite/rewrite methods in favor of the new OpAdaptor overload. Differential Revision: https://reviews.llvm.org/D110360 --- mlir/include/mlir/Conversion/LLVMCommon/Pattern.h | 6 +- .../mlir/Conversion/LLVMCommon/VectorPattern.h | 6 +- .../lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp | 48 ++++---- mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp | 4 +- mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h | 6 +- .../Conversion/GPUCommon/GPUToLLVMConversion.cpp | 79 ++++++------- .../GPUCommon/IndexIntrinsicsOpLowering.h | 2 +- .../Conversion/GPUCommon/OpToFuncCallLowering.h | 16 ++- .../Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp | 3 +- mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp | 33 +++--- mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp | 7 +- mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp | 27 ++--- mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp | 126 +++++++++------------ .../lib/Conversion/OpenACCToLLVM/OpenACCToLLVM.cpp | 6 +- mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp | 6 +- .../SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp | 4 +- .../VectorToLLVM/ConvertVectorToLLVM.cpp | 110 ++++++++---------- .../lib/Conversion/VectorToROCDL/VectorToROCDL.cpp | 12 +- mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp | 11 +- .../StandardToLLVM/TestConvertCallOp.cpp | 2 +- mlir/test/lib/Dialect/Test/TestPatterns.cpp | 15 +-- 21 files changed, 233 insertions(+), 296 deletions(-) diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h index 81358dc..4ffd135 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h @@ -217,11 +217,11 @@ public: /// Converts the type of the result to an LLVM type, pass operands as is, /// preserve attributes. LogicalResult - matchAndRewrite(SourceOp op, ArrayRef operands, + matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { return LLVM::detail::oneToOneRewrite(op, TargetOp::getOperationName(), - operands, *this->getTypeConverter(), - rewriter); + adaptor.getOperands(), + *this->getTypeConverter(), rewriter); } }; diff --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h index 383516a..7eba83c 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h @@ -70,14 +70,14 @@ public: using Super = VectorConvertToLLVMPattern; LogicalResult - matchAndRewrite(SourceOp op, ArrayRef operands, + matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { static_assert( std::is_base_of, SourceOp>::value, "expected single result op"); return LLVM::detail::vectorOneToOneRewrite( - op, TargetOp::getOperationName(), operands, *this->getTypeConverter(), - rewriter); + op, TargetOp::getOperationName(), adaptor.getOperands(), + *this->getTypeConverter(), rewriter); } }; } // namespace mlir diff --git a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp index a2ba4c5..129b9e6 100644 --- a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp +++ b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp @@ -58,12 +58,11 @@ struct AbsOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(complex::AbsOp op, ArrayRef operands, + matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - complex::AbsOp::Adaptor transformed(operands); auto loc = op.getLoc(); - ComplexStructBuilder complexStruct(transformed.complex()); + ComplexStructBuilder complexStruct(adaptor.complex()); Value real = complexStruct.real(rewriter, op.getLoc()); Value imag = complexStruct.imaginary(rewriter, op.getLoc()); @@ -81,16 +80,14 @@ struct CreateOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(complex::CreateOp complexOp, ArrayRef operands, + matchAndRewrite(complex::CreateOp complexOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - complex::CreateOp::Adaptor transformed(operands); - // Pack real and imaginary part in a complex number struct. auto loc = complexOp.getLoc(); auto structType = typeConverter->convertType(complexOp.getType()); auto complexStruct = ComplexStructBuilder::undef(rewriter, loc, structType); - complexStruct.setReal(rewriter, loc, transformed.real()); - complexStruct.setImaginary(rewriter, loc, transformed.imaginary()); + complexStruct.setReal(rewriter, loc, adaptor.real()); + complexStruct.setImaginary(rewriter, loc, adaptor.imaginary()); rewriter.replaceOp(complexOp, {complexStruct}); return success(); @@ -101,12 +98,10 @@ struct ReOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(complex::ReOp op, ArrayRef operands, + matchAndRewrite(complex::ReOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - complex::ReOp::Adaptor transformed(operands); - // Extract real part from the complex number struct. - ComplexStructBuilder complexStruct(transformed.complex()); + ComplexStructBuilder complexStruct(adaptor.complex()); Value real = complexStruct.real(rewriter, op.getLoc()); rewriter.replaceOp(op, real); @@ -118,12 +113,10 @@ struct ImOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(complex::ImOp op, ArrayRef operands, + matchAndRewrite(complex::ImOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - complex::ImOp::Adaptor transformed(operands); - // Extract imaginary part from the complex number struct. - ComplexStructBuilder complexStruct(transformed.complex()); + ComplexStructBuilder complexStruct(adaptor.complex()); Value imaginary = complexStruct.imaginary(rewriter, op.getLoc()); rewriter.replaceOp(op, imaginary); @@ -138,17 +131,16 @@ struct BinaryComplexOperands { template BinaryComplexOperands -unpackBinaryComplexOperands(OpTy op, ArrayRef operands, +unpackBinaryComplexOperands(OpTy op, typename OpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter) { auto loc = op.getLoc(); - typename OpTy::Adaptor transformed(operands); // Extract real and imaginary values from operands. BinaryComplexOperands unpacked; - ComplexStructBuilder lhs(transformed.lhs()); + ComplexStructBuilder lhs(adaptor.lhs()); unpacked.lhs.real(lhs.real(rewriter, loc)); unpacked.lhs.imag(lhs.imaginary(rewriter, loc)); - ComplexStructBuilder rhs(transformed.rhs()); + ComplexStructBuilder rhs(adaptor.rhs()); unpacked.rhs.real(rhs.real(rewriter, loc)); unpacked.rhs.imag(rhs.imaginary(rewriter, loc)); @@ -159,11 +151,11 @@ struct AddOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(complex::AddOp op, ArrayRef operands, + matchAndRewrite(complex::AddOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); BinaryComplexOperands arg = - unpackBinaryComplexOperands(op, operands, rewriter); + unpackBinaryComplexOperands(op, adaptor, rewriter); // Initialize complex number struct for result. auto structType = typeConverter->convertType(op.getType()); @@ -187,11 +179,11 @@ struct DivOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(complex::DivOp op, ArrayRef operands, + matchAndRewrite(complex::DivOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); BinaryComplexOperands arg = - unpackBinaryComplexOperands(op, operands, rewriter); + unpackBinaryComplexOperands(op, adaptor, rewriter); // Initialize complex number struct for result. auto structType = typeConverter->convertType(op.getType()); @@ -232,11 +224,11 @@ struct MulOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(complex::MulOp op, ArrayRef operands, + matchAndRewrite(complex::MulOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); BinaryComplexOperands arg = - unpackBinaryComplexOperands(op, operands, rewriter); + unpackBinaryComplexOperands(op, adaptor, rewriter); // Initialize complex number struct for result. auto structType = typeConverter->convertType(op.getType()); @@ -269,11 +261,11 @@ struct SubOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(complex::SubOp op, ArrayRef operands, + matchAndRewrite(complex::SubOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); BinaryComplexOperands arg = - unpackBinaryComplexOperands(op, operands, rewriter); + unpackBinaryComplexOperands(op, adaptor, rewriter); // Initialize complex number struct for result. auto structType = typeConverter->convertType(op.getType()); diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp index 98aa0e0..c8fc7b2 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp @@ -14,10 +14,8 @@ using namespace mlir; LogicalResult -GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, - ArrayRef operands, +GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - assert(operands.empty() && "func op is not expected to have operands"); Location loc = gpuFuncOp.getLoc(); SmallVector workgroupBuffers; diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h index 9d54001d..72805dc 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h @@ -22,7 +22,7 @@ struct GPUFuncOpLowering : ConvertOpToLLVMPattern { kernelAttributeName(kernelAttributeName) {} LogicalResult - matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, ArrayRef operands, + matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; private: @@ -37,9 +37,9 @@ struct GPUReturnOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(gpu::ReturnOp op, ArrayRef operands, + matchAndRewrite(gpu::ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, operands); + rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); return success(); } }; diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp index 0b33ab4..40a4463 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -195,7 +195,7 @@ public: private: LogicalResult - matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, ArrayRef operands, + matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -209,7 +209,7 @@ public: private: LogicalResult - matchAndRewrite(gpu::AllocOp allocOp, ArrayRef operands, + matchAndRewrite(gpu::AllocOp allocOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -223,7 +223,7 @@ public: private: LogicalResult - matchAndRewrite(gpu::DeallocOp deallocOp, ArrayRef operands, + matchAndRewrite(gpu::DeallocOp deallocOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -235,7 +235,7 @@ public: private: LogicalResult - matchAndRewrite(async::YieldOp yieldOp, ArrayRef operands, + matchAndRewrite(async::YieldOp yieldOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -249,7 +249,7 @@ public: private: LogicalResult - matchAndRewrite(gpu::WaitOp waitOp, ArrayRef operands, + matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -263,7 +263,7 @@ public: private: LogicalResult - matchAndRewrite(gpu::WaitOp waitOp, ArrayRef operands, + matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -289,13 +289,13 @@ public: gpuBinaryAnnotation(gpuBinaryAnnotation) {} private: - Value generateParamsArray(gpu::LaunchFuncOp launchOp, - ArrayRef operands, OpBuilder &builder) const; + Value generateParamsArray(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor, + OpBuilder &builder) const; Value generateKernelNameConstant(StringRef moduleName, StringRef name, Location loc, OpBuilder &builder) const; LogicalResult - matchAndRewrite(gpu::LaunchFuncOp launchOp, ArrayRef operands, + matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; llvm::SmallString<32> gpuBinaryAnnotation; @@ -323,7 +323,7 @@ public: private: LogicalResult - matchAndRewrite(gpu::MemcpyOp memcpyOp, ArrayRef operands, + matchAndRewrite(gpu::MemcpyOp memcpyOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -337,7 +337,7 @@ public: private: LogicalResult - matchAndRewrite(gpu::MemsetOp memsetOp, ArrayRef operands, + matchAndRewrite(gpu::MemsetOp memsetOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; } // namespace @@ -398,10 +398,10 @@ isAsyncWithOneDependency(ConversionPatternRewriter &rewriter, } LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite( - gpu::HostRegisterOp hostRegisterOp, ArrayRef operands, + gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto *op = hostRegisterOp.getOperation(); - if (failed(areAllLLVMTypes(op, operands, rewriter))) + if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter))) return failure(); Location loc = op->getLoc(); @@ -410,8 +410,8 @@ LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite( auto elementType = memRefType.cast().getElementType(); auto elementSize = getSizeInBytes(loc, elementType, rewriter); - auto arguments = getTypeConverter()->promoteOperands(loc, op->getOperands(), - operands, rewriter); + auto arguments = getTypeConverter()->promoteOperands( + loc, op->getOperands(), adaptor.getOperands(), rewriter); arguments.push_back(elementSize); hostRegisterCallBuilder.create(loc, rewriter, arguments); @@ -420,17 +420,16 @@ LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite( } LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite( - gpu::AllocOp allocOp, ArrayRef operands, + gpu::AllocOp allocOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { MemRefType memRefType = allocOp.getType(); - if (failed(areAllLLVMTypes(allocOp, operands, rewriter)) || + if (failed(areAllLLVMTypes(allocOp, adaptor.getOperands(), rewriter)) || !isConvertibleAndHasIdentityMaps(memRefType) || failed(isAsyncWithOneDependency(rewriter, allocOp))) return failure(); auto loc = allocOp.getLoc(); - auto adaptor = gpu::AllocOpAdaptor(operands, allocOp->getAttrDictionary()); // Get shape of the memref as values: static sizes are constant // values and dynamic sizes are passed to 'alloc' as operands. @@ -462,16 +461,14 @@ LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite( } LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite( - gpu::DeallocOp deallocOp, ArrayRef operands, + gpu::DeallocOp deallocOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - if (failed(areAllLLVMTypes(deallocOp, operands, rewriter)) || + if (failed(areAllLLVMTypes(deallocOp, adaptor.getOperands(), rewriter)) || failed(isAsyncWithOneDependency(rewriter, deallocOp))) return failure(); Location loc = deallocOp.getLoc(); - auto adaptor = - gpu::DeallocOpAdaptor(operands, deallocOp->getAttrDictionary()); Value pointer = MemRefDescriptor(adaptor.memref()).allocatedPtr(rewriter, loc); auto casted = rewriter.create(loc, llvmPointerType, pointer); @@ -491,19 +488,19 @@ static bool isGpuAsyncTokenType(Value value) { // are passed as events between them. For each !gpu.async.token operand, we // create an event and record it on the stream. LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite( - async::YieldOp yieldOp, ArrayRef operands, + async::YieldOp yieldOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { if (llvm::none_of(yieldOp.operands(), isGpuAsyncTokenType)) return rewriter.notifyMatchFailure(yieldOp, "no gpu async token operand"); Location loc = yieldOp.getLoc(); - SmallVector newOperands(operands.begin(), operands.end()); + SmallVector newOperands(adaptor.getOperands()); llvm::SmallDenseSet streams; for (auto &operand : yieldOp->getOpOperands()) { if (!isGpuAsyncTokenType(operand.get())) continue; auto idx = operand.getOperandNumber(); - auto stream = operands[idx]; + auto stream = adaptor.getOperands()[idx]; auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult(0); eventRecordCallBuilder.create(loc, rewriter, {event, stream}); newOperands[idx] = event; @@ -530,14 +527,14 @@ static bool isDefinedByCallTo(Value value, StringRef functionName) { // assumes that it is not used afterwards or elsewhere. Otherwise we will get a // runtime error. Eventually, we should guarantee this property. LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite( - gpu::WaitOp waitOp, ArrayRef operands, + gpu::WaitOp waitOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { if (waitOp.asyncToken()) return rewriter.notifyMatchFailure(waitOp, "Cannot convert async op."); Location loc = waitOp.getLoc(); - for (auto operand : operands) { + for (auto operand : adaptor.getOperands()) { if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) { // The converted operand's definition created a stream. streamSynchronizeCallBuilder.create(loc, rewriter, {operand}); @@ -560,7 +557,7 @@ LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite( // Otherwise we will get a runtime error. Eventually, we should guarantee this // property. LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite( - gpu::WaitOp waitOp, ArrayRef operands, + gpu::WaitOp waitOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { if (!waitOp.asyncToken()) return rewriter.notifyMatchFailure(waitOp, "Can only convert async op."); @@ -569,7 +566,8 @@ LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite( auto insertionPoint = rewriter.saveInsertionPoint(); SmallVector events; - for (auto pair : llvm::zip(waitOp.asyncDependencies(), operands)) { + for (auto pair : + llvm::zip(waitOp.asyncDependencies(), adaptor.getOperands())) { auto operand = std::get<1>(pair); if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) { // The converted operand's definition created a stream. Insert an event @@ -611,13 +609,12 @@ LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite( // llvm.store %fieldPtr, %elementPtr // return %array Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateParamsArray( - gpu::LaunchFuncOp launchOp, ArrayRef operands, - OpBuilder &builder) const { + gpu::LaunchFuncOp launchOp, OpAdaptor adaptor, OpBuilder &builder) const { auto loc = launchOp.getLoc(); auto numKernelOperands = launchOp.getNumKernelOperands(); auto arguments = getTypeConverter()->promoteOperands( loc, launchOp.getOperands().take_back(numKernelOperands), - operands.take_back(numKernelOperands), builder); + adaptor.getOperands().take_back(numKernelOperands), builder); auto numArguments = arguments.size(); SmallVector argumentTypes; argumentTypes.reserve(numArguments); @@ -693,9 +690,9 @@ Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateKernelNameConstant( // If the op is async, the stream corresponds to the (single) async dependency // as well as the async token the op produces. LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite( - gpu::LaunchFuncOp launchOp, ArrayRef operands, + gpu::LaunchFuncOp launchOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - if (failed(areAllLLVMTypes(launchOp, operands, rewriter))) + if (failed(areAllLLVMTypes(launchOp, adaptor.getOperands(), rewriter))) return failure(); if (launchOp.asyncDependencies().size() > 1) @@ -741,14 +738,12 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite( loc, rewriter, {module.getResult(0), kernelName}); auto zero = rewriter.create(loc, llvmInt32Type, rewriter.getI32IntegerAttr(0)); - auto adaptor = - gpu::LaunchFuncOpAdaptor(operands, launchOp->getAttrDictionary()); Value stream = adaptor.asyncDependencies().empty() ? streamCreateCallBuilder.create(loc, rewriter, {}).getResult(0) : adaptor.asyncDependencies().front(); // Create array of pointers to kernel arguments. - auto kernelParams = generateParamsArray(launchOp, operands, rewriter); + auto kernelParams = generateParamsArray(launchOp, adaptor, rewriter); auto nullpointer = rewriter.create(loc, llvmPointerPointerType); launchKernelCallBuilder.create(loc, rewriter, {function.getResult(0), adaptor.gridSizeX(), @@ -775,17 +770,16 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite( } LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite( - gpu::MemcpyOp memcpyOp, ArrayRef operands, + gpu::MemcpyOp memcpyOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto memRefType = memcpyOp.src().getType().cast(); - if (failed(areAllLLVMTypes(memcpyOp, operands, rewriter)) || + if (failed(areAllLLVMTypes(memcpyOp, adaptor.getOperands(), rewriter)) || !isConvertibleAndHasIdentityMaps(memRefType) || failed(isAsyncWithOneDependency(rewriter, memcpyOp))) return failure(); auto loc = memcpyOp.getLoc(); - auto adaptor = gpu::MemcpyOpAdaptor(operands, memcpyOp->getAttrDictionary()); MemRefDescriptor srcDesc(adaptor.src()); Value numElements = getNumElements(rewriter, loc, memRefType, srcDesc); @@ -812,17 +806,16 @@ LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite( } LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite( - gpu::MemsetOp memsetOp, ArrayRef operands, + gpu::MemsetOp memsetOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto memRefType = memsetOp.dst().getType().cast(); - if (failed(areAllLLVMTypes(memsetOp, operands, rewriter)) || + if (failed(areAllLLVMTypes(memsetOp, adaptor.getOperands(), rewriter)) || !isConvertibleAndHasIdentityMaps(memRefType) || failed(isAsyncWithOneDependency(rewriter, memsetOp))) return failure(); auto loc = memsetOp.getLoc(); - auto adaptor = gpu::MemsetOpAdaptor(operands, memsetOp->getAttrDictionary()); Type valueType = adaptor.value().getType(); if (!valueType.isIntOrFloat() || valueType.getIntOrFloatBitWidth() != 32) { diff --git a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h index 1f80122..416964d 100644 --- a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h +++ b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h @@ -41,7 +41,7 @@ public: // Convert the kernel arguments to an LLVM type, preserve the rest. LogicalResult - matchAndRewrite(Op op, ArrayRef operands, + matchAndRewrite(Op op, typename Op::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); MLIRContext *context = rewriter.getContext(); diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h index b8781fc..2c1c0e1 100644 --- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h +++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h @@ -37,7 +37,7 @@ public: f64Func(f64Func) {} LogicalResult - matchAndRewrite(SourceOp op, ArrayRef operands, + matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { using LLVM::LLVMFuncOp; @@ -50,7 +50,7 @@ public: "expected op with same operand and result types"); SmallVector castedOperands; - for (Value operand : operands) + for (Value operand : adaptor.getOperands()) castedOperands.push_back(maybeCast(operand, rewriter)); Type resultType = castedOperands.front().getType(); @@ -64,13 +64,14 @@ public: auto callOp = rewriter.create( op->getLoc(), resultType, SymbolRefAttr::get(funcOp), castedOperands); - if (resultType == operands.front().getType()) { + if (resultType == adaptor.getOperands().front().getType()) { rewriter.replaceOp(op, {callOp.getResult(0)}); return success(); } Value truncated = rewriter.create( - op->getLoc(), operands.front().getType(), callOp.getResult(0)); + op->getLoc(), adaptor.getOperands().front().getType(), + callOp.getResult(0)); rewriter.replaceOp(op, {truncated}); return success(); } @@ -85,11 +86,8 @@ private: operand.getLoc(), Float32Type::get(rewriter.getContext()), operand); } - Type getFunctionType(Type resultType, ArrayRef operands) const { - SmallVector operandTypes; - for (Value operand : operands) { - operandTypes.push_back(operand.getType()); - } + Type getFunctionType(Type resultType, ValueRange operands) const { + SmallVector operandTypes(operands.getTypes()); return LLVM::LLVMFunctionType::get(resultType, operandTypes); } diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index 4c8ad15..69a9fea 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -57,10 +57,9 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern { /// %shfl_pred = llvm.extractvalue %shfl[1 : index] : /// !llvm<"{ float, i1 }"> LogicalResult - matchAndRewrite(gpu::ShuffleOp op, ArrayRef operands, + matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); - gpu::ShuffleOpAdaptor adaptor(operands); auto valueTy = adaptor.value().getType(); auto int32Type = IntegerType::get(rewriter.getContext(), 32); diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp index 3a86e2e..0296390 100644 --- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp @@ -69,10 +69,10 @@ struct WmmaLoadOpToNVVMLowering LogicalResult matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp subgroupMmaLoadMatrixOp, - ArrayRef operands, + OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Operation *op = subgroupMmaLoadMatrixOp.getOperation(); - if (failed(areAllLLVMTypes(op, operands, rewriter))) + if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter))) return failure(); unsigned indexTypeBitwidth = @@ -88,7 +88,6 @@ struct WmmaLoadOpToNVVMLowering auto leadDimension = subgroupMmaLoadMatrixOp.leadDimensionAttr(); - gpu::SubgroupMmaLoadMatrixOpAdaptor adaptor(operands); // MemRefDescriptor to extract alignedPtr and offset. MemRefDescriptor promotedSrcOp(adaptor.srcMemref()); @@ -177,10 +176,10 @@ struct WmmaStoreOpToNVVMLowering LogicalResult matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp subgroupMmaStoreMatrixOp, - ArrayRef operands, + OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Operation *op = subgroupMmaStoreMatrixOp.getOperation(); - if (failed(areAllLLVMTypes(op, operands, rewriter))) + if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter))) return failure(); unsigned indexTypeBitwidth = @@ -194,7 +193,6 @@ struct WmmaStoreOpToNVVMLowering Location loc = op->getLoc(); - gpu::SubgroupMmaStoreMatrixOpAdaptor adaptor(operands); // MemRefDescriptor to extract alignedPtr and offset. MemRefDescriptor promotedDstOp(adaptor.dstMemref()); @@ -282,10 +280,10 @@ struct WmmaMmaOpToNVVMLowering LogicalResult matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp, - ArrayRef operands, + OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Operation *op = subgroupMmaComputeOp.getOperation(); - if (failed(areAllLLVMTypes(op, operands, rewriter))) + if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter))) return failure(); Location loc = op->getLoc(); @@ -317,17 +315,16 @@ struct WmmaMmaOpToNVVMLowering subgroupMmaComputeOp.opC().getType().cast(); ArrayRef cTypeShape = cType.getShape(); - gpu::SubgroupMmaComputeOpAdaptor transformedOperands(operands); - unpackOp(transformedOperands.opA()); - unpackOp(transformedOperands.opB()); - unpackOp(transformedOperands.opC()); + unpackOp(adaptor.opA()); + unpackOp(adaptor.opB()); + unpackOp(adaptor.opC()); if (cType.getElementType().isF16()) { if (aTypeShape[0] == 16 && aTypeShape[1] == 16 && bTypeShape[0] == 16 && bTypeShape[1] == 16 && cTypeShape[0] == 16 && cTypeShape[1] == 16) { // Create nvvm.wmma.mma op. rewriter.replaceOpWithNewOp( - op, transformedOperands.opC().getType(), unpackedOps); + op, adaptor.opC().getType(), unpackedOps); return success(); } @@ -338,7 +335,7 @@ struct WmmaMmaOpToNVVMLowering bTypeShape[1] == 16 && cTypeShape[0] == 16 && cTypeShape[1] == 16) { // Create nvvm.wmma.mma op. rewriter.replaceOpWithNewOp( - op, transformedOperands.opC().getType(), unpackedOps); + op, adaptor.opC().getType(), unpackedOps); return success(); } @@ -356,13 +353,13 @@ struct WmmaConstantOpToNVVMLowering LogicalResult matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp subgroupMmaConstantOp, - ArrayRef operands, + OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (failed(areAllLLVMTypes(subgroupMmaConstantOp.getOperation(), operands, - rewriter))) + if (failed(areAllLLVMTypes(subgroupMmaConstantOp.getOperation(), + adaptor.getOperands(), rewriter))) return failure(); Location loc = subgroupMmaConstantOp.getLoc(); - Value cst = operands[0]; + Value cst = adaptor.getOperands()[0]; LLVM::LLVMStructType type = convertMMAToLLVMType( subgroupMmaConstantOp.getType().cast()); // If the element type is a vector create a vector from the operand. diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp index 205b2bf..7138904 100644 --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -73,7 +73,7 @@ public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(RangeOp rangeOp, ArrayRef operands, + matchAndRewrite(RangeOp rangeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto rangeDescriptorTy = convertRangeType( rangeOp.getType().cast(), *getTypeConverter()); @@ -81,7 +81,6 @@ public: ImplicitLocOpBuilder b(rangeOp->getLoc(), rewriter); // Fill in an aggregate value of the descriptor. - RangeOpAdaptor adaptor(operands); Value desc = b.create(rangeDescriptorTy); desc = b.create(desc, adaptor.min(), rewriter.getI64ArrayAttr(0)); @@ -101,9 +100,9 @@ public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(linalg::YieldOp op, ArrayRef operands, + matchAndRewrite(linalg::YieldOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, operands); + rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); return success(); } }; diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp index 07364c0..3c476f2 100644 --- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp +++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp @@ -34,10 +34,9 @@ struct ExpM1OpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(math::ExpM1Op op, ArrayRef operands, + matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - math::ExpM1Op::Adaptor transformed(operands); - auto operandType = transformed.operand().getType(); + auto operandType = adaptor.operand().getType(); if (!operandType || !LLVM::isCompatibleType(operandType)) return failure(); @@ -56,7 +55,7 @@ struct ExpM1OpLowering : public ConvertOpToLLVMPattern { } else { one = rewriter.create(loc, operandType, floatOne); } - auto exp = rewriter.create(loc, transformed.operand()); + auto exp = rewriter.create(loc, adaptor.operand()); rewriter.replaceOpWithNewOp(op, operandType, exp, one); return success(); } @@ -66,7 +65,7 @@ struct ExpM1OpLowering : public ConvertOpToLLVMPattern { return rewriter.notifyMatchFailure(op, "expected vector result type"); return LLVM::detail::handleMultidimensionalVectors( - op.getOperation(), operands, *getTypeConverter(), + op.getOperation(), adaptor.getOperands(), *getTypeConverter(), [&](Type llvm1DVectorTy, ValueRange operands) { auto splatAttr = SplatElementsAttr::get( mlir::VectorType::get( @@ -88,10 +87,9 @@ struct Log1pOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(math::Log1pOp op, ArrayRef operands, + matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - math::Log1pOp::Adaptor transformed(operands); - auto operandType = transformed.operand().getType(); + auto operandType = adaptor.operand().getType(); if (!operandType || !LLVM::isCompatibleType(operandType)) return rewriter.notifyMatchFailure(op, "unsupported operand type"); @@ -111,7 +109,7 @@ struct Log1pOpLowering : public ConvertOpToLLVMPattern { : rewriter.create(loc, operandType, floatOne); auto add = rewriter.create(loc, operandType, one, - transformed.operand()); + adaptor.operand()); rewriter.replaceOpWithNewOp(op, operandType, add); return success(); } @@ -121,7 +119,7 @@ struct Log1pOpLowering : public ConvertOpToLLVMPattern { return rewriter.notifyMatchFailure(op, "expected vector result type"); return LLVM::detail::handleMultidimensionalVectors( - op.getOperation(), operands, *getTypeConverter(), + op.getOperation(), adaptor.getOperands(), *getTypeConverter(), [&](Type llvm1DVectorTy, ValueRange operands) { auto splatAttr = SplatElementsAttr::get( mlir::VectorType::get( @@ -143,10 +141,9 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(math::RsqrtOp op, ArrayRef operands, + matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - math::RsqrtOp::Adaptor transformed(operands); - auto operandType = transformed.operand().getType(); + auto operandType = adaptor.operand().getType(); if (!operandType || !LLVM::isCompatibleType(operandType)) return failure(); @@ -165,7 +162,7 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern { } else { one = rewriter.create(loc, operandType, floatOne); } - auto sqrt = rewriter.create(loc, transformed.operand()); + auto sqrt = rewriter.create(loc, adaptor.operand()); rewriter.replaceOpWithNewOp(op, operandType, one, sqrt); return success(); } @@ -175,7 +172,7 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern { return failure(); return LLVM::detail::handleMultidimensionalVectors( - op.getOperation(), operands, *getTypeConverter(), + op.getOperation(), adaptor.getOperands(), *getTypeConverter(), [&](Type llvm1DVectorTy, ValueRange operands) { auto splatAttr = SplatElementsAttr::get( mlir::VectorType::get( diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index 85c05f0..e43be68 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -194,7 +194,7 @@ struct AllocaScopeOpLowering using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, ArrayRef operands, + matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { OpBuilder::InsertionGuard guard(rewriter); Location loc = allocaScopeOp.getLoc(); @@ -249,10 +249,9 @@ struct AssumeAlignmentOpLowering memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(memref::AssumeAlignmentOp op, ArrayRef operands, + matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - memref::AssumeAlignmentOp::Adaptor transformed(operands); - Value memref = transformed.memref(); + Value memref = adaptor.memref(); unsigned alignment = op.alignment(); auto loc = op.getLoc(); @@ -293,14 +292,11 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern { : ConvertOpToLLVMPattern(converter) {} LogicalResult - matchAndRewrite(memref::DeallocOp op, ArrayRef operands, + matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - assert(operands.size() == 1 && "dealloc takes one operand"); - memref::DeallocOp::Adaptor transformed(operands); - // Insert the `free` declaration if it is not already present. auto freeFunc = LLVM::lookupOrCreateFreeFn(op->getParentOfType()); - MemRefDescriptor memref(transformed.memref()); + MemRefDescriptor memref(adaptor.memref()); Value casted = rewriter.create( op.getLoc(), getVoidPtrType(), memref.allocatedPtr(rewriter, op.getLoc())); @@ -316,18 +312,20 @@ struct DimOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(memref::DimOp dimOp, ArrayRef operands, + matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type operandType = dimOp.source().getType(); if (operandType.isa()) { - rewriter.replaceOp(dimOp, {extractSizeOfUnrankedMemRef( - operandType, dimOp, operands, rewriter)}); + rewriter.replaceOp( + dimOp, {extractSizeOfUnrankedMemRef( + operandType, dimOp, adaptor.getOperands(), rewriter)}); return success(); } if (operandType.isa()) { - rewriter.replaceOp(dimOp, {extractSizeOfRankedMemRef( - operandType, dimOp, operands, rewriter)}); + rewriter.replaceOp( + dimOp, {extractSizeOfRankedMemRef(operandType, dimOp, + adaptor.getOperands(), rewriter)}); return success(); } llvm_unreachable("expected MemRefType or UnrankedMemRefType"); @@ -335,10 +333,9 @@ struct DimOpLowering : public ConvertOpToLLVMPattern { private: Value extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp, - ArrayRef operands, + OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = dimOp.getLoc(); - memref::DimOp::Adaptor transformed(operands); auto unrankedMemRefType = operandType.cast(); auto scalarMemRefType = @@ -348,7 +345,7 @@ private: // Extract pointer to the underlying ranked descriptor and bitcast it to a // memref descriptor pointer to minimize the number of GEP // operations. - UnrankedMemRefDescriptor unrankedDesc(transformed.source()); + UnrankedMemRefDescriptor unrankedDesc(adaptor.source()); Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc); Value scalarMemRefDescPtr = rewriter.create( loc, @@ -369,7 +366,7 @@ private: // The size value that we have to extract can be obtained using GEPop with // `dimOp.index() + 1` index argument. Value idxPlusOne = rewriter.create( - loc, createIndexConstant(rewriter, loc, 1), transformed.index()); + loc, createIndexConstant(rewriter, loc, 1), adaptor.index()); Value sizePtr = rewriter.create(loc, indexPtrTy, offsetPtr, ValueRange({idxPlusOne})); return rewriter.create(loc, sizePtr); @@ -386,26 +383,26 @@ private: } Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp, - ArrayRef operands, + OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = dimOp.getLoc(); - memref::DimOp::Adaptor transformed(operands); + // Take advantage if index is constant. MemRefType memRefType = operandType.cast(); if (Optional index = getConstantDimIndex(dimOp)) { int64_t i = index.getValue(); if (memRefType.isDynamicDim(i)) { // extract dynamic size from the memref descriptor. - MemRefDescriptor descriptor(transformed.source()); + MemRefDescriptor descriptor(adaptor.source()); return descriptor.size(rewriter, loc, i); } // Use constant for static size. int64_t dimSize = memRefType.getDimSize(i); return createIndexConstant(rewriter, loc, dimSize); } - Value index = transformed.index(); + Value index = adaptor.index(); int64_t rank = memRefType.getRank(); - MemRefDescriptor memrefDescriptor(transformed.source()); + MemRefDescriptor memrefDescriptor(adaptor.source()); return memrefDescriptor.size(rewriter, loc, index, rank); } }; @@ -432,7 +429,7 @@ struct GlobalMemrefOpLowering using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(memref::GlobalOp global, ArrayRef operands, + matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { MemRefType type = global.type(); if (!isConvertibleAndHasIdentityMaps(type)) @@ -536,14 +533,12 @@ struct LoadOpLowering : public LoadStoreOpLowering { using Base::Base; LogicalResult - matchAndRewrite(memref::LoadOp loadOp, ArrayRef operands, + matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - memref::LoadOp::Adaptor transformed(operands); auto type = loadOp.getMemRefType(); - Value dataPtr = - getStridedElementPtr(loadOp.getLoc(), type, transformed.memref(), - transformed.indices(), rewriter); + Value dataPtr = getStridedElementPtr( + loadOp.getLoc(), type, adaptor.memref(), adaptor.indices(), rewriter); rewriter.replaceOpWithNewOp(loadOp, dataPtr); return success(); } @@ -555,16 +550,13 @@ struct StoreOpLowering : public LoadStoreOpLowering { using Base::Base; LogicalResult - matchAndRewrite(memref::StoreOp op, ArrayRef operands, + matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto type = op.getMemRefType(); - memref::StoreOp::Adaptor transformed(operands); - Value dataPtr = - getStridedElementPtr(op.getLoc(), type, transformed.memref(), - transformed.indices(), rewriter); - rewriter.replaceOpWithNewOp(op, transformed.value(), - dataPtr); + Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.memref(), + adaptor.indices(), rewriter); + rewriter.replaceOpWithNewOp(op, adaptor.value(), dataPtr); return success(); } }; @@ -575,14 +567,13 @@ struct PrefetchOpLowering : public LoadStoreOpLowering { using Base::Base; LogicalResult - matchAndRewrite(memref::PrefetchOp prefetchOp, ArrayRef operands, + matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - memref::PrefetchOp::Adaptor transformed(operands); auto type = prefetchOp.getMemRefType(); auto loc = prefetchOp.getLoc(); - Value dataPtr = getStridedElementPtr(loc, type, transformed.memref(), - transformed.indices(), rewriter); + Value dataPtr = getStridedElementPtr(loc, type, adaptor.memref(), + adaptor.indices(), rewriter); // Replace with llvm.prefetch. auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32)); @@ -627,10 +618,8 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern { : failure(); } - void rewrite(memref::CastOp memRefCastOp, ArrayRef operands, + void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - memref::CastOp::Adaptor transformed(operands); - auto srcType = memRefCastOp.getOperand().getType(); auto dstType = memRefCastOp.getType(); auto targetStructType = typeConverter->convertType(memRefCastOp.getType()); @@ -638,7 +627,7 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern { // For ranked/ranked case, just keep the original descriptor. if (srcType.isa() && dstType.isa()) - return rewriter.replaceOp(memRefCastOp, {transformed.source()}); + return rewriter.replaceOp(memRefCastOp, {adaptor.source()}); if (srcType.isa() && dstType.isa()) { // Casting ranked to unranked memref type @@ -649,7 +638,7 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern { int64_t rank = srcMemRefType.getRank(); // ptr = AllocaOp sizeof(MemRefDescriptor) auto ptr = getTypeConverter()->promoteOneMemRefDescriptor( - loc, transformed.source(), rewriter); + loc, adaptor.source(), rewriter); // voidptr = BitCastOp srcType* to void* auto voidPtr = rewriter.create(loc, getVoidPtrType(), ptr) @@ -671,7 +660,7 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern { // Casting from unranked type to ranked. // The operation is assumed to be doing a correct cast. If the destination // type mismatches the unranked the type, it is undefined behavior. - UnrankedMemRefDescriptor memRefDesc(transformed.source()); + UnrankedMemRefDescriptor memRefDesc(adaptor.source()); // ptr = ExtractValueOp src, 1 auto ptr = memRefDesc.memRefDescPtr(rewriter, loc); // castPtr = BitCastOp i8* to structTy* @@ -693,10 +682,9 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(memref::CopyOp op, ArrayRef operands, + matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); - memref::CopyOp::Adaptor adaptor(operands); auto srcType = op.source().getType().cast(); auto targetType = op.target().getType().cast(); @@ -799,10 +787,8 @@ struct MemRefReinterpretCastOpLowering memref::ReinterpretCastOp>::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(memref::ReinterpretCastOp castOp, ArrayRef operands, + matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - memref::ReinterpretCastOp::Adaptor adaptor(operands, - castOp->getAttrDictionary()); Type srcType = castOp.source().getType(); Value descriptor; @@ -867,17 +853,15 @@ struct MemRefReshapeOpLowering using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(memref::ReshapeOp reshapeOp, ArrayRef operands, + matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto *op = reshapeOp.getOperation(); - memref::ReshapeOp::Adaptor adaptor(operands, op->getAttrDictionary()); Type srcType = reshapeOp.source().getType(); Value descriptor; if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp, adaptor, &descriptor))) return failure(); - rewriter.replaceOp(op, {descriptor}); + rewriter.replaceOp(reshapeOp, {descriptor}); return success(); } @@ -1152,7 +1136,7 @@ public: using ReshapeOpAdaptor = typename ReshapeOp::Adaptor; LogicalResult - matchAndRewrite(ReshapeOp reshapeOp, ArrayRef operands, + matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { MemRefType dstType = reshapeOp.getResultType(); MemRefType srcType = reshapeOp.getSrcType(); @@ -1168,7 +1152,6 @@ public: reshapeOp, "failed to get stride and offset exprs"); } - ReshapeOpAdaptor adaptor(operands); MemRefDescriptor srcDesc(adaptor.src()); Location loc = reshapeOp->getLoc(); auto dstDesc = MemRefDescriptor::undef( @@ -1217,7 +1200,7 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(memref::SubViewOp subViewOp, ArrayRef operands, + matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = subViewOp.getLoc(); @@ -1249,9 +1232,9 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern { return failure(); // Create the descriptor. - if (!LLVM::isCompatibleType(operands.front().getType())) + if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType())) return failure(); - MemRefDescriptor sourceMemRef(operands.front()); + MemRefDescriptor sourceMemRef(adaptor.getOperands().front()); auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); // Copy the buffer pointer from the old descriptor to the new one. @@ -1296,7 +1279,7 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern { Value offset = // TODO: need OpFoldResult ODS adaptor to clean this up. subViewOp.isDynamicOffset(i) - ? operands[subViewOp.getIndexOfDynamicOffset(i)] + ? adaptor.getOperands()[subViewOp.getIndexOfDynamicOffset(i)] : rewriter.create( loc, llvmIndexType, rewriter.getI64IntegerAttr(subViewOp.getStaticOffset(i))); @@ -1346,7 +1329,7 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern { // TODO: need OpFoldResult ODS adaptor to clean this up. size = subViewOp.isDynamicSize(i) - ? operands[subViewOp.getIndexOfDynamicSize(i)] + ? adaptor.getOperands()[subViewOp.getIndexOfDynamicSize(i)] : rewriter.create( loc, llvmIndexType, rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i))); @@ -1354,12 +1337,13 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern { stride = rewriter.create( loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i])); } else { - stride = subViewOp.isDynamicStride(i) - ? operands[subViewOp.getIndexOfDynamicStride(i)] - : rewriter.create( - loc, llvmIndexType, - rewriter.getI64IntegerAttr( - subViewOp.getStaticStride(i))); + stride = + subViewOp.isDynamicStride(i) + ? adaptor.getOperands()[subViewOp.getIndexOfDynamicStride(i)] + : rewriter.create( + loc, llvmIndexType, + rewriter.getI64IntegerAttr( + subViewOp.getStaticStride(i))); stride = rewriter.create(loc, stride, strideValues[i]); } } @@ -1385,10 +1369,9 @@ public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(memref::TransposeOp transposeOp, ArrayRef operands, + matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = transposeOp.getLoc(); - memref::TransposeOpAdaptor adaptor(operands); MemRefDescriptor viewMemRef(adaptor.in()); // No permutation, early exit. @@ -1465,10 +1448,9 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern { } LogicalResult - matchAndRewrite(memref::ViewOp viewOp, ArrayRef operands, + matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = viewOp.getLoc(); - memref::ViewOpAdaptor adaptor(operands); auto viewMemRefType = viewOp.getType(); auto targetElementTy = diff --git a/mlir/lib/Conversion/OpenACCToLLVM/OpenACCToLLVM.cpp b/mlir/lib/Conversion/OpenACCToLLVM/OpenACCToLLVM.cpp index 511af47..fdff881 100644 --- a/mlir/lib/Conversion/OpenACCToLLVM/OpenACCToLLVM.cpp +++ b/mlir/lib/Conversion/OpenACCToLLVM/OpenACCToLLVM.cpp @@ -79,7 +79,7 @@ class LegalizeDataOpForLLVMTranslation : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Op op, ArrayRef operands, + matchAndRewrite(Op op, typename Op::Adaptor adaptor, ConversionPatternRewriter &builder) const override { Location loc = op.getLoc(); TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter(); @@ -87,8 +87,8 @@ class LegalizeDataOpForLLVMTranslation : public ConvertOpToLLVMPattern { unsigned numDataOperand = op.getNumDataOperands(); // Keep the non data operands without modification. - auto nonDataOperands = - operands.take_front(operands.size() - numDataOperand); + auto nonDataOperands = adaptor.getOperands().take_front( + adaptor.getOperands().size() - numDataOperand); SmallVector convertedOperands; convertedOperands.append(nonDataOperands.begin(), nonDataOperands.end()); diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp index e0b3ed8..0e6010c 100644 --- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp +++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp @@ -29,10 +29,10 @@ struct RegionOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(OpType curOp, ArrayRef operands, + matchAndRewrite(OpType curOp, typename OpType::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto newOp = rewriter.create(curOp.getLoc(), TypeRange(), operands, - curOp->getAttrs()); + auto newOp = rewriter.create( + curOp.getLoc(), TypeRange(), adaptor.getOperands(), curOp->getAttrs()); rewriter.inlineRegionBefore(curOp.region(), newOp.region(), newOp.region().end()); if (failed(rewriter.convertRegionTypes(&newOp.region(), diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp index eef6f2c..2f54a38 100644 --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp @@ -157,7 +157,7 @@ class GPULaunchLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(gpu::LaunchFuncOp launchOp, ArrayRef operands, + matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto *op = launchOp.getOperation(); MLIRContext *context = rewriter.getContext(); @@ -206,7 +206,7 @@ class GPULaunchLowering : public ConvertOpToLLVMPattern { Location loc = launchOp.getLoc(); SmallVector copyInfo; auto numKernelOperands = launchOp.getNumKernelOperands(); - auto kernelOperands = operands.take_back(numKernelOperands); + auto kernelOperands = adaptor.getOperands().take_back(numKernelOperands); for (auto operand : llvm::enumerate(kernelOperands)) { // Check if the kernel's operand is a ranked memref. auto memRefType = launchOp.getKernelOperand(operand.index()) diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 67583f9..ba942bb 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -178,7 +178,7 @@ public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(vector::BitCastOp bitCastOp, ArrayRef operands, + matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Only 1-D vectors can be lowered to LLVM. VectorType resultTy = bitCastOp.getType(); @@ -186,7 +186,7 @@ public: return failure(); Type newResultTy = typeConverter->convertType(resultTy); rewriter.replaceOpWithNewOp(bitCastOp, newResultTy, - operands[0]); + adaptor.getOperands()[0]); return success(); } }; @@ -199,9 +199,8 @@ public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(vector::MatmulOp matmulOp, ArrayRef operands, + matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto adaptor = vector::MatmulOpAdaptor(operands); rewriter.replaceOpWithNewOp( matmulOp, typeConverter->convertType(matmulOp.res().getType()), adaptor.lhs(), adaptor.rhs(), matmulOp.lhs_rows(), @@ -218,9 +217,8 @@ public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(vector::FlatTransposeOp transOp, ArrayRef operands, + matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto adaptor = vector::FlatTransposeOpAdaptor(operands); rewriter.replaceOpWithNewOp( transOp, typeConverter->convertType(transOp.res().getType()), adaptor.matrix(), transOp.rows(), transOp.columns()); @@ -270,7 +268,8 @@ public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(LoadOrStoreOp loadOrStoreOp, ArrayRef operands, + matchAndRewrite(LoadOrStoreOp loadOrStoreOp, + typename LoadOrStoreOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Only 1-D vectors can be lowered to LLVM. VectorType vectorTy = loadOrStoreOp.getVectorType(); @@ -278,7 +277,6 @@ public: return failure(); auto loc = loadOrStoreOp->getLoc(); - auto adaptor = LoadOrStoreOpAdaptor(operands); MemRefType memRefTy = loadOrStoreOp.getMemRefType(); // Resolve alignment. @@ -306,10 +304,9 @@ public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(vector::GatherOp gather, ArrayRef operands, + matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = gather->getLoc(); - auto adaptor = vector::GatherOpAdaptor(operands); MemRefType memRefType = gather.getMemRefType(); // Resolve alignment. @@ -341,10 +338,9 @@ public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(vector::ScatterOp scatter, ArrayRef operands, + matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = scatter->getLoc(); - auto adaptor = vector::ScatterOpAdaptor(operands); MemRefType memRefType = scatter.getMemRefType(); // Resolve alignment. @@ -376,10 +372,9 @@ public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(vector::ExpandLoadOp expand, ArrayRef operands, + matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = expand->getLoc(); - auto adaptor = vector::ExpandLoadOpAdaptor(operands); MemRefType memRefType = expand.getMemRefType(); // Resolve address. @@ -400,10 +395,9 @@ public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(vector::CompressStoreOp compress, ArrayRef operands, + matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = compress->getLoc(); - auto adaptor = vector::CompressStoreOpAdaptor(operands); MemRefType memRefType = compress.getMemRefType(); // Resolve address. @@ -426,42 +420,43 @@ public: reassociateFPReductions(reassociateFPRed) {} LogicalResult - matchAndRewrite(vector::ReductionOp reductionOp, ArrayRef operands, + matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto kind = reductionOp.kind(); Type eltType = reductionOp.dest().getType(); Type llvmType = typeConverter->convertType(eltType); + Value operand = adaptor.getOperands()[0]; if (eltType.isIntOrIndex()) { // Integer reductions: add/mul/min/max/and/or/xor. if (kind == "add") - rewriter.replaceOpWithNewOp( - reductionOp, llvmType, operands[0]); + rewriter.replaceOpWithNewOp(reductionOp, + llvmType, operand); else if (kind == "mul") - rewriter.replaceOpWithNewOp( - reductionOp, llvmType, operands[0]); + rewriter.replaceOpWithNewOp(reductionOp, + llvmType, operand); else if (kind == "min" && (eltType.isIndex() || eltType.isUnsignedInteger())) rewriter.replaceOpWithNewOp( - reductionOp, llvmType, operands[0]); + reductionOp, llvmType, operand); else if (kind == "min") rewriter.replaceOpWithNewOp( - reductionOp, llvmType, operands[0]); + reductionOp, llvmType, operand); else if (kind == "max" && (eltType.isIndex() || eltType.isUnsignedInteger())) rewriter.replaceOpWithNewOp( - reductionOp, llvmType, operands[0]); + reductionOp, llvmType, operand); else if (kind == "max") rewriter.replaceOpWithNewOp( - reductionOp, llvmType, operands[0]); + reductionOp, llvmType, operand); else if (kind == "and") - rewriter.replaceOpWithNewOp( - reductionOp, llvmType, operands[0]); + rewriter.replaceOpWithNewOp(reductionOp, + llvmType, operand); else if (kind == "or") - rewriter.replaceOpWithNewOp( - reductionOp, llvmType, operands[0]); + rewriter.replaceOpWithNewOp(reductionOp, + llvmType, operand); else if (kind == "xor") - rewriter.replaceOpWithNewOp( - reductionOp, llvmType, operands[0]); + rewriter.replaceOpWithNewOp(reductionOp, + llvmType, operand); else return failure(); return success(); @@ -473,29 +468,30 @@ public: // Floating-point reductions: add/mul/min/max if (kind == "add") { // Optional accumulator (or zero). - Value acc = operands.size() > 1 ? operands[1] - : rewriter.create( - reductionOp->getLoc(), llvmType, - rewriter.getZeroAttr(eltType)); + Value acc = adaptor.getOperands().size() > 1 + ? adaptor.getOperands()[1] + : rewriter.create( + reductionOp->getLoc(), llvmType, + rewriter.getZeroAttr(eltType)); rewriter.replaceOpWithNewOp( - reductionOp, llvmType, acc, operands[0], + reductionOp, llvmType, acc, operand, rewriter.getBoolAttr(reassociateFPReductions)); } else if (kind == "mul") { // Optional accumulator (or one). - Value acc = operands.size() > 1 - ? operands[1] + Value acc = adaptor.getOperands().size() > 1 + ? adaptor.getOperands()[1] : rewriter.create( reductionOp->getLoc(), llvmType, rewriter.getFloatAttr(eltType, 1.0)); rewriter.replaceOpWithNewOp( - reductionOp, llvmType, acc, operands[0], + reductionOp, llvmType, acc, operand, rewriter.getBoolAttr(reassociateFPReductions)); } else if (kind == "min") - rewriter.replaceOpWithNewOp( - reductionOp, llvmType, operands[0]); + rewriter.replaceOpWithNewOp(reductionOp, + llvmType, operand); else if (kind == "max") - rewriter.replaceOpWithNewOp( - reductionOp, llvmType, operands[0]); + rewriter.replaceOpWithNewOp(reductionOp, + llvmType, operand); else return failure(); return success(); @@ -511,10 +507,9 @@ public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(vector::ShuffleOp shuffleOp, ArrayRef operands, + matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = shuffleOp->getLoc(); - auto adaptor = vector::ShuffleOpAdaptor(operands); auto v1Type = shuffleOp.getV1VectorType(); auto v2Type = shuffleOp.getV2VectorType(); auto vectorType = shuffleOp.getVectorType(); @@ -573,10 +568,8 @@ public: vector::ExtractElementOp>::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(vector::ExtractElementOp extractEltOp, - ArrayRef operands, + matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto adaptor = vector::ExtractElementOpAdaptor(operands); auto vectorType = extractEltOp.getVectorType(); auto llvmType = typeConverter->convertType(vectorType.getElementType()); @@ -596,10 +589,9 @@ public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(vector::ExtractOp extractOp, ArrayRef operands, + matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = extractOp->getLoc(); - auto adaptor = vector::ExtractOpAdaptor(operands); auto vectorType = extractOp.getVectorType(); auto resultType = extractOp.getResult().getType(); auto llvmResultType = typeConverter->convertType(resultType); @@ -667,9 +659,8 @@ public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(vector::FMAOp fmaOp, ArrayRef operands, + matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto adaptor = vector::FMAOpAdaptor(operands); VectorType vType = fmaOp.getVectorType(); if (vType.getRank() != 1) return failure(); @@ -685,9 +676,8 @@ public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(vector::InsertElementOp insertEltOp, ArrayRef operands, + matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto adaptor = vector::InsertElementOpAdaptor(operands); auto vectorType = insertEltOp.getDestVectorType(); auto llvmType = typeConverter->convertType(vectorType); @@ -708,10 +698,9 @@ public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(vector::InsertOp insertOp, ArrayRef operands, + matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = insertOp->getLoc(); - auto adaptor = vector::InsertOpAdaptor(operands); auto sourceType = insertOp.getSourceType(); auto destVectorType = insertOp.getDestVectorType(); auto llvmResultType = typeConverter->convertType(destVectorType); @@ -984,7 +973,7 @@ public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(vector::TypeCastOp castOp, ArrayRef operands, + matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = castOp->getLoc(); MemRefType sourceMemRefType = @@ -997,10 +986,10 @@ public: return failure(); auto llvmSourceDescriptorTy = - operands[0].getType().dyn_cast(); + adaptor.getOperands()[0].getType().dyn_cast(); if (!llvmSourceDescriptorTy) return failure(); - MemRefDescriptor sourceMemRef(operands[0]); + MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]); auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) .dyn_cast_or_null(); @@ -1074,9 +1063,8 @@ public: // TODO: rely solely on libc in future? something else? // LogicalResult - matchAndRewrite(vector::PrintOp printOp, ArrayRef operands, + matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto adaptor = vector::PrintOpAdaptor(operands); Type printType = printOp.getPrintType(); if (typeConverter->convertType(printType) == nullptr) diff --git a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp index 9685faf..cc54b7f 100644 --- a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp +++ b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp @@ -30,7 +30,7 @@ using namespace mlir; using namespace mlir::vector; static LogicalResult replaceTransferOpWithMubuf( - ConversionPatternRewriter &rewriter, ArrayRef operands, + ConversionPatternRewriter &rewriter, ValueRange operands, LLVMTypeConverter &typeConverter, Location loc, TransferReadOp xferOp, Type &vecTy, Value &dwordConfig, Value &vindex, Value &offsetSizeInBytes, Value &glc, Value &slc) { @@ -40,7 +40,7 @@ static LogicalResult replaceTransferOpWithMubuf( } static LogicalResult replaceTransferOpWithMubuf( - ConversionPatternRewriter &rewriter, ArrayRef operands, + ConversionPatternRewriter &rewriter, ValueRange operands, LLVMTypeConverter &typeConverter, Location loc, TransferWriteOp xferOp, Type &vecTy, Value &dwordConfig, Value &vindex, Value &offsetSizeInBytes, Value &glc, Value &slc) { @@ -62,10 +62,8 @@ public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(ConcreteOp xferOp, ArrayRef operands, + matchAndRewrite(ConcreteOp xferOp, typename ConcreteOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - typename ConcreteOp::Adaptor adaptor(operands, xferOp->getAttrDictionary()); - if (xferOp.getVectorType().getRank() > 1 || llvm::size(xferOp.indices()) == 0) return failure(); @@ -139,8 +137,8 @@ public: loc, toLLVMTy(i32Ty), rewriter.getIntegerAttr(rewriter.getIntegerType(32), 0)); return replaceTransferOpWithMubuf( - rewriter, operands, *this->getTypeConverter(), loc, xferOp, vecTy, - dwordConfig, int32Zero, int32Zero, int1False, int1False); + rewriter, adaptor.getOperands(), *this->getTypeConverter(), loc, xferOp, + vecTy, dwordConfig, int32Zero, int32Zero, int1False, int1False); } }; } // end anonymous namespace diff --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp index a8cf8c1..007a1c0 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp @@ -244,9 +244,8 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(tensor::InsertSliceOp op, ArrayRef operands, + matchAndRewrite(tensor::InsertSliceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { - tensor::InsertSliceOpAdaptor adaptor(operands, op->getAttrDictionary()); Value sourceMemRef = adaptor.source(); assert(sourceMemRef.getType().isa()); @@ -273,12 +272,10 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(vector::TransferReadOp readOp, ArrayRef operands, + matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { if (readOp.getShapedType().isa()) return failure(); - vector::TransferReadOp::Adaptor adaptor(operands, - readOp->getAttrDictionary()); rewriter.replaceOpWithNewOp( readOp, readOp.getType(), adaptor.source(), adaptor.indices(), adaptor.permutation_map(), adaptor.padding(), adaptor.mask(), @@ -293,12 +290,10 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(vector::TransferWriteOp writeOp, ArrayRef operands, + matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { if (writeOp.getShapedType().isa()) return failure(); - vector::TransferWriteOp::Adaptor adaptor(operands, - writeOp->getAttrDictionary()); rewriter.create( writeOp.getLoc(), adaptor.vector(), adaptor.source(), adaptor.indices(), adaptor.permutation_map(), diff --git a/mlir/test/lib/Conversion/StandardToLLVM/TestConvertCallOp.cpp b/mlir/test/lib/Conversion/StandardToLLVM/TestConvertCallOp.cpp index 873a9a8..43f57fe 100644 --- a/mlir/test/lib/Conversion/StandardToLLVM/TestConvertCallOp.cpp +++ b/mlir/test/lib/Conversion/StandardToLLVM/TestConvertCallOp.cpp @@ -25,7 +25,7 @@ public: test::TestTypeProducerOp>::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(test::TestTypeProducerOp op, ArrayRef operands, + matchAndRewrite(test::TestTypeProducerOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, getVoidPtrType()); return success(); diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index d51cf5e..4be8cb1 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -783,7 +783,7 @@ struct OneVResOneVOperandOp1Converter using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(OneVResOneVOperandOp1 op, ArrayRef operands, + matchAndRewrite(OneVResOneVOperandOp1 op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto origOps = op.getOperands(); assert(std::distance(origOps.begin(), origOps.end()) == 1 && @@ -878,7 +878,7 @@ struct TestTypeConversionProducer : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(TestTypeProducerOp op, ArrayRef operands, + matchAndRewrite(TestTypeProducerOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { Type resultType = op.getType(); if (resultType.isa()) @@ -900,7 +900,7 @@ struct TestSignatureConversionUndo using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(TestSignatureConversionUndoOp op, ArrayRef operands, + matchAndRewrite(TestSignatureConversionUndoOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter()); return failure(); @@ -914,9 +914,10 @@ struct TestTypeConsumerForward using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(TestTypeConsumerOp op, ArrayRef operands, + matchAndRewrite(TestTypeConsumerOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { - rewriter.updateRootInPlace(op, [&] { op->setOperands(operands); }); + rewriter.updateRootInPlace(op, + [&] { op->setOperands(adaptor.getOperands()); }); return success(); } }; @@ -1022,7 +1023,7 @@ struct TestMergeBlock : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(TestMergeBlocksOp op, ArrayRef operands, + matchAndRewrite(TestMergeBlocksOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { Block &firstBlock = op.body().front(); Operation *branchOp = firstBlock.getTerminator(); @@ -1065,7 +1066,7 @@ struct TestMergeSingleBlockOps SingleBlockImplicitTerminatorOp>::OpConversionPattern; LogicalResult - matchAndRewrite(SingleBlockImplicitTerminatorOp op, ArrayRef operands, + matchAndRewrite(SingleBlockImplicitTerminatorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { SingleBlockImplicitTerminatorOp parentOp = op->getParentOfType(); -- 2.7.4