From: River Riddle Date: Fri, 24 Sep 2021 17:50:58 +0000 (+0000) Subject: [mlir:OpConversionPattern] Add overloads for taking an Adaptor instead of ArrayRef X-Git-Tag: upstream/15.0.7~30577 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=b54c724be0b490f231af534696b3b7ef072a7ca1;p=platform%2Fupstream%2Fllvm.git [mlir:OpConversionPattern] Add overloads for taking an Adaptor instead of ArrayRef This has been a TODO for a long time, and it brings about many advantages (namely nice accessors, and less fragile code). The existing overloads that accept ArrayRef are now treated as deprecated and will be removed in a followup (after a small grace period). Most of the upstream MLIR usages have been fixed by this commit, the rest will be handled in a followup. Differential Revision: https://reviews.llvm.org/D110293 --- diff --git a/mlir/docs/Bufferization.md b/mlir/docs/Bufferization.md index ac2e068..4a6db34 100644 --- a/mlir/docs/Bufferization.md +++ b/mlir/docs/Bufferization.md @@ -139,10 +139,10 @@ class BufferizeCastOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(tensor::CastOp op, ArrayRef operands, + matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto resultType = getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp(op, resultType, operands[0]); + rewriter.replaceOpWithNewOp(op, resultType, adaptor.source()); return success(); } }; diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h index 21e2591..81358dc 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h @@ -131,6 +131,8 @@ protected: template class ConvertOpToLLVMPattern : public ConvertToLLVMPattern { public: + using OpAdaptor = typename SourceOp::Adaptor; + explicit ConvertOpToLLVMPattern(LLVMTypeConverter &typeConverter, PatternBenefit benefit = 1) : ConvertToLLVMPattern(SourceOp::getOperationName(), @@ -140,7 +142,8 @@ public: /// Wrappers around the RewritePattern methods that pass the derived op type. void rewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { - rewrite(cast(op), operands, rewriter); + rewrite(cast(op), OpAdaptor(operands, op->getAttrDictionary()), + rewriter); } LogicalResult match(Operation *op) const final { return match(cast(op)); @@ -148,28 +151,53 @@ public: LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { - return matchAndRewrite(cast(op), operands, rewriter); + return matchAndRewrite(cast(op), + OpAdaptor(operands, op->getAttrDictionary()), + rewriter); } /// Rewrite and Match methods that operate on the SourceOp type. These must be /// overridden by the derived pattern class. + /// NOTICE: These methods are deprecated and will be removed. All new code + /// should use the adaptor methods below instead. virtual void rewrite(SourceOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { llvm_unreachable("must override rewrite or matchAndRewrite"); } - virtual LogicalResult match(SourceOp op) const { - llvm_unreachable("must override match or matchAndRewrite"); - } virtual LogicalResult matchAndRewrite(SourceOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { if (succeeded(match(op))) { - rewrite(op, operands, rewriter); + rewrite(op, OpAdaptor(operands, op->getAttrDictionary()), rewriter); return success(); } return failure(); } + /// Rewrite and Match methods that operate on the SourceOp type. These must be + /// overridden by the derived pattern class. + virtual LogicalResult match(SourceOp op) const { + llvm_unreachable("must override match or matchAndRewrite"); + } + virtual void rewrite(SourceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + ValueRange operands = adaptor.getOperands(); + rewrite(op, + ArrayRef(operands.getBase().get(), + operands.size()), + rewriter); + } + virtual LogicalResult + matchAndRewrite(SourceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + ValueRange operands = adaptor.getOperands(); + return matchAndRewrite( + op, + ArrayRef(operands.getBase().get(), + operands.size()), + rewriter); + } + private: using ConvertToLLVMPattern::match; using ConvertToLLVMPattern::matchAndRewrite; diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 7354b55..9fe9690 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -366,79 +366,121 @@ private: using RewritePattern::rewrite; }; -namespace detail { -/// OpOrInterfaceConversionPatternBase is a wrapper around ConversionPattern -/// that allows for matching and rewriting against an instance of a derived -/// operation class or an Interface as opposed to a raw Operation. +/// OpConversionPattern is a wrapper around ConversionPattern that allows for +/// matching and rewriting against an instance of a derived operation class as +/// opposed to a raw Operation. template -struct OpOrInterfaceConversionPatternBase : public ConversionPattern { - using ConversionPattern::ConversionPattern; +class OpConversionPattern : public ConversionPattern { +public: + using OpAdaptor = typename SourceOp::Adaptor; + + OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1) + : ConversionPattern(SourceOp::getOperationName(), benefit, context) {} + OpConversionPattern(TypeConverter &typeConverter, MLIRContext *context, + PatternBenefit benefit = 1) + : ConversionPattern(typeConverter, SourceOp::getOperationName(), benefit, + context) {} /// Wrappers around the ConversionPattern methods that pass the derived op /// type. void rewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { - rewrite(cast(op), operands, rewriter); + rewrite(cast(op), OpAdaptor(operands, op->getAttrDictionary()), + rewriter); } LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { - return matchAndRewrite(cast(op), operands, rewriter); + return matchAndRewrite(cast(op), + OpAdaptor(operands, op->getAttrDictionary()), + rewriter); } - // TODO: Use OperandAdaptor when it supports access to unnamed operands. - - /// Rewrite and Match methods that operate on the SourceOp type. These must be - /// overridden by the derived pattern class. + /// Rewrite and Match methods that operate on the SourceOp type and accept the + /// raw operand values. + /// NOTICE: These methods are deprecated and will be removed. All new code + /// should use the adaptor methods below instead. virtual void rewrite(SourceOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { llvm_unreachable("must override matchAndRewrite or a rewrite method"); } - virtual LogicalResult matchAndRewrite(SourceOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { if (failed(match(op))) return failure(); - rewrite(op, operands, rewriter); + rewrite(op, OpAdaptor(operands, op->getAttrDictionary()), rewriter); return success(); } + /// Rewrite and Match methods that operate on the SourceOp type. These must be + /// overridden by the derived pattern class. + virtual void rewrite(SourceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + ValueRange operands = adaptor.getOperands(); + rewrite(op, + ArrayRef(operands.getBase().get(), + operands.size()), + rewriter); + } + virtual LogicalResult + matchAndRewrite(SourceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + ValueRange operands = adaptor.getOperands(); + return matchAndRewrite( + op, + ArrayRef(operands.getBase().get(), + operands.size()), + rewriter); + } + private: using ConversionPattern::matchAndRewrite; }; -} // namespace detail - -/// OpConversionPattern is a wrapper around ConversionPattern that allows for -/// matching and rewriting against an instance of a derived operation class as -/// opposed to a raw Operation. -template -struct OpConversionPattern - : public detail::OpOrInterfaceConversionPatternBase { - OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1) - : detail::OpOrInterfaceConversionPatternBase( - SourceOp::getOperationName(), benefit, context) {} - OpConversionPattern(TypeConverter &typeConverter, MLIRContext *context, - PatternBenefit benefit = 1) - : detail::OpOrInterfaceConversionPatternBase( - typeConverter, SourceOp::getOperationName(), benefit, context) {} -}; /// OpInterfaceConversionPattern is a wrapper around ConversionPattern that /// allows for matching and rewriting against an instance of an OpInterface /// class as opposed to a raw Operation. template -struct OpInterfaceConversionPattern - : public detail::OpOrInterfaceConversionPatternBase { +class OpInterfaceConversionPattern : public ConversionPattern { +public: OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit = 1) - : detail::OpOrInterfaceConversionPatternBase( - Pattern::MatchInterfaceOpTypeTag(), SourceOp::getInterfaceID(), - benefit, context) {} + : ConversionPattern(Pattern::MatchInterfaceOpTypeTag(), + SourceOp::getInterfaceID(), benefit, context) {} OpInterfaceConversionPattern(TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) - : detail::OpOrInterfaceConversionPatternBase( - typeConverter, Pattern::MatchInterfaceOpTypeTag(), - SourceOp::getInterfaceID(), benefit, context) {} + : ConversionPattern(typeConverter, Pattern::MatchInterfaceOpTypeTag(), + SourceOp::getInterfaceID(), benefit, context) {} + + /// Wrappers around the ConversionPattern methods that pass the derived op + /// type. + void rewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + rewrite(cast(op), operands, rewriter); + } + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + return matchAndRewrite(cast(op), operands, rewriter); + } + + /// Rewrite and Match methods that operate on the SourceOp type. These must be + /// overridden by the derived pattern class. + virtual void rewrite(SourceOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + llvm_unreachable("must override matchAndRewrite or a rewrite method"); + } + virtual LogicalResult + matchAndRewrite(SourceOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + if (failed(match(op))) + return failure(); + rewrite(op, operands, rewriter); + return success(); + } + +private: + using ConversionPattern::matchAndRewrite; }; /// Add a pattern to the given pattern list to convert the signature of a diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp index 4f4dd0e..8e0c1a2 100644 --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -326,7 +326,7 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(CoroIdOp op, ArrayRef operands, + matchAndRewrite(CoroIdOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto token = AsyncAPI::tokenType(op->getContext()); auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); @@ -356,7 +356,7 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(CoroBeginOp op, ArrayRef operands, + matchAndRewrite(CoroBeginOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); auto loc = op->getLoc(); @@ -371,7 +371,7 @@ public: ValueRange(coroSize.getResult())); // Begin a coroutine: @llvm.coro.begin. - auto coroId = CoroBeginOpAdaptor(operands).id(); + auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).id(); rewriter.replaceOpWithNewOp( op, i8Ptr, ValueRange({coroId, coroAlloc.getResult(0)})); @@ -390,13 +390,14 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(CoroFreeOp op, ArrayRef operands, + matchAndRewrite(CoroFreeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); auto loc = op->getLoc(); // Get a pointer to the coroutine frame memory: @llvm.coro.free. - auto coroMem = rewriter.create(loc, i8Ptr, operands); + auto coroMem = + rewriter.create(loc, i8Ptr, adaptor.getOperands()); // Free the memory. rewriter.replaceOpWithNewOp( @@ -418,14 +419,14 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(CoroEndOp op, ArrayRef operands, + matchAndRewrite(CoroEndOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // We are not in the block that is part of the unwind sequence. auto constFalse = rewriter.create( op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false)); // Mark the end of a coroutine: @llvm.coro.end. - auto coroHdl = CoroEndOpAdaptor(operands).handle(); + auto coroHdl = adaptor.handle(); rewriter.create(op->getLoc(), rewriter.getI1Type(), ValueRange({coroHdl, constFalse})); rewriter.eraseOp(op); @@ -445,11 +446,11 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(CoroSaveOp op, ArrayRef operands, + matchAndRewrite(CoroSaveOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Save the coroutine state: @llvm.coro.save rewriter.replaceOpWithNewOp( - op, AsyncAPI::tokenType(op->getContext()), operands); + op, AsyncAPI::tokenType(op->getContext()), adaptor.getOperands()); return success(); } @@ -491,7 +492,7 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(CoroSuspendOp op, ArrayRef operands, + matchAndRewrite(CoroSuspendOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto i8 = rewriter.getIntegerType(8); auto i32 = rewriter.getI32Type(); @@ -502,7 +503,7 @@ public: loc, rewriter.getI1Type(), rewriter.getBoolAttr(false)); // Suspend a coroutine: @llvm.coro.suspend - auto coroState = CoroSuspendOpAdaptor(operands).state(); + auto coroState = adaptor.state(); auto coroSuspend = rewriter.create( loc, i8, ValueRange({coroState, constFalse})); @@ -541,7 +542,7 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(RuntimeCreateOp op, ArrayRef operands, + matchAndRewrite(RuntimeCreateOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { TypeConverter *converter = getTypeConverter(); Type resultType = op->getResultTypes()[0]; @@ -595,13 +596,14 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(RuntimeCreateGroupOp op, ArrayRef operands, + matchAndRewrite(RuntimeCreateGroupOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { TypeConverter *converter = getTypeConverter(); Type resultType = op.getResult().getType(); - rewriter.replaceOpWithNewOp( - op, kCreateGroup, converter->convertType(resultType), operands); + rewriter.replaceOpWithNewOp(op, kCreateGroup, + converter->convertType(resultType), + adaptor.getOperands()); return success(); } }; @@ -618,14 +620,15 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(RuntimeSetAvailableOp op, ArrayRef operands, + matchAndRewrite(RuntimeSetAvailableOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { StringRef apiFuncName = TypeSwitch(op.operand().getType()) .Case([](Type) { return kEmplaceToken; }) .Case([](Type) { return kEmplaceValue; }); - rewriter.replaceOpWithNewOp(op, apiFuncName, TypeRange(), operands); + rewriter.replaceOpWithNewOp(op, apiFuncName, TypeRange(), + adaptor.getOperands()); return success(); } @@ -643,14 +646,15 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(RuntimeSetErrorOp op, ArrayRef operands, + matchAndRewrite(RuntimeSetErrorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { StringRef apiFuncName = TypeSwitch(op.operand().getType()) .Case([](Type) { return kSetTokenError; }) .Case([](Type) { return kSetValueError; }); - rewriter.replaceOpWithNewOp(op, apiFuncName, TypeRange(), operands); + rewriter.replaceOpWithNewOp(op, apiFuncName, TypeRange(), + adaptor.getOperands()); return success(); } @@ -667,7 +671,7 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(RuntimeIsErrorOp op, ArrayRef operands, + matchAndRewrite(RuntimeIsErrorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { StringRef apiFuncName = TypeSwitch(op.operand().getType()) @@ -676,7 +680,7 @@ public: .Case([](Type) { return kIsValueError; }); rewriter.replaceOpWithNewOp(op, apiFuncName, rewriter.getI1Type(), - operands); + adaptor.getOperands()); return success(); } }; @@ -692,7 +696,7 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(RuntimeAwaitOp op, ArrayRef operands, + matchAndRewrite(RuntimeAwaitOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { StringRef apiFuncName = TypeSwitch(op.operand().getType()) @@ -700,7 +704,8 @@ public: .Case([](Type) { return kAwaitValue; }) .Case([](Type) { return kAwaitGroup; }); - rewriter.create(op->getLoc(), apiFuncName, TypeRange(), operands); + rewriter.create(op->getLoc(), apiFuncName, TypeRange(), + adaptor.getOperands()); rewriter.eraseOp(op); return success(); @@ -719,7 +724,7 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(RuntimeAwaitAndResumeOp op, ArrayRef operands, + matchAndRewrite(RuntimeAwaitAndResumeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { StringRef apiFuncName = TypeSwitch(op.operand().getType()) @@ -727,8 +732,8 @@ public: .Case([](Type) { return kAwaitValueAndExecute; }) .Case([](Type) { return kAwaitAllAndExecute; }); - Value operand = RuntimeAwaitAndResumeOpAdaptor(operands).operand(); - Value handle = RuntimeAwaitAndResumeOpAdaptor(operands).handle(); + Value operand = adaptor.operand(); + Value handle = adaptor.handle(); // A pointer to coroutine resume intrinsic wrapper. addResumeFunction(op->getParentOfType()); @@ -755,7 +760,7 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(RuntimeResumeOp op, ArrayRef operands, + matchAndRewrite(RuntimeResumeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // A pointer to coroutine resume intrinsic wrapper. addResumeFunction(op->getParentOfType()); @@ -764,7 +769,7 @@ public: op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume); // Call async runtime API to execute a coroutine in the managed thread. - auto coroHdl = RuntimeResumeOpAdaptor(operands).handle(); + auto coroHdl = adaptor.handle(); rewriter.replaceOpWithNewOp(op, TypeRange(), kExecute, ValueRange({coroHdl, resumePtr.res()})); @@ -783,13 +788,13 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(RuntimeStoreOp op, ArrayRef operands, + matchAndRewrite(RuntimeStoreOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); // Get a pointer to the async value storage from the runtime. auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext()); - auto storage = RuntimeStoreOpAdaptor(operands).storage(); + auto storage = adaptor.storage(); auto storagePtr = rewriter.create(loc, kGetValueStorage, TypeRange(i8Ptr), storage); @@ -805,7 +810,7 @@ public: storagePtr.getResult(0)); // Store the yielded value into the async value storage. - auto value = RuntimeStoreOpAdaptor(operands).value(); + auto value = adaptor.value(); rewriter.create(loc, value, castedStoragePtr.getResult()); // Erase the original runtime store operation. @@ -826,13 +831,13 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(RuntimeLoadOp op, ArrayRef operands, + matchAndRewrite(RuntimeLoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); // Get a pointer to the async value storage from the runtime. auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext()); - auto storage = RuntimeLoadOpAdaptor(operands).storage(); + auto storage = adaptor.storage(); auto storagePtr = rewriter.create(loc, kGetValueStorage, TypeRange(i8Ptr), storage); @@ -866,15 +871,15 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(RuntimeAddToGroupOp op, ArrayRef operands, + matchAndRewrite(RuntimeAddToGroupOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Currently we can only add tokens to the group. if (!op.operand().getType().isa()) return rewriter.notifyMatchFailure(op, "only token type is supported"); // Replace with a runtime API function call. - rewriter.replaceOpWithNewOp(op, kAddTokenToGroup, - rewriter.getI64Type(), operands); + rewriter.replaceOpWithNewOp( + op, kAddTokenToGroup, rewriter.getI64Type(), adaptor.getOperands()); return success(); } @@ -896,13 +901,13 @@ public: apiFunctionName(apiFunctionName) {} LogicalResult - matchAndRewrite(RefCountingOp op, ArrayRef operands, + matchAndRewrite(RefCountingOp op, typename RefCountingOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto count = rewriter.create(op->getLoc(), rewriter.getI32Type(), rewriter.getI32IntegerAttr(op.count())); - auto operand = typename RefCountingOp::Adaptor(operands).operand(); + auto operand = adaptor.operand(); rewriter.replaceOpWithNewOp(op, TypeRange(), apiFunctionName, ValueRange({operand, count})); @@ -937,9 +942,9 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ReturnOp op, ArrayRef operands, + matchAndRewrite(ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, operands); + rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); return success(); } }; @@ -1032,7 +1037,7 @@ class ConvertExecuteOpTypes : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ExecuteOp op, ArrayRef operands, + matchAndRewrite(ExecuteOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { ExecuteOp newOp = cast(rewriter.cloneWithoutRegions(*op.getOperation())); @@ -1040,7 +1045,7 @@ public: newOp.getRegion().end()); // Set operands and update block argument and result types. - newOp->setOperands(operands); + newOp->setOperands(adaptor.getOperands()); if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter))) return failure(); for (auto result : newOp.getResults()) @@ -1056,9 +1061,9 @@ class ConvertAwaitOpTypes : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(AwaitOp op, ArrayRef operands, + matchAndRewrite(AwaitOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, operands.front()); + rewriter.replaceOpWithNewOp(op, adaptor.getOperands().front()); return success(); } }; @@ -1068,9 +1073,9 @@ class ConvertYieldOpTypes : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(async::YieldOp op, ArrayRef operands, + matchAndRewrite(async::YieldOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, operands); + rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); return success(); } }; diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp index f651eed..6ca60d0 100644 --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -26,16 +26,13 @@ struct AbsOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(complex::AbsOp op, ArrayRef operands, + matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - complex::AbsOp::Adaptor transformed(operands); auto loc = op.getLoc(); auto type = op.getType(); - Value real = - rewriter.create(loc, type, transformed.complex()); - Value imag = - rewriter.create(loc, type, transformed.complex()); + Value real = rewriter.create(loc, type, adaptor.complex()); + Value imag = rewriter.create(loc, type, adaptor.complex()); Value realSqr = rewriter.create(loc, real, real); Value imagSqr = rewriter.create(loc, imag, imag); Value sqNorm = rewriter.create(loc, realSqr, imagSqr); @@ -53,23 +50,16 @@ struct ComparisonOpConversion : public OpConversionPattern { AndOp, OrOp>; LogicalResult - matchAndRewrite(ComparisonOp op, ArrayRef operands, + matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - typename ComparisonOp::Adaptor transformed(operands); auto loc = op.getLoc(); - auto type = transformed.lhs() - .getType() - .template cast() - .getElementType(); - - Value realLhs = - rewriter.create(loc, type, transformed.lhs()); - Value imagLhs = - rewriter.create(loc, type, transformed.lhs()); - Value realRhs = - rewriter.create(loc, type, transformed.rhs()); - Value imagRhs = - rewriter.create(loc, type, transformed.rhs()); + auto type = + adaptor.lhs().getType().template cast().getElementType(); + + Value realLhs = rewriter.create(loc, type, adaptor.lhs()); + Value imagLhs = rewriter.create(loc, type, adaptor.lhs()); + Value realRhs = rewriter.create(loc, type, adaptor.rhs()); + Value imagRhs = rewriter.create(loc, type, adaptor.rhs()); Value realComparison = rewriter.create(loc, p, realLhs, realRhs); Value imagComparison = rewriter.create(loc, p, imagLhs, imagRhs); @@ -87,19 +77,18 @@ struct BinaryComplexOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(BinaryComplexOp op, ArrayRef operands, + matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - typename BinaryComplexOp::Adaptor transformed(operands); - auto type = transformed.lhs().getType().template cast(); + auto type = adaptor.lhs().getType().template cast(); auto elementType = type.getElementType().template cast(); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); - Value realLhs = b.create(elementType, transformed.lhs()); - Value realRhs = b.create(elementType, transformed.rhs()); + Value realLhs = b.create(elementType, adaptor.lhs()); + Value realRhs = b.create(elementType, adaptor.rhs()); Value resultReal = b.create(elementType, realLhs, realRhs); - Value imagLhs = b.create(elementType, transformed.lhs()); - Value imagRhs = b.create(elementType, transformed.rhs()); + Value imagLhs = b.create(elementType, adaptor.lhs()); + Value imagRhs = b.create(elementType, adaptor.rhs()); Value resultImag = b.create(elementType, imagLhs, imagRhs); rewriter.replaceOpWithNewOp(op, type, resultReal, @@ -112,21 +101,20 @@ struct DivOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(complex::DivOp op, ArrayRef operands, + matchAndRewrite(complex::DivOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - complex::DivOp::Adaptor transformed(operands); auto loc = op.getLoc(); - auto type = transformed.lhs().getType().cast(); + auto type = adaptor.lhs().getType().cast(); auto elementType = type.getElementType().cast(); Value lhsReal = - rewriter.create(loc, elementType, transformed.lhs()); + rewriter.create(loc, elementType, adaptor.lhs()); Value lhsImag = - rewriter.create(loc, elementType, transformed.lhs()); + rewriter.create(loc, elementType, adaptor.lhs()); Value rhsReal = - rewriter.create(loc, elementType, transformed.rhs()); + rewriter.create(loc, elementType, adaptor.rhs()); Value rhsImag = - rewriter.create(loc, elementType, transformed.rhs()); + rewriter.create(loc, elementType, adaptor.rhs()); // Smith's algorithm to divide complex numbers. It is just a bit smarter // way to compute the following formula: @@ -321,17 +309,16 @@ struct ExpOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(complex::ExpOp op, ArrayRef operands, + matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - complex::ExpOp::Adaptor transformed(operands); auto loc = op.getLoc(); - auto type = transformed.complex().getType().cast(); + auto type = adaptor.complex().getType().cast(); auto elementType = type.getElementType().cast(); Value real = - rewriter.create(loc, elementType, transformed.complex()); + rewriter.create(loc, elementType, adaptor.complex()); Value imag = - rewriter.create(loc, elementType, transformed.complex()); + rewriter.create(loc, elementType, adaptor.complex()); Value expReal = rewriter.create(loc, real); Value cosImag = rewriter.create(loc, imag); Value resultReal = rewriter.create(loc, expReal, cosImag); @@ -348,17 +335,16 @@ struct LogOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(complex::LogOp op, ArrayRef operands, + matchAndRewrite(complex::LogOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - complex::LogOp::Adaptor transformed(operands); - auto type = transformed.complex().getType().cast(); + auto type = adaptor.complex().getType().cast(); auto elementType = type.getElementType().cast(); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); - Value abs = b.create(elementType, transformed.complex()); + Value abs = b.create(elementType, adaptor.complex()); Value resultReal = b.create(elementType, abs); - Value real = b.create(elementType, transformed.complex()); - Value imag = b.create(elementType, transformed.complex()); + Value real = b.create(elementType, adaptor.complex()); + Value imag = b.create(elementType, adaptor.complex()); Value resultImag = b.create(elementType, imag, real); rewriter.replaceOpWithNewOp(op, type, resultReal, resultImag); @@ -370,15 +356,14 @@ struct Log1pOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(complex::Log1pOp op, ArrayRef operands, + matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - complex::Log1pOp::Adaptor transformed(operands); - auto type = transformed.complex().getType().cast(); + auto type = adaptor.complex().getType().cast(); auto elementType = type.getElementType().cast(); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); - Value real = b.create(elementType, transformed.complex()); - Value imag = b.create(elementType, transformed.complex()); + Value real = b.create(elementType, adaptor.complex()); + Value imag = b.create(elementType, adaptor.complex()); Value one = b.create(elementType, b.getFloatAttr(elementType, 1)); Value realPlusOne = b.create(real, one); @@ -392,20 +377,19 @@ struct MulOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(complex::MulOp op, ArrayRef operands, + matchAndRewrite(complex::MulOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - complex::MulOp::Adaptor transformed(operands); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); - auto type = transformed.lhs().getType().cast(); + auto type = adaptor.lhs().getType().cast(); auto elementType = type.getElementType().cast(); - Value lhsReal = b.create(elementType, transformed.lhs()); + Value lhsReal = b.create(elementType, adaptor.lhs()); Value lhsRealAbs = b.create(lhsReal); - Value lhsImag = b.create(elementType, transformed.lhs()); + Value lhsImag = b.create(elementType, adaptor.lhs()); Value lhsImagAbs = b.create(lhsImag); - Value rhsReal = b.create(elementType, transformed.rhs()); + Value rhsReal = b.create(elementType, adaptor.rhs()); Value rhsRealAbs = b.create(rhsReal); - Value rhsImag = b.create(elementType, transformed.rhs()); + Value rhsImag = b.create(elementType, adaptor.rhs()); Value rhsImagAbs = b.create(rhsImag); Value lhsRealTimesRhsReal = b.create(lhsReal, rhsReal); @@ -530,17 +514,16 @@ struct NegOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(complex::NegOp op, ArrayRef operands, + matchAndRewrite(complex::NegOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - complex::NegOp::Adaptor transformed(operands); auto loc = op.getLoc(); - auto type = transformed.complex().getType().cast(); + auto type = adaptor.complex().getType().cast(); auto elementType = type.getElementType().cast(); Value real = - rewriter.create(loc, elementType, transformed.complex()); + rewriter.create(loc, elementType, adaptor.complex()); Value imag = - rewriter.create(loc, elementType, transformed.complex()); + rewriter.create(loc, elementType, adaptor.complex()); Value negReal = rewriter.create(loc, real); Value negImag = rewriter.create(loc, imag); rewriter.replaceOpWithNewOp(op, type, negReal, negImag); @@ -552,25 +535,23 @@ struct SignOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(complex::SignOp op, ArrayRef operands, + matchAndRewrite(complex::SignOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - complex::SignOp::Adaptor transformed(operands); - auto type = transformed.complex().getType().cast(); + auto type = adaptor.complex().getType().cast(); auto elementType = type.getElementType().cast(); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); - Value real = b.create(elementType, transformed.complex()); - Value imag = b.create(elementType, transformed.complex()); + Value real = b.create(elementType, adaptor.complex()); + Value imag = b.create(elementType, adaptor.complex()); Value zero = b.create(elementType, b.getZeroAttr(elementType)); Value realIsZero = b.create(CmpFPredicate::OEQ, real, zero); Value imagIsZero = b.create(CmpFPredicate::OEQ, imag, zero); Value isZero = b.create(realIsZero, imagIsZero); - auto abs = b.create(elementType, transformed.complex()); + auto abs = b.create(elementType, adaptor.complex()); Value realSign = b.create(real, abs); Value imagSign = b.create(imag, abs); Value sign = b.create(type, realSign, imagSign); - rewriter.replaceOpWithNewOp(op, isZero, transformed.complex(), - sign); + rewriter.replaceOpWithNewOp(op, isZero, adaptor.complex(), sign); return success(); } }; diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp index a303a87..88eca46 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -33,7 +33,7 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(SourceOp op, ArrayRef operands, + matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -45,7 +45,7 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(SourceOp op, ArrayRef operands, + matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -58,7 +58,7 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(gpu::BlockDimOp op, ArrayRef operands, + matchAndRewrite(gpu::BlockDimOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -68,7 +68,7 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(gpu::GPUFuncOp funcOp, ArrayRef operands, + matchAndRewrite(gpu::GPUFuncOp funcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; private: @@ -81,7 +81,7 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(gpu::GPUModuleOp moduleOp, ArrayRef operands, + matchAndRewrite(gpu::GPUModuleOp moduleOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -91,7 +91,7 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(gpu::ModuleEndOp endOp, ArrayRef operands, + matchAndRewrite(gpu::ModuleEndOp endOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.eraseOp(endOp); return success(); @@ -105,7 +105,7 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(gpu::ReturnOp returnOp, ArrayRef operands, + matchAndRewrite(gpu::ReturnOp returnOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -129,7 +129,7 @@ static Optional getLaunchConfigIndex(Operation *op) { template LogicalResult LaunchConfigConversion::matchAndRewrite( - SourceOp op, ArrayRef operands, + SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const { auto index = getLaunchConfigIndex(op); if (!index) @@ -150,7 +150,7 @@ LogicalResult LaunchConfigConversion::matchAndRewrite( template LogicalResult SingleDimLaunchConfigConversion::matchAndRewrite( - SourceOp op, ArrayRef operands, + SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const { auto *typeConverter = this->template getTypeConverter(); auto indexType = typeConverter->getIndexType(); @@ -162,7 +162,7 @@ SingleDimLaunchConfigConversion::matchAndRewrite( } LogicalResult WorkGroupSizeConversion::matchAndRewrite( - gpu::BlockDimOp op, ArrayRef operands, + gpu::BlockDimOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto index = getLaunchConfigIndex(op); if (!index) @@ -264,7 +264,7 @@ getDefaultABIAttrs(MLIRContext *context, gpu::GPUFuncOp funcOp, } LogicalResult GPUFuncOpConversion::matchAndRewrite( - gpu::GPUFuncOp funcOp, ArrayRef operands, + gpu::GPUFuncOp funcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { if (!gpu::GPUDialect::isKernel(funcOp)) return failure(); @@ -306,7 +306,7 @@ LogicalResult GPUFuncOpConversion::matchAndRewrite( //===----------------------------------------------------------------------===// LogicalResult GPUModuleConversion::matchAndRewrite( - gpu::GPUModuleOp moduleOp, ArrayRef operands, + gpu::GPUModuleOp moduleOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { spirv::TargetEnvAttr targetEnv = spirv::lookupTargetEnvOrDefault(moduleOp); spirv::AddressingModel addressingModel = spirv::getAddressingModel(targetEnv); @@ -336,9 +336,9 @@ LogicalResult GPUModuleConversion::matchAndRewrite( //===----------------------------------------------------------------------===// LogicalResult GPUReturnOpConversion::matchAndRewrite( - gpu::ReturnOp returnOp, ArrayRef operands, + gpu::ReturnOp returnOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - if (!operands.empty()) + if (!adaptor.getOperands().empty()) return failure(); rewriter.replaceOpWithNewOp(returnOp); diff --git a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp index bd1e4ad..a43ffbc 100644 --- a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp +++ b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp @@ -55,7 +55,7 @@ struct SingleWorkgroupReduction final matchAsPerformingReduction(linalg::GenericOp genericOp); LogicalResult - matchAndRewrite(linalg::GenericOp genericOp, ArrayRef operands, + matchAndRewrite(linalg::GenericOp genericOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -109,7 +109,7 @@ SingleWorkgroupReduction::matchAsPerformingReduction( } LogicalResult SingleWorkgroupReduction::matchAndRewrite( - linalg::GenericOp genericOp, ArrayRef operands, + linalg::GenericOp genericOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Operation *op = genericOp.getOperation(); auto originalInputType = op->getOperand(0).getType().cast(); @@ -134,7 +134,8 @@ LogicalResult SingleWorkgroupReduction::matchAndRewrite( // TODO: Query the target environment to make sure the current // workload fits in a local workgroup. - Value convertedInput = operands[0], convertedOutput = operands[1]; + Value convertedInput = adaptor.getOperands()[0]; + Value convertedOutput = adaptor.getOperands()[1]; Location loc = genericOp.getLoc(); auto *typeConverter = getTypeConverter(); diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp index e30cbc0..04e8869 100644 --- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp +++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp @@ -37,9 +37,9 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(StdOp operation, ArrayRef operands, + matchAndRewrite(StdOp operation, typename StdOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - assert(operands.size() <= 2); + assert(adaptor.getOperands().size() <= 2); auto dstType = this->getTypeConverter()->convertType(operation.getType()); if (!dstType) return failure(); @@ -48,7 +48,8 @@ public: return operation.emitError( "bitwidth emulation is not implemented yet on unsigned op"); } - rewriter.template replaceOpWithNewOp(operation, dstType, operands); + rewriter.template replaceOpWithNewOp(operation, dstType, + adaptor.getOperands()); return success(); } }; @@ -62,14 +63,15 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(math::Log1pOp operation, ArrayRef operands, + matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - assert(operands.size() == 1); + assert(adaptor.getOperands().size() == 1); Location loc = operation.getLoc(); auto type = this->getTypeConverter()->convertType(operation.operand().getType()); auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter); - auto onePlus = rewriter.create(loc, one, operands[0]); + auto onePlus = + rewriter.create(loc, one, adaptor.getOperands()[0]); rewriter.replaceOpWithNewOp(operation, type, onePlus); return success(); } diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp index 9f9f115..7fb0f02 100644 --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -158,7 +158,7 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(memref::AllocOp operation, ArrayRef operands, + matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -169,7 +169,7 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(memref::DeallocOp operation, ArrayRef operands, + matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -179,7 +179,7 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(memref::LoadOp loadOp, ArrayRef operands, + matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -189,7 +189,7 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(memref::LoadOp loadOp, ArrayRef operands, + matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -199,7 +199,7 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(memref::StoreOp storeOp, ArrayRef operands, + matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -209,7 +209,7 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(memref::StoreOp storeOp, ArrayRef operands, + matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -220,8 +220,7 @@ public: //===----------------------------------------------------------------------===// LogicalResult -AllocOpPattern::matchAndRewrite(memref::AllocOp operation, - ArrayRef operands, +AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { MemRefType allocType = operation.getType(); if (!isAllocationSupported(allocType)) @@ -260,7 +259,7 @@ AllocOpPattern::matchAndRewrite(memref::AllocOp operation, LogicalResult DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation, - ArrayRef operands, + OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { MemRefType deallocType = operation.memref().getType().cast(); if (!isAllocationSupported(deallocType)) @@ -274,10 +273,8 @@ DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation, //===----------------------------------------------------------------------===// LogicalResult -IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, - ArrayRef operands, +IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - memref::LoadOpAdaptor loadOperands(operands); auto loc = loadOp.getLoc(); auto memrefType = loadOp.memref().getType().cast(); if (!memrefType.getElementType().isSignlessInteger()) @@ -285,8 +282,8 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, auto &typeConverter = *getTypeConverter(); spirv::AccessChainOp accessChainOp = - spirv::getElementPtr(typeConverter, memrefType, loadOperands.memref(), - loadOperands.indices(), loc, rewriter); + spirv::getElementPtr(typeConverter, memrefType, adaptor.memref(), + adaptor.indices(), loc, rewriter); if (!accessChainOp) return failure(); @@ -372,15 +369,14 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, } LogicalResult -LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, ArrayRef operands, +LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - memref::LoadOpAdaptor loadOperands(operands); auto memrefType = loadOp.memref().getType().cast(); if (memrefType.getElementType().isSignlessInteger()) return failure(); auto loadPtr = spirv::getElementPtr( - *getTypeConverter(), memrefType, - loadOperands.memref(), loadOperands.indices(), loadOp.getLoc(), rewriter); + *getTypeConverter(), memrefType, adaptor.memref(), + adaptor.indices(), loadOp.getLoc(), rewriter); if (!loadPtr) return failure(); @@ -390,10 +386,8 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, ArrayRef operands, } LogicalResult -IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, - ArrayRef operands, +IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - memref::StoreOpAdaptor storeOperands(operands); auto memrefType = storeOp.memref().getType().cast(); if (!memrefType.getElementType().isSignlessInteger()) return failure(); @@ -401,8 +395,8 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, auto loc = storeOp.getLoc(); auto &typeConverter = *getTypeConverter(); spirv::AccessChainOp accessChainOp = - spirv::getElementPtr(typeConverter, memrefType, storeOperands.memref(), - storeOperands.indices(), loc, rewriter); + spirv::getElementPtr(typeConverter, memrefType, adaptor.memref(), + adaptor.indices(), loc, rewriter); if (!accessChainOp) return failure(); @@ -427,7 +421,7 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, assert(dstBits % srcBits == 0); if (srcBits == dstBits) { - Value storeVal = storeOperands.value(); + Value storeVal = adaptor.value(); if (isBool) storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter); rewriter.replaceOpWithNewOp( @@ -458,7 +452,7 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, rewriter.create(loc, dstType, mask, offset); clearBitsMask = rewriter.create(loc, dstType, clearBitsMask); - Value storeVal = storeOperands.value(); + Value storeVal = adaptor.value(); if (isBool) storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter); storeVal = shiftValue(loc, storeVal, offset, mask, dstBits, rewriter); @@ -487,23 +481,20 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, } LogicalResult -StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, - ArrayRef operands, +StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - memref::StoreOpAdaptor storeOperands(operands); auto memrefType = storeOp.memref().getType().cast(); if (memrefType.getElementType().isSignlessInteger()) return failure(); - auto storePtr = - spirv::getElementPtr(*getTypeConverter(), memrefType, - storeOperands.memref(), storeOperands.indices(), - storeOp.getLoc(), rewriter); + auto storePtr = spirv::getElementPtr( + *getTypeConverter(), memrefType, adaptor.memref(), + adaptor.indices(), storeOp.getLoc(), rewriter); if (!storePtr) return failure(); rewriter.replaceOpWithNewOp(storeOp, storePtr, - storeOperands.value()); + adaptor.value()); return success(); } diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp index 08e3d3f..dc5bdc7 100644 --- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp +++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp @@ -84,7 +84,7 @@ public: using SCFToSPIRVPattern::SCFToSPIRVPattern; LogicalResult - matchAndRewrite(scf::ForOp forOp, ArrayRef operands, + matchAndRewrite(scf::ForOp forOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -95,7 +95,7 @@ public: using SCFToSPIRVPattern::SCFToSPIRVPattern; LogicalResult - matchAndRewrite(scf::IfOp ifOp, ArrayRef operands, + matchAndRewrite(scf::IfOp ifOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -104,7 +104,7 @@ public: using SCFToSPIRVPattern::SCFToSPIRVPattern; LogicalResult - matchAndRewrite(scf::YieldOp terminatorOp, ArrayRef operands, + matchAndRewrite(scf::YieldOp terminatorOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; } // namespace @@ -146,14 +146,13 @@ static void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp, //===----------------------------------------------------------------------===// LogicalResult -ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef operands, +ForOpConversion::matchAndRewrite(scf::ForOp forOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { // scf::ForOp can be lowered to the structured control flow represented by // spirv::LoopOp by making the continue block of the spirv::LoopOp the loop // latch and the merge block the exit block. The resulting spirv::LoopOp has a // single back edge from the continue to header block, and a single exit from // header to merge. - scf::ForOpAdaptor forOperands(operands); auto loc = forOp.getLoc(); auto loopOp = rewriter.create(loc, spirv::LoopControl::None); loopOp.addEntryAndMergeBlock(); @@ -165,9 +164,8 @@ ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef operands, loopOp.body().getBlocks().insert(std::next(loopOp.body().begin(), 1), header); // Create the new induction variable to use. - BlockArgument newIndVar = - header->addArgument(forOperands.lowerBound().getType()); - for (Value arg : forOperands.initArgs()) + BlockArgument newIndVar = header->addArgument(adaptor.lowerBound().getType()); + for (Value arg : adaptor.initArgs()) header->addArgument(arg.getType()); Block *body = forOp.getBody(); @@ -187,8 +185,8 @@ ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef operands, rewriter.inlineRegionBefore(forOp->getRegion(0), loopOp.body(), std::next(loopOp.body().begin(), 2)); - SmallVector args(1, forOperands.lowerBound()); - args.append(forOperands.initArgs().begin(), forOperands.initArgs().end()); + SmallVector args(1, adaptor.lowerBound()); + args.append(adaptor.initArgs().begin(), adaptor.initArgs().end()); // Branch into it from the entry. rewriter.setInsertionPointToEnd(&(loopOp.body().front())); rewriter.create(loc, header, args); @@ -197,7 +195,7 @@ ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef operands, rewriter.setInsertionPointToEnd(header); auto *mergeBlock = loopOp.getMergeBlock(); auto cmpOp = rewriter.create( - loc, rewriter.getI1Type(), newIndVar, forOperands.upperBound()); + loc, rewriter.getI1Type(), newIndVar, adaptor.upperBound()); rewriter.create( loc, cmpOp, body, ArrayRef(), mergeBlock, ArrayRef()); @@ -209,7 +207,7 @@ ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef operands, // Add the step to the induction variable and branch to the header. Value updatedIndVar = rewriter.create( - loc, newIndVar.getType(), newIndVar, forOperands.step()); + loc, newIndVar.getType(), newIndVar, adaptor.step()); rewriter.create(loc, header, updatedIndVar); // Infer the return types from the init operands. Vector type may get @@ -217,7 +215,7 @@ ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef operands, // extra logic to figure out the right type we just infer it from the Init // operands. SmallVector initTypes; - for (auto arg : forOperands.initArgs()) + for (auto arg : adaptor.initArgs()) initTypes.push_back(arg.getType()); replaceSCFOutputValue(forOp, loopOp, rewriter, scfToSPIRVContext, initTypes); return success(); @@ -228,12 +226,11 @@ ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef operands, //===----------------------------------------------------------------------===// LogicalResult -IfOpConversion::matchAndRewrite(scf::IfOp ifOp, ArrayRef operands, +IfOpConversion::matchAndRewrite(scf::IfOp ifOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { // When lowering `scf::IfOp` we explicitly create a selection header block // before the control flow diverges and a merge block where control flow // subsequently converges. - scf::IfOpAdaptor ifOperands(operands); auto loc = ifOp.getLoc(); // Create `spv.selection` operation, selection header block and merge block. @@ -267,7 +264,7 @@ IfOpConversion::matchAndRewrite(scf::IfOp ifOp, ArrayRef operands, // Create a `spv.BranchConditional` operation for selection header block. rewriter.setInsertionPointToEnd(selectionHeaderBlock); - rewriter.create(loc, ifOperands.condition(), + rewriter.create(loc, adaptor.condition(), thenBlock, ArrayRef(), elseBlock, ArrayRef()); @@ -289,8 +286,10 @@ IfOpConversion::matchAndRewrite(scf::IfOp ifOp, ArrayRef operands, /// parent region. For loops we also need to update the branch looping back to /// the header with the loop carried values. LogicalResult TerminatorOpConversion::matchAndRewrite( - scf::YieldOp terminatorOp, ArrayRef operands, + scf::YieldOp terminatorOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + ValueRange operands = adaptor.getOperands(); + // If the region is return values, store each value into the associated // VariableOp created during lowering of the parent region. if (!operands.empty()) { diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp index 8d957e0..348d8ad 100644 --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -302,7 +302,7 @@ public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(spirv::AccessChainOp op, ArrayRef operands, + matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto dstType = typeConverter.convertType(op.component_ptr().getType()); if (!dstType) @@ -327,7 +327,7 @@ public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(spirv::AddressOfOp op, ArrayRef operands, + matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto dstType = typeConverter.convertType(op.pointer().getType()); if (!dstType) @@ -343,7 +343,7 @@ public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(spirv::BitFieldInsertOp op, ArrayRef operands, + matchAndRewrite(spirv::BitFieldInsertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto srcType = op.getType(); auto dstType = typeConverter.convertType(srcType); @@ -387,7 +387,7 @@ public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(spirv::ConstantOp constOp, ArrayRef operands, + matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto srcType = constOp.getType(); if (!srcType.isa() && !srcType.isIntOrFloat()) @@ -419,8 +419,8 @@ public: rewriter.replaceOpWithNewOp(constOp, dstType, dstAttr); return success(); } - rewriter.replaceOpWithNewOp(constOp, dstType, operands, - constOp->getAttrs()); + rewriter.replaceOpWithNewOp( + constOp, dstType, adaptor.getOperands(), constOp->getAttrs()); return success(); } }; @@ -431,7 +431,7 @@ public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(spirv::BitFieldSExtractOp op, ArrayRef operands, + matchAndRewrite(spirv::BitFieldSExtractOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto srcType = op.getType(); auto dstType = typeConverter.convertType(srcType); @@ -484,7 +484,7 @@ public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(spirv::BitFieldUExtractOp op, ArrayRef operands, + matchAndRewrite(spirv::BitFieldUExtractOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto srcType = op.getType(); auto dstType = typeConverter.convertType(srcType); @@ -518,9 +518,9 @@ public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(spirv::BranchOp branchOp, ArrayRef operands, + matchAndRewrite(spirv::BranchOp branchOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(branchOp, operands, + rewriter.replaceOpWithNewOp(branchOp, adaptor.getOperands(), branchOp.getTarget()); return success(); } @@ -533,7 +533,7 @@ public: spirv::BranchConditionalOp>::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(spirv::BranchConditionalOp op, ArrayRef operands, + matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // If branch weights exist, map them to 32-bit integer vector. ElementsAttr branchWeights = nullptr; @@ -560,7 +560,7 @@ public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(spirv::CompositeExtractOp op, ArrayRef operands, + matchAndRewrite(spirv::CompositeExtractOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto dstType = this->typeConverter.convertType(op.getType()); if (!dstType) @@ -590,7 +590,7 @@ public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(spirv::CompositeInsertOp op, ArrayRef operands, + matchAndRewrite(spirv::CompositeInsertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto dstType = this->typeConverter.convertType(op.getType()); if (!dstType) @@ -619,13 +619,13 @@ public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(SPIRVOp operation, ArrayRef operands, + matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto dstType = this->typeConverter.convertType(operation.getType()); if (!dstType) return failure(); - rewriter.template replaceOpWithNewOp(operation, dstType, operands, - operation->getAttrs()); + rewriter.template replaceOpWithNewOp( + operation, dstType, adaptor.getOperands(), operation->getAttrs()); return success(); } }; @@ -638,7 +638,7 @@ public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(spirv::ExecutionModeOp op, ArrayRef operands, + matchAndRewrite(spirv::ExecutionModeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // First, create the global struct's name that would be associated with // this entry point's execution mode. We set it to be: @@ -717,7 +717,7 @@ public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(spirv::GlobalVariableOp op, ArrayRef operands, + matchAndRewrite(spirv::GlobalVariableOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Currently, there is no support of initialization with a constant value in // SPIR-V dialect. Specialization constants are not considered as well. @@ -767,7 +767,7 @@ public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(SPIRVOp operation, ArrayRef operands, + matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type fromType = operation.operand().getType(); @@ -779,12 +779,12 @@ public: if (getBitWidth(fromType) < getBitWidth(toType)) { rewriter.template replaceOpWithNewOp(operation, dstType, - operands); + adaptor.getOperands()); return success(); } if (getBitWidth(fromType) > getBitWidth(toType)) { rewriter.template replaceOpWithNewOp(operation, dstType, - operands); + adaptor.getOperands()); return success(); } return failure(); @@ -797,18 +797,18 @@ public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(spirv::FunctionCallOp callOp, ArrayRef operands, + matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (callOp.getNumResults() == 0) { - rewriter.replaceOpWithNewOp(callOp, llvm::None, operands, - callOp->getAttrs()); + rewriter.replaceOpWithNewOp( + callOp, llvm::None, adaptor.getOperands(), callOp->getAttrs()); return success(); } // Function returns a single result. auto dstType = typeConverter.convertType(callOp.getType(0)); - rewriter.replaceOpWithNewOp(callOp, dstType, operands, - callOp->getAttrs()); + rewriter.replaceOpWithNewOp( + callOp, dstType, adaptor.getOperands(), callOp->getAttrs()); return success(); } }; @@ -820,7 +820,7 @@ public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(SPIRVOp operation, ArrayRef operands, + matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto dstType = this->typeConverter.convertType(operation.getType()); @@ -841,7 +841,7 @@ public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(SPIRVOp operation, ArrayRef operands, + matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto dstType = this->typeConverter.convertType(operation.getType()); @@ -861,7 +861,7 @@ public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(spirv::GLSLInverseSqrtOp op, ArrayRef operands, + matchAndRewrite(spirv::GLSLInverseSqrtOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto srcType = op.getType(); auto dstType = typeConverter.convertType(srcType); @@ -877,15 +877,14 @@ public: }; /// Converts `spv.Load` and `spv.Store` to LLVM dialect. -template -class LoadStorePattern : public SPIRVToLLVMConversion { +template +class LoadStorePattern : public SPIRVToLLVMConversion { public: - using SPIRVToLLVMConversion::SPIRVToLLVMConversion; + using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(SPIRVop op, ArrayRef operands, + matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (!op.memory_access().hasValue()) { return replaceWithLoadOrStore( op, rewriter, this->typeConverter, /*alignment=*/0, @@ -918,9 +917,8 @@ public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(SPIRVOp notOp, ArrayRef operands, + matchAndRewrite(SPIRVOp notOp, typename SPIRVOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto srcType = notOp.getType(); auto dstType = this->typeConverter.convertType(srcType); if (!dstType) @@ -947,7 +945,7 @@ public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(SPIRVOp op, ArrayRef operands, + matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.eraseOp(op); return success(); @@ -959,7 +957,7 @@ public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(spirv::ReturnOp returnOp, ArrayRef operands, + matchAndRewrite(spirv::ReturnOp returnOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(returnOp, ArrayRef(), ArrayRef()); @@ -972,10 +970,10 @@ public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(spirv::ReturnValueOp returnValueOp, ArrayRef operands, + matchAndRewrite(spirv::ReturnValueOp returnValueOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(returnValueOp, ArrayRef(), - operands); + adaptor.getOperands()); return success(); } }; @@ -1033,7 +1031,7 @@ public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(spirv::LoopOp loopOp, ArrayRef operands, + matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // There is no support of loop control at the moment. if (loopOp.loop_control() != spirv::LoopControl::None) @@ -1080,7 +1078,7 @@ public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(spirv::SelectionOp op, ArrayRef operands, + matchAndRewrite(spirv::SelectionOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // There is no support for `Flatten` or `DontFlatten` selection control at // the moment. This are just compiler hints and can be performed during the @@ -1149,7 +1147,7 @@ public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(SPIRVOp operation, ArrayRef operands, + matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto dstType = this->typeConverter.convertType(operation.getType()); @@ -1161,7 +1159,7 @@ public: if (op1Type == op2Type) { rewriter.template replaceOpWithNewOp(operation, dstType, - operands); + adaptor.getOperands()); return success(); } @@ -1186,7 +1184,7 @@ public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(spirv::GLSLTanOp tanOp, ArrayRef operands, + matchAndRewrite(spirv::GLSLTanOp tanOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto dstType = typeConverter.convertType(tanOp.getType()); if (!dstType) @@ -1211,7 +1209,7 @@ public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(spirv::GLSLTanhOp tanhOp, ArrayRef operands, + matchAndRewrite(spirv::GLSLTanhOp tanhOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto srcType = tanhOp.getType(); auto dstType = typeConverter.convertType(srcType); @@ -1239,7 +1237,7 @@ public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(spirv::VariableOp varOp, ArrayRef operands, + matchAndRewrite(spirv::VariableOp varOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto srcType = varOp.getType(); // Initialization is supported for scalars and vectors only. @@ -1274,7 +1272,7 @@ public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(spirv::FuncOp funcOp, ArrayRef operands, + matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Convert function signature. At the moment LLVMType converter is enough @@ -1337,7 +1335,7 @@ public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(spirv::ModuleOp spvModuleOp, ArrayRef operands, + matchAndRewrite(spirv::ModuleOp spvModuleOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto newModuleOp = diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp index b662b52..f622d5e 100644 --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -29,19 +29,17 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(AnyOp op, ArrayRef operands, + matchAndRewrite(AnyOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; } // namespace LogicalResult -AnyOpConversion::matchAndRewrite(AnyOp op, ArrayRef operands, +AnyOpConversion::matchAndRewrite(AnyOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - AnyOp::Adaptor transformed(operands); - // Replace `any` with its first operand. // Any operand would be a valid substitution. - rewriter.replaceOp(op, {transformed.inputs().front()}); + rewriter.replaceOp(op, {adaptor.inputs().front()}); return success(); } @@ -52,16 +50,13 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(SrcOpTy op, ArrayRef operands, + matchAndRewrite(SrcOpTy op, typename SrcOpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - typename SrcOpTy::Adaptor transformed(operands); - // For now, only error-free types are supported by this lowering. if (op.getType().template isa()) return failure(); - rewriter.replaceOpWithNewOp(op, transformed.lhs(), - transformed.rhs()); + rewriter.replaceOpWithNewOp(op, adaptor.lhs(), adaptor.rhs()); return success(); } }; @@ -72,7 +67,7 @@ struct BroadcastOpConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(BroadcastOp op, ArrayRef operands, + matchAndRewrite(BroadcastOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -120,7 +115,7 @@ Value getBroadcastedDim(ImplicitLocOpBuilder lb, ValueRange extentTensors, } // namespace LogicalResult BroadcastOpConverter::matchAndRewrite( - BroadcastOp op, ArrayRef operands, + BroadcastOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { // For now, this lowering is only defined on `tensor` operands, not // on shapes. @@ -129,7 +124,6 @@ LogicalResult BroadcastOpConverter::matchAndRewrite( auto loc = op.getLoc(); ImplicitLocOpBuilder lb(loc, rewriter); - BroadcastOp::Adaptor transformed(operands); Value zero = lb.create(0); Type indexTy = lb.getIndexType(); @@ -138,7 +132,7 @@ LogicalResult BroadcastOpConverter::matchAndRewrite( // representing the shape extents, the rank is the extent of the only // dimension in the tensor. SmallVector ranks, rankDiffs; - llvm::append_range(ranks, llvm::map_range(transformed.shapes(), [&](Value v) { + llvm::append_range(ranks, llvm::map_range(adaptor.shapes(), [&](Value v) { return lb.create(v, zero); })); @@ -157,9 +151,8 @@ LogicalResult BroadcastOpConverter::matchAndRewrite( Value replacement = lb.create( getExtentTensorType(lb.getContext()), ValueRange{maxRank}, [&](OpBuilder &b, Location loc, ValueRange args) { - Value broadcastedDim = - getBroadcastedDim(ImplicitLocOpBuilder(loc, b), - transformed.shapes(), rankDiffs, args[0]); + Value broadcastedDim = getBroadcastedDim( + ImplicitLocOpBuilder(loc, b), adaptor.shapes(), rankDiffs, args[0]); b.create(loc, broadcastedDim); }); @@ -175,13 +168,13 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ConstShapeOp op, ArrayRef operands, + matchAndRewrite(ConstShapeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; } // namespace LogicalResult ConstShapeOpConverter::matchAndRewrite( - ConstShapeOp op, ArrayRef operands, + ConstShapeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { // For now, this lowering supports only extent tensors, not `shape.shape` @@ -209,13 +202,13 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ConstSizeOp op, ArrayRef operands, + matchAndRewrite(ConstSizeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; } // namespace LogicalResult ConstSizeOpConversion::matchAndRewrite( - ConstSizeOp op, ArrayRef operands, + ConstSizeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { rewriter.replaceOpWithNewOp(op, op.value().getSExtValue()); return success(); @@ -227,17 +220,16 @@ struct IsBroadcastableOpConverter using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(IsBroadcastableOp op, ArrayRef operands, + matchAndRewrite(IsBroadcastableOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; } // namespace LogicalResult IsBroadcastableOpConverter::matchAndRewrite( - IsBroadcastableOp op, ArrayRef operands, + IsBroadcastableOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { // For now, this lowering is only defined on `tensor` operands, not // on shapes. - IsBroadcastableOp::Adaptor transformed(operands); if (!llvm::all_of(op.shapes(), [](Value v) { return !v.getType().isa(); })) return failure(); @@ -252,7 +244,7 @@ LogicalResult IsBroadcastableOpConverter::matchAndRewrite( // representing the shape extents, the rank is the extent of the only // dimension in the tensor. SmallVector ranks, rankDiffs; - llvm::append_range(ranks, llvm::map_range(transformed.shapes(), [&](Value v) { + llvm::append_range(ranks, llvm::map_range(adaptor.shapes(), [&](Value v) { return lb.create(v, zero); })); @@ -279,10 +271,10 @@ LogicalResult IsBroadcastableOpConverter::matchAndRewrite( // could reuse the Broadcast lowering entirely, but we redo the work // here to make optimizations easier between the two loops. Value broadcastedDim = getBroadcastedDim( - ImplicitLocOpBuilder(loc, b), transformed.shapes(), rankDiffs, iv); + ImplicitLocOpBuilder(loc, b), adaptor.shapes(), rankDiffs, iv); Value broadcastable = iterArgs[0]; - for (auto tup : llvm::zip(transformed.shapes(), rankDiffs)) { + for (auto tup : llvm::zip(adaptor.shapes(), rankDiffs)) { Value shape, rankDiff; std::tie(shape, rankDiff) = tup; Value outOfBounds = @@ -327,16 +319,14 @@ class GetExtentOpConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(GetExtentOp op, ArrayRef operands, + matchAndRewrite(GetExtentOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; } // namespace LogicalResult GetExtentOpConverter::matchAndRewrite( - GetExtentOp op, ArrayRef operands, + GetExtentOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - GetExtentOp::Adaptor transformed(operands); - // For now, only error-free types are supported by this lowering. if (op.getType().isa()) return failure(); @@ -346,14 +336,13 @@ LogicalResult GetExtentOpConverter::matchAndRewrite( if (auto shapeOfOp = op.shape().getDefiningOp()) { if (shapeOfOp.arg().getType().isa()) { rewriter.replaceOpWithNewOp(op, shapeOfOp.arg(), - transformed.dim()); + adaptor.dim()); return success(); } } - rewriter.replaceOpWithNewOp(op, rewriter.getIndexType(), - transformed.shape(), - ValueRange{transformed.dim()}); + rewriter.replaceOpWithNewOp( + op, rewriter.getIndexType(), adaptor.shape(), ValueRange{adaptor.dim()}); return success(); } @@ -363,20 +352,19 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(shape::RankOp op, ArrayRef operands, + matchAndRewrite(shape::RankOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; } // namespace LogicalResult -RankOpConverter::matchAndRewrite(shape::RankOp op, ArrayRef operands, +RankOpConverter::matchAndRewrite(shape::RankOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { // For now, this lowering supports only error-free types. if (op.getType().isa()) return failure(); - shape::RankOp::Adaptor transformed(operands); - rewriter.replaceOpWithNewOp(op, transformed.shape(), 0); + rewriter.replaceOpWithNewOp(op, adaptor.shape(), 0); return success(); } @@ -387,32 +375,30 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(shape::ReduceOp op, ArrayRef operands, + matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final; }; } // namespace LogicalResult -ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, ArrayRef operands, +ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { // For now, this lowering is only defined on `tensor` operands. if (op.shape().getType().isa()) return failure(); auto loc = op.getLoc(); - shape::ReduceOp::Adaptor transformed(operands); Value zero = rewriter.create(loc, 0); Value one = rewriter.create(loc, 1); Type indexTy = rewriter.getIndexType(); Value rank = - rewriter.create(loc, indexTy, transformed.shape(), zero); + rewriter.create(loc, indexTy, adaptor.shape(), zero); auto loop = rewriter.create( loc, zero, rank, one, op.initVals(), [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { - Value extent = - b.create(loc, transformed.shape(), iv); + Value extent = b.create(loc, adaptor.shape(), iv); SmallVector mappedValues{iv, extent}; mappedValues.append(args.begin(), args.end()); @@ -468,13 +454,13 @@ struct ShapeEqOpConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ShapeEqOp op, ArrayRef operands, + matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; } // namespace LogicalResult -ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, ArrayRef operands, +ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { if (!llvm::all_of(op.shapes(), [](Value v) { return !v.getType().isa(); })) @@ -487,16 +473,15 @@ ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, ArrayRef operands, return success(); } - ShapeEqOp::Adaptor transformed(operands); auto loc = op.getLoc(); Type indexTy = rewriter.getIndexType(); Value zero = rewriter.create(loc, 0); - Value firstShape = transformed.shapes().front(); + Value firstShape = adaptor.shapes().front(); Value firstRank = rewriter.create(loc, indexTy, firstShape, zero); Value result = nullptr; // Generate a linear sequence of compares, all with firstShape as lhs. - for (Value shape : transformed.shapes().drop_front(1)) { + for (Value shape : adaptor.shapes().drop_front(1)) { Value rank = rewriter.create(loc, indexTy, shape, zero); Value eqRank = rewriter.create(loc, CmpIPredicate::eq, firstRank, rank); @@ -536,13 +521,13 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ShapeOfOp op, ArrayRef operands, + matchAndRewrite(ShapeOfOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; } // namespace LogicalResult ShapeOfOpConversion::matchAndRewrite( - ShapeOfOp op, ArrayRef operands, + ShapeOfOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { // For now, only error-free types are supported by this lowering. @@ -551,8 +536,7 @@ LogicalResult ShapeOfOpConversion::matchAndRewrite( // For ranked tensor arguments, lower to `tensor.from_elements`. auto loc = op.getLoc(); - ShapeOfOp::Adaptor transformed(operands); - Value tensor = transformed.arg(); + Value tensor = adaptor.arg(); Type tensorTy = tensor.getType(); if (tensorTy.isa()) { @@ -599,13 +583,13 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(SplitAtOp op, ArrayRef operands, + matchAndRewrite(SplitAtOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; } // namespace LogicalResult SplitAtOpConversion::matchAndRewrite( - SplitAtOp op, ArrayRef operands, + SplitAtOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { // Error conditions are not implemented, only lower if all operands and // results are extent tensors. @@ -613,13 +597,12 @@ LogicalResult SplitAtOpConversion::matchAndRewrite( [](Value v) { return v.getType().isa(); })) return failure(); - SplitAtOp::Adaptor transformed(op); ImplicitLocOpBuilder b(op.getLoc(), rewriter); Value zero = b.create(0); - Value rank = b.create(transformed.operand(), zero); + Value rank = b.create(adaptor.operand(), zero); // index < 0 ? index + rank : index - Value originalIndex = transformed.index(); + Value originalIndex = adaptor.index(); Value add = b.create(originalIndex, rank); Value indexIsNegative = b.create(CmpIPredicate::slt, originalIndex, zero); @@ -627,10 +610,10 @@ LogicalResult SplitAtOpConversion::matchAndRewrite( Value one = b.create(1); Value head = - b.create(transformed.operand(), zero, index, one); + b.create(adaptor.operand(), zero, index, one); Value tailSize = b.create(rank, index); - Value tail = b.create(transformed.operand(), index, - tailSize, one); + Value tail = + b.create(adaptor.operand(), index, tailSize, one); rewriter.replaceOp(op, {head, tail}); return success(); } @@ -642,10 +625,8 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ToExtentTensorOp op, ArrayRef operands, + matchAndRewrite(ToExtentTensorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - ToExtentTensorOpAdaptor adaptor(operands); - if (!adaptor.input().getType().isa()) return rewriter.notifyMatchFailure(op, "input needs to be a tensor"); diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp index 7a59330..8327492 100644 --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -292,7 +292,7 @@ struct FuncOpConversion : public FuncOpConversionBase { : FuncOpConversionBase(converter) {} LogicalResult - matchAndRewrite(FuncOp funcOp, ArrayRef operands, + matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter); if (!newFuncOp) @@ -319,7 +319,7 @@ struct BarePtrFuncOpConversion : public FuncOpConversionBase { using FuncOpConversionBase::FuncOpConversionBase; LogicalResult - matchAndRewrite(FuncOp funcOp, ArrayRef operands, + matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // TODO: bare ptr conversion could be handled by argument materialization @@ -442,10 +442,9 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(AssertOp op, ArrayRef operands, + matchAndRewrite(AssertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); - AssertOp::Adaptor transformed(operands); // Insert the `abort` declaration if necessary. auto module = op->getParentOfType(); @@ -471,7 +470,7 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern { // Generate assertion test. rewriter.setInsertionPointToEnd(opBlock); rewriter.replaceOpWithNewOp( - op, transformed.arg(), continuationBlock, failureBlock); + op, adaptor.arg(), continuationBlock, failureBlock); return success(); } @@ -481,7 +480,7 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(ConstantOp op, ArrayRef operands, + matchAndRewrite(ConstantOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // If constant refers to a function, convert it to "addressof". if (auto symbolRef = op.getValue().dyn_cast()) { @@ -506,8 +505,8 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern { op, "referring to a symbol outside of the current module"); return LLVM::detail::oneToOneRewrite( - op, LLVM::ConstantOp::getOperationName(), operands, *getTypeConverter(), - rewriter); + op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(), + *getTypeConverter(), rewriter); } }; @@ -520,10 +519,8 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern { using Base = ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(CallOpType callOp, ArrayRef operands, + matchAndRewrite(CallOpType callOp, typename CallOpType::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - typename CallOpType::Adaptor transformed(operands); - // Pack the result types into a struct. Type packedResult = nullptr; unsigned numResults = callOp.getNumResults(); @@ -536,8 +533,8 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern { } auto promoted = this->getTypeConverter()->promoteOperands( - callOp.getLoc(), /*opOperands=*/callOp->getOperands(), operands, - rewriter); + callOp.getLoc(), /*opOperands=*/callOp->getOperands(), + adaptor.getOperands(), rewriter); auto newOp = rewriter.create( callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(), promoted, callOp->getAttrs()); @@ -591,22 +588,21 @@ struct UnrealizedConversionCastOpLowering UnrealizedConversionCastOp>::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(UnrealizedConversionCastOp op, ArrayRef operands, + matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - UnrealizedConversionCastOp::Adaptor transformed(operands); SmallVector convertedTypes; if (succeeded(typeConverter->convertTypes(op.outputs().getTypes(), convertedTypes)) && - convertedTypes == transformed.inputs().getTypes()) { - rewriter.replaceOp(op, transformed.inputs()); + convertedTypes == adaptor.inputs().getTypes()) { + rewriter.replaceOp(op, adaptor.inputs()); return success(); } convertedTypes.clear(); - if (succeeded(typeConverter->convertTypes(transformed.inputs().getTypes(), + if (succeeded(typeConverter->convertTypes(adaptor.inputs().getTypes(), convertedTypes)) && convertedTypes == op.outputs().getType()) { - rewriter.replaceOp(op, transformed.inputs()); + rewriter.replaceOp(op, adaptor.inputs()); return success(); } return failure(); @@ -617,12 +613,12 @@ struct RankOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(RankOp op, ArrayRef operands, + matchAndRewrite(RankOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Type operandType = op.memrefOrTensor().getType(); if (auto unrankedMemRefType = operandType.dyn_cast()) { - UnrankedMemRefDescriptor desc(RankOp::Adaptor(operands).memrefOrTensor()); + UnrankedMemRefDescriptor desc(adaptor.memrefOrTensor()); rewriter.replaceOp(op, {desc.rank(rewriter, loc)}); return success(); } @@ -658,10 +654,8 @@ struct IndexCastOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(IndexCastOp indexCastOp, ArrayRef operands, + matchAndRewrite(IndexCastOp indexCastOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - IndexCastOpAdaptor transformed(operands); - auto targetType = typeConverter->convertType(indexCastOp.getResult().getType()); auto targetElementType = @@ -669,18 +663,18 @@ struct IndexCastOpLowering : public ConvertOpToLLVMPattern { ->convertType(getElementTypeOrSelf(indexCastOp.getResult())) .cast(); auto sourceElementType = - getElementTypeOrSelf(transformed.in()).cast(); + getElementTypeOrSelf(adaptor.in()).cast(); unsigned targetBits = targetElementType.getWidth(); unsigned sourceBits = sourceElementType.getWidth(); if (targetBits == sourceBits) - rewriter.replaceOp(indexCastOp, transformed.in()); + rewriter.replaceOp(indexCastOp, adaptor.in()); else if (targetBits < sourceBits) rewriter.replaceOpWithNewOp(indexCastOp, targetType, - transformed.in()); + adaptor.in()); else rewriter.replaceOpWithNewOp(indexCastOp, targetType, - transformed.in()); + adaptor.in()); return success(); } }; @@ -696,10 +690,9 @@ struct CmpIOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(CmpIOp cmpiOp, ArrayRef operands, + matchAndRewrite(CmpIOp cmpiOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - CmpIOpAdaptor transformed(operands); - auto operandType = transformed.lhs().getType(); + auto operandType = adaptor.lhs().getType(); auto resultType = cmpiOp.getResult().getType(); // Handle the scalar and 1D vector cases. @@ -707,7 +700,7 @@ struct CmpIOpLowering : public ConvertOpToLLVMPattern { rewriter.replaceOpWithNewOp( cmpiOp, typeConverter->convertType(resultType), convertCmpPredicate(cmpiOp.getPredicate()), - transformed.lhs(), transformed.rhs()); + adaptor.lhs(), adaptor.rhs()); return success(); } @@ -716,13 +709,13 @@ struct CmpIOpLowering : public ConvertOpToLLVMPattern { return rewriter.notifyMatchFailure(cmpiOp, "expected vector result type"); return LLVM::detail::handleMultidimensionalVectors( - cmpiOp.getOperation(), operands, *getTypeConverter(), + cmpiOp.getOperation(), adaptor.getOperands(), *getTypeConverter(), [&](Type llvm1DVectorTy, ValueRange operands) { - CmpIOpAdaptor transformed(operands); + CmpIOpAdaptor adaptor(operands); return rewriter.create( cmpiOp.getLoc(), llvm1DVectorTy, convertCmpPredicate(cmpiOp.getPredicate()), - transformed.lhs(), transformed.rhs()); + adaptor.lhs(), adaptor.rhs()); }, rewriter); @@ -734,10 +727,9 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(CmpFOp cmpfOp, ArrayRef operands, + matchAndRewrite(CmpFOp cmpfOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - CmpFOpAdaptor transformed(operands); - auto operandType = transformed.lhs().getType(); + auto operandType = adaptor.lhs().getType(); auto resultType = cmpfOp.getResult().getType(); // Handle the scalar and 1D vector cases. @@ -745,7 +737,7 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern { rewriter.replaceOpWithNewOp( cmpfOp, typeConverter->convertType(resultType), convertCmpPredicate(cmpfOp.getPredicate()), - transformed.lhs(), transformed.rhs()); + adaptor.lhs(), adaptor.rhs()); return success(); } @@ -754,13 +746,13 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern { return rewriter.notifyMatchFailure(cmpfOp, "expected vector result type"); return LLVM::detail::handleMultidimensionalVectors( - cmpfOp.getOperation(), operands, *getTypeConverter(), + cmpfOp.getOperation(), adaptor.getOperands(), *getTypeConverter(), [&](Type llvm1DVectorTy, ValueRange operands) { - CmpFOpAdaptor transformed(operands); + CmpFOpAdaptor adaptor(operands); return rewriter.create( cmpfOp.getLoc(), llvm1DVectorTy, convertCmpPredicate(cmpfOp.getPredicate()), - transformed.lhs(), transformed.rhs()); + adaptor.lhs(), adaptor.rhs()); }, rewriter); } @@ -774,10 +766,10 @@ struct OneToOneLLVMTerminatorLowering using Super = OneToOneLLVMTerminatorLowering; LogicalResult - matchAndRewrite(SourceOp op, ArrayRef operands, + matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, operands, op->getSuccessors(), - op->getAttrs()); + rewriter.replaceOpWithNewOp(op, adaptor.getOperands(), + op->getSuccessors(), op->getAttrs()); return success(); } }; @@ -792,7 +784,7 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(ReturnOp op, ArrayRef operands, + matchAndRewrite(ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); unsigned numArguments = op.getNumOperands(); @@ -801,7 +793,7 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern { if (getTypeConverter()->getOptions().useBarePtrCallConv) { // For the bare-ptr calling convention, extract the aligned pointer to // be returned from the memref descriptor. - for (auto it : llvm::zip(op->getOperands(), operands)) { + for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) { Type oldTy = std::get<0>(it).getType(); Value newOperand = std::get<1>(it); if (oldTy.isa()) { @@ -815,7 +807,7 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern { updatedOperands.push_back(newOperand); } } else { - updatedOperands = llvm::to_vector<4>(operands); + updatedOperands = llvm::to_vector<4>(adaptor.getOperands()); (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(), updatedOperands, /*toDynamic=*/true); @@ -870,14 +862,12 @@ struct SplatOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(SplatOp splatOp, ArrayRef operands, + matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { VectorType resultType = splatOp.getType().dyn_cast(); if (!resultType || resultType.getRank() != 1) return failure(); - SplatOp::Adaptor adaptor(operands); - // First insert it into an undef vector so we can shuffle it. auto vectorType = typeConverter->convertType(splatOp.getType()); Value undef = rewriter.create(splatOp.getLoc(), vectorType); @@ -907,9 +897,8 @@ struct SplatNdOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(SplatOp splatOp, ArrayRef operands, + matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - SplatOp::Adaptor adaptor(operands); VectorType resultType = splatOp.getType().dyn_cast(); if (!resultType || resultType.getRank() == 1) return failure(); @@ -984,14 +973,13 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering { using Base::Base; LogicalResult - matchAndRewrite(AtomicRMWOp atomicOp, ArrayRef operands, + matchAndRewrite(AtomicRMWOp atomicOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(match(atomicOp))) return failure(); auto maybeKind = matchSimpleAtomicOp(atomicOp); if (!maybeKind) return failure(); - AtomicRMWOp::Adaptor adaptor(operands); auto resultType = adaptor.value().getType(); auto memRefType = atomicOp.getMemRefType(); auto dataPtr = @@ -1036,11 +1024,10 @@ struct GenericAtomicRMWOpLowering using Base::Base; LogicalResult - matchAndRewrite(GenericAtomicRMWOp atomicOp, ArrayRef operands, + matchAndRewrite(GenericAtomicRMWOp atomicOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = atomicOp.getLoc(); - GenericAtomicRMWOp::Adaptor adaptor(operands); Type valueType = typeConverter->convertType(atomicOp.getResult().getType()); // Split the block into initial, loop, and ending parts. diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp index 0da2209..fe8b925 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp @@ -144,9 +144,9 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(StdOp operation, ArrayRef operands, + matchAndRewrite(StdOp operation, typename StdOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - assert(operands.size() <= 2); + assert(adaptor.getOperands().size() <= 2); auto dstType = this->getTypeConverter()->convertType(operation.getType()); if (!dstType) return failure(); @@ -155,7 +155,8 @@ public: return operation.emitError( "bitwidth emulation is not implemented yet on unsigned op"); } - rewriter.template replaceOpWithNewOp(operation, dstType, operands); + rewriter.template replaceOpWithNewOp(operation, dstType, + adaptor.getOperands()); return success(); } }; @@ -169,7 +170,7 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(SignedRemIOp remOp, ArrayRef operands, + matchAndRewrite(SignedRemIOp remOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -183,19 +184,19 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(StdOp operation, ArrayRef operands, + matchAndRewrite(StdOp operation, typename StdOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - assert(operands.size() == 2); + assert(adaptor.getOperands().size() == 2); auto dstType = this->getTypeConverter()->convertType(operation.getResult().getType()); if (!dstType) return failure(); - if (isBoolScalarOrVector(operands.front().getType())) { - rewriter.template replaceOpWithNewOp(operation, dstType, - operands); + if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) { + rewriter.template replaceOpWithNewOp( + operation, dstType, adaptor.getOperands()); } else { - rewriter.template replaceOpWithNewOp(operation, dstType, - operands); + rewriter.template replaceOpWithNewOp( + operation, dstType, adaptor.getOperands()); } return success(); } @@ -208,7 +209,7 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ConstantOp constOp, ArrayRef operands, + matchAndRewrite(ConstantOp constOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -218,7 +219,7 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ConstantOp constOp, ArrayRef operands, + matchAndRewrite(ConstantOp constOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -228,7 +229,7 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(CmpFOp cmpFOp, ArrayRef operands, + matchAndRewrite(CmpFOp cmpFOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -239,7 +240,7 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(CmpFOp cmpFOp, ArrayRef operands, + matchAndRewrite(CmpFOp cmpFOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -250,7 +251,7 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(CmpFOp cmpFOp, ArrayRef operands, + matchAndRewrite(CmpFOp cmpFOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -260,7 +261,7 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(CmpIOp cmpIOp, ArrayRef operands, + matchAndRewrite(CmpIOp cmpIOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -270,7 +271,7 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(CmpIOp cmpIOp, ArrayRef operands, + matchAndRewrite(CmpIOp cmpIOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -280,7 +281,7 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ReturnOp returnOp, ArrayRef operands, + matchAndRewrite(ReturnOp returnOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -289,7 +290,7 @@ class SelectOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(SelectOp op, ArrayRef operands, + matchAndRewrite(SelectOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -299,7 +300,7 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(SplatOp op, ArrayRef operands, + matchAndRewrite(SplatOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -310,9 +311,9 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ZeroExtendIOp op, ArrayRef operands, + matchAndRewrite(ZeroExtendIOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto srcType = operands.front().getType(); + auto srcType = adaptor.getOperands().front().getType(); if (!isBoolScalarOrVector(srcType)) return failure(); @@ -322,7 +323,7 @@ public: Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); rewriter.template replaceOpWithNewOp( - op, dstType, operands.front(), one, zero); + op, dstType, adaptor.getOperands().front(), one, zero); return success(); } }; @@ -338,7 +339,7 @@ public: byteCountThreshold(threshold) {} LogicalResult - matchAndRewrite(tensor::ExtractOp extractOp, ArrayRef operands, + matchAndRewrite(tensor::ExtractOp extractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { TensorType tensorType = extractOp.tensor().getType().cast(); @@ -351,7 +352,6 @@ public: "exceeding byte count threshold"); Location loc = extractOp.getLoc(); - tensor::ExtractOp::Adaptor adaptor(operands); int64_t rank = tensorType.getRank(); SmallVector strides(rank, 1); @@ -396,7 +396,7 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(TruncateIOp op, ArrayRef operands, + matchAndRewrite(TruncateIOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto dstType = this->getTypeConverter()->convertType(op.getResult().getType()); @@ -404,11 +404,11 @@ public: return failure(); Location loc = op.getLoc(); - auto srcType = operands.front().getType(); + auto srcType = adaptor.getOperands().front().getType(); // Check if (x & 1) == 1. Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter); - Value maskedSrc = - rewriter.create(loc, srcType, operands[0], mask); + Value maskedSrc = rewriter.create( + loc, srcType, adaptor.getOperands()[0], mask); Value isOne = rewriter.create(loc, maskedSrc, mask); Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); @@ -425,9 +425,9 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(UIToFPOp op, ArrayRef operands, + matchAndRewrite(UIToFPOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto srcType = operands.front().getType(); + auto srcType = adaptor.getOperands().front().getType(); if (!isBoolScalarOrVector(srcType)) return failure(); @@ -437,7 +437,7 @@ public: Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); rewriter.template replaceOpWithNewOp( - op, dstType, operands.front(), one, zero); + op, dstType, adaptor.getOperands().front(), one, zero); return success(); } }; @@ -449,10 +449,10 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(StdOp operation, ArrayRef operands, + matchAndRewrite(StdOp operation, typename StdOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - assert(operands.size() == 1); - auto srcType = operands.front().getType(); + assert(adaptor.getOperands().size() == 1); + auto srcType = adaptor.getOperands().front().getType(); auto dstType = this->getTypeConverter()->convertType(operation.getResult().getType()); if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType)) @@ -460,10 +460,10 @@ public: if (dstType == srcType) { // Due to type conversion, we are seeing the same source and target type. // Then we can just erase this operation by forwarding its operand. - rewriter.replaceOp(operation, operands.front()); + rewriter.replaceOp(operation, adaptor.getOperands().front()); } else { rewriter.template replaceOpWithNewOp(operation, dstType, - operands); + adaptor.getOperands()); } return success(); } @@ -475,7 +475,7 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(XOrOp xorOp, ArrayRef operands, + matchAndRewrite(XOrOp xorOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -486,7 +486,7 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(XOrOp xorOp, ArrayRef operands, + matchAndRewrite(XOrOp xorOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -497,10 +497,11 @@ public: //===----------------------------------------------------------------------===// LogicalResult SignedRemIOpPattern::matchAndRewrite( - SignedRemIOp remOp, ArrayRef operands, + SignedRemIOp remOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - Value result = emulateSignedRemainder(remOp.getLoc(), operands[0], - operands[1], operands[0], rewriter); + Value result = emulateSignedRemainder( + remOp.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1], + adaptor.getOperands()[0], rewriter); rewriter.replaceOp(remOp, result); return success(); @@ -514,7 +515,7 @@ LogicalResult SignedRemIOpPattern::matchAndRewrite( // so that the tensor case can be moved to TensorToSPIRV conversion. But, // std.constant is for the standard dialect though. LogicalResult ConstantCompositeOpPattern::matchAndRewrite( - ConstantOp constOp, ArrayRef operands, + ConstantOp constOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto srcType = constOp.getType().dyn_cast(); if (!srcType) @@ -599,7 +600,7 @@ LogicalResult ConstantCompositeOpPattern::matchAndRewrite( //===----------------------------------------------------------------------===// LogicalResult ConstantScalarOpPattern::matchAndRewrite( - ConstantOp constOp, ArrayRef operands, + ConstantOp constOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Type srcType = constOp.getType(); if (!srcType.isIntOrIndexOrFloat()) @@ -653,16 +654,13 @@ LogicalResult ConstantScalarOpPattern::matchAndRewrite( //===----------------------------------------------------------------------===// LogicalResult -CmpFOpPattern::matchAndRewrite(CmpFOp cmpFOp, ArrayRef operands, +CmpFOpPattern::matchAndRewrite(CmpFOp cmpFOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - CmpFOpAdaptor cmpFOpOperands(operands); - switch (cmpFOp.getPredicate()) { #define DISPATCH(cmpPredicate, spirvOp) \ case cmpPredicate: \ rewriter.replaceOpWithNewOp(cmpFOp, cmpFOp.getResult().getType(), \ - cmpFOpOperands.lhs(), \ - cmpFOpOperands.rhs()); \ + adaptor.lhs(), adaptor.rhs()); \ return success(); // Ordered. @@ -689,19 +687,17 @@ CmpFOpPattern::matchAndRewrite(CmpFOp cmpFOp, ArrayRef operands, } LogicalResult CmpFOpNanKernelPattern::matchAndRewrite( - CmpFOp cmpFOp, ArrayRef operands, + CmpFOp cmpFOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - CmpFOpAdaptor cmpFOpOperands(operands); - if (cmpFOp.getPredicate() == CmpFPredicate::ORD) { - rewriter.replaceOpWithNewOp(cmpFOp, cmpFOpOperands.lhs(), - cmpFOpOperands.rhs()); + rewriter.replaceOpWithNewOp(cmpFOp, adaptor.lhs(), + adaptor.rhs()); return success(); } if (cmpFOp.getPredicate() == CmpFPredicate::UNO) { - rewriter.replaceOpWithNewOp( - cmpFOp, cmpFOpOperands.lhs(), cmpFOpOperands.rhs()); + rewriter.replaceOpWithNewOp(cmpFOp, adaptor.lhs(), + adaptor.rhs()); return success(); } @@ -709,17 +705,16 @@ LogicalResult CmpFOpNanKernelPattern::matchAndRewrite( } LogicalResult CmpFOpNanNonePattern::matchAndRewrite( - CmpFOp cmpFOp, ArrayRef operands, + CmpFOp cmpFOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { if (cmpFOp.getPredicate() != CmpFPredicate::ORD && cmpFOp.getPredicate() != CmpFPredicate::UNO) return failure(); - CmpFOpAdaptor cmpFOpOperands(operands); Location loc = cmpFOp.getLoc(); - Value lhsIsNan = rewriter.create(loc, cmpFOpOperands.lhs()); - Value rhsIsNan = rewriter.create(loc, cmpFOpOperands.rhs()); + Value lhsIsNan = rewriter.create(loc, adaptor.lhs()); + Value rhsIsNan = rewriter.create(loc, adaptor.rhs()); Value replace = rewriter.create(loc, lhsIsNan, rhsIsNan); if (cmpFOp.getPredicate() == CmpFPredicate::ORD) @@ -734,10 +729,8 @@ LogicalResult CmpFOpNanNonePattern::matchAndRewrite( //===----------------------------------------------------------------------===// LogicalResult -BoolCmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef operands, +BoolCmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - CmpIOpAdaptor cmpIOpOperands(operands); - Type operandType = cmpIOp.lhs().getType(); if (!isBoolScalarOrVector(operandType)) return failure(); @@ -746,8 +739,7 @@ BoolCmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef operands, #define DISPATCH(cmpPredicate, spirvOp) \ case cmpPredicate: \ rewriter.replaceOpWithNewOp(cmpIOp, cmpIOp.getResult().getType(), \ - cmpIOpOperands.lhs(), \ - cmpIOpOperands.rhs()); \ + adaptor.lhs(), adaptor.rhs()); \ return success(); DISPATCH(CmpIPredicate::eq, spirv::LogicalEqualOp); @@ -760,10 +752,8 @@ BoolCmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef operands, } LogicalResult -CmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef operands, +CmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - CmpIOpAdaptor cmpIOpOperands(operands); - Type operandType = cmpIOp.lhs().getType(); if (isBoolScalarOrVector(operandType)) return failure(); @@ -777,8 +767,7 @@ CmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef operands, "bitwidth emulation is not implemented yet on unsigned op"); \ } \ rewriter.replaceOpWithNewOp(cmpIOp, cmpIOp.getResult().getType(), \ - cmpIOpOperands.lhs(), \ - cmpIOpOperands.rhs()); \ + adaptor.lhs(), adaptor.rhs()); \ return success(); DISPATCH(CmpIPredicate::eq, spirv::IEqualOp); @@ -802,13 +791,14 @@ CmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef operands, //===----------------------------------------------------------------------===// LogicalResult -ReturnOpPattern::matchAndRewrite(ReturnOp returnOp, ArrayRef operands, +ReturnOpPattern::matchAndRewrite(ReturnOp returnOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { if (returnOp.getNumOperands() > 1) return failure(); if (returnOp.getNumOperands() == 1) { - rewriter.replaceOpWithNewOp(returnOp, operands[0]); + rewriter.replaceOpWithNewOp(returnOp, + adaptor.getOperands()[0]); } else { rewriter.replaceOpWithNewOp(returnOp); } @@ -820,12 +810,10 @@ ReturnOpPattern::matchAndRewrite(ReturnOp returnOp, ArrayRef operands, //===----------------------------------------------------------------------===// LogicalResult -SelectOpPattern::matchAndRewrite(SelectOp op, ArrayRef operands, +SelectOpPattern::matchAndRewrite(SelectOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - SelectOpAdaptor selectOperands(operands); - rewriter.replaceOpWithNewOp(op, selectOperands.condition(), - selectOperands.true_value(), - selectOperands.false_value()); + rewriter.replaceOpWithNewOp( + op, adaptor.condition(), adaptor.true_value(), adaptor.false_value()); return success(); } @@ -834,12 +822,11 @@ SelectOpPattern::matchAndRewrite(SelectOp op, ArrayRef operands, //===----------------------------------------------------------------------===// LogicalResult -SplatPattern::matchAndRewrite(SplatOp op, ArrayRef operands, +SplatPattern::matchAndRewrite(SplatOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto dstVecType = op.getType().dyn_cast(); if (!dstVecType || !spirv::CompositeType::isValid(dstVecType)) return failure(); - SplatOp::Adaptor adaptor(operands); SmallVector source(dstVecType.getNumElements(), adaptor.input()); rewriter.replaceOpWithNewOp(op, dstVecType, source); @@ -851,34 +838,35 @@ SplatPattern::matchAndRewrite(SplatOp op, ArrayRef operands, //===----------------------------------------------------------------------===// LogicalResult -XOrOpPattern::matchAndRewrite(XOrOp xorOp, ArrayRef operands, +XOrOpPattern::matchAndRewrite(XOrOp xorOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - assert(operands.size() == 2); + assert(adaptor.getOperands().size() == 2); - if (isBoolScalarOrVector(operands.front().getType())) + if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) return failure(); auto dstType = getTypeConverter()->convertType(xorOp.getType()); if (!dstType) return failure(); - rewriter.replaceOpWithNewOp(xorOp, dstType, operands); + rewriter.replaceOpWithNewOp(xorOp, dstType, + adaptor.getOperands()); return success(); } LogicalResult -BoolXOrOpPattern::matchAndRewrite(XOrOp xorOp, ArrayRef operands, +BoolXOrOpPattern::matchAndRewrite(XOrOp xorOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - assert(operands.size() == 2); + assert(adaptor.getOperands().size() == 2); - if (!isBoolScalarOrVector(operands.front().getType())) + if (!isBoolScalarOrVector(adaptor.getOperands().front().getType())) return failure(); auto dstType = getTypeConverter()->convertType(xorOp.getType()); if (!dstType) return failure(); rewriter.replaceOpWithNewOp(xorOp, dstType, - operands); + adaptor.getOperands()); return success(); } diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index ed1b03c..8ee0c43 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -947,7 +947,7 @@ class ConvConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(tosa::Conv2DOp op, ArrayRef args, + matchAndRewrite(tosa::Conv2DOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { Location loc = op->getLoc(); Value input = op->getOperand(0); @@ -1111,7 +1111,7 @@ class DepthwiseConvConverter public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(tosa::DepthwiseConv2DOp op, ArrayRef args, + matchAndRewrite(tosa::DepthwiseConv2DOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { Location loc = op->getLoc(); Value input = op->getOperand(0); @@ -1266,7 +1266,7 @@ class TransposeConvConverter public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(tosa::TransposeConv2DOp op, ArrayRef args, + matchAndRewrite(tosa::TransposeConv2DOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { Location loc = op->getLoc(); Value input = op->getOperand(0); @@ -1336,10 +1336,8 @@ class MatMulConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(tosa::MatMulOp op, ArrayRef args, + matchAndRewrite(tosa::MatMulOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { - tosa::MatMulOp::Adaptor adaptor(args); - Location loc = op.getLoc(); auto outputTy = op.getType().cast(); @@ -1377,7 +1375,7 @@ class FullyConnectedConverter public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(tosa::FullyConnectedOp op, ArrayRef args, + matchAndRewrite(tosa::FullyConnectedOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { Location loc = op.getLoc(); auto outputTy = op.getType().cast(); @@ -1486,15 +1484,13 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(tosa::ReshapeOp reshape, ArrayRef args, + matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { - typename tosa::ReshapeOp::Adaptor operands(args); - - ShapedType operandTy = operands.input1().getType().cast(); + ShapedType operandTy = adaptor.input1().getType().cast(); ShapedType resultTy = reshape.getType().template cast(); if (operandTy == resultTy) { - rewriter.replaceOp(reshape, args[0]); + rewriter.replaceOp(reshape, adaptor.getOperands()[0]); return success(); } @@ -1575,19 +1571,20 @@ public: auto collapsedTy = RankedTensorType::get({totalElems}, elemTy); Value collapsedOp = rewriter.create( - loc, collapsedTy, args[0], collapsingMap); + loc, collapsedTy, adaptor.getOperands()[0], collapsingMap); rewriter.replaceOpWithNewOp( reshape, resultTy, collapsedOp, expandingMap); return success(); } - if (resultTy.getRank() < args[0].getType().cast().getRank()) + if (resultTy.getRank() < + adaptor.getOperands()[0].getType().cast().getRank()) rewriter.replaceOpWithNewOp( - reshape, resultTy, args[0], reassociationMap); + reshape, resultTy, adaptor.getOperands()[0], reassociationMap); else rewriter.replaceOpWithNewOp( - reshape, resultTy, args[0], reassociationMap); + reshape, resultTy, adaptor.getOperands()[0], reassociationMap); return success(); } @@ -2117,7 +2114,7 @@ struct ConcatConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(tosa::ConcatOp op, ArrayRef args, + matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto resultType = op.getType().dyn_cast(); if (!resultType || !resultType.hasStaticShape()) { @@ -2136,11 +2133,12 @@ struct ConcatConverter : public OpConversionPattern { offsets.resize(rank, rewriter.create(loc, 0)); for (int i = 0; i < rank; ++i) { - sizes.push_back(rewriter.create(loc, args[0], i)); + sizes.push_back( + rewriter.create(loc, adaptor.getOperands()[0], i)); } Value resultDimSize = sizes[axis]; - for (auto arg : args.drop_front()) { + for (auto arg : adaptor.getOperands().drop_front()) { auto size = rewriter.create(loc, arg, axisValue); resultDimSize = rewriter.create(loc, resultDimSize, size); } @@ -2154,7 +2152,7 @@ struct ConcatConverter : public OpConversionPattern { Value result = rewriter.create(loc, zeroVal, init).getResult(0); - for (auto arg : args) { + for (auto arg : adaptor.getOperands()) { sizes[axis] = rewriter.create(loc, arg, axisValue); result = rewriter.create(loc, arg, result, offsets, sizes, strides); @@ -2230,7 +2228,7 @@ struct TileConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(tosa::TileOp op, ArrayRef args, + matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto input = op.input1(); @@ -2488,10 +2486,10 @@ class GatherConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(tosa::GatherOp op, ArrayRef args, + matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { - auto input = args[0]; - auto indices = args[1]; + auto input = adaptor.getOperands()[0]; + auto indices = adaptor.getOperands()[1]; auto inputTy = input.getType().cast(); auto indicesTy = indices.getType().cast(); diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index de9dfa1..27037cb 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -36,13 +36,12 @@ struct VectorBitcastConvert final using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(vector::BitCastOp bitcastOp, ArrayRef operands, + matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto dstType = getTypeConverter()->convertType(bitcastOp.getType()); if (!dstType) return failure(); - vector::BitCastOp::Adaptor adaptor(operands); if (dstType == adaptor.source().getType()) rewriter.replaceOp(bitcastOp, adaptor.source()); else @@ -58,12 +57,11 @@ struct VectorBroadcastConvert final using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(vector::BroadcastOp broadcastOp, ArrayRef operands, + matchAndRewrite(vector::BroadcastOp broadcastOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (broadcastOp.source().getType().isa() || !spirv::CompositeType::isValid(broadcastOp.getVectorType())) return failure(); - vector::BroadcastOp::Adaptor adaptor(operands); SmallVector source(broadcastOp.getVectorType().getNumElements(), adaptor.source()); rewriter.replaceOpWithNewOp( @@ -77,7 +75,7 @@ struct VectorExtractOpConvert final using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(vector::ExtractOp extractOp, ArrayRef operands, + matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Only support extracting a scalar value now. VectorType resultVectorType = extractOp.getType().dyn_cast(); @@ -88,7 +86,6 @@ struct VectorExtractOpConvert final if (!dstType) return failure(); - vector::ExtractOp::Adaptor adaptor(operands); if (adaptor.vector().getType().isa()) { rewriter.replaceOp(extractOp, adaptor.vector()); return success(); @@ -106,8 +103,7 @@ struct VectorExtractStridedSliceOpConvert final using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(vector::ExtractStridedSliceOp extractOp, - ArrayRef operands, + matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto dstType = getTypeConverter()->convertType(extractOp.getType()); if (!dstType) @@ -120,7 +116,7 @@ struct VectorExtractStridedSliceOpConvert final if (stride != 1) return failure(); - Value srcVector = operands.front(); + Value srcVector = adaptor.getOperands().front(); // Extract vector<1xT> case. if (dstType.isa()) { @@ -144,11 +140,10 @@ struct VectorFmaOpConvert final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(vector::FMAOp fmaOp, ArrayRef operands, + matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (!spirv::CompositeType::isValid(fmaOp.getVectorType())) return failure(); - vector::FMAOp::Adaptor adaptor(operands); rewriter.replaceOpWithNewOp( fmaOp, fmaOp.getType(), adaptor.lhs(), adaptor.rhs(), adaptor.acc()); return success(); @@ -160,12 +155,11 @@ struct VectorInsertOpConvert final using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(vector::InsertOp insertOp, ArrayRef operands, + matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (insertOp.getSourceType().isa() || !spirv::CompositeType::isValid(insertOp.getDestVectorType())) return failure(); - vector::InsertOp::Adaptor adaptor(operands); int32_t id = getFirstIntValue(insertOp.position()); rewriter.replaceOpWithNewOp( insertOp, adaptor.source(), adaptor.dest(), id); @@ -178,12 +172,10 @@ struct VectorExtractElementOpConvert final using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(vector::ExtractElementOp extractElementOp, - ArrayRef operands, + matchAndRewrite(vector::ExtractElementOp extractElementOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (!spirv::CompositeType::isValid(extractElementOp.getVectorType())) return failure(); - vector::ExtractElementOp::Adaptor adaptor(operands); rewriter.replaceOpWithNewOp( extractElementOp, extractElementOp.getType(), adaptor.vector(), extractElementOp.position()); @@ -196,12 +188,10 @@ struct VectorInsertElementOpConvert final using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(vector::InsertElementOp insertElementOp, - ArrayRef operands, + matchAndRewrite(vector::InsertElementOp insertElementOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (!spirv::CompositeType::isValid(insertElementOp.getDestVectorType())) return failure(); - vector::InsertElementOp::Adaptor adaptor(operands); rewriter.replaceOpWithNewOp( insertElementOp, insertElementOp.getType(), insertElementOp.dest(), adaptor.source(), insertElementOp.position()); @@ -214,11 +204,10 @@ struct VectorInsertStridedSliceOpConvert final using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(vector::InsertStridedSliceOp insertOp, - ArrayRef operands, + matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Value srcVector = operands.front(); - Value dstVector = operands.back(); + Value srcVector = adaptor.getOperands().front(); + Value dstVector = adaptor.getOperands().back(); // Insert scalar values not supported yet. if (srcVector.getType().isa() || diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp index 6a94299..5e60bed 100644 --- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp @@ -84,7 +84,7 @@ Value castPtr(ConversionPatternRewriter &rewriter, Location loc, Value ptr) { struct TileZeroConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(TileZeroOp op, ArrayRef operands, + matchAndRewrite(TileZeroOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { VectorType vType = op.getVectorType(); // Determine m x n tile sizes. @@ -102,9 +102,8 @@ struct TileLoadConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(TileLoadOp op, ArrayRef operands, + matchAndRewrite(TileLoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - TileLoadOp::Adaptor adaptor(operands); MemRefType mType = op.getMemRefType(); VectorType vType = op.getVectorType(); // Determine m x n tile sizes. @@ -130,9 +129,8 @@ struct TileStoreConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(TileStoreOp op, ArrayRef operands, + matchAndRewrite(TileStoreOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - TileStoreOp::Adaptor adaptor(operands); MemRefType mType = op.getMemRefType(); VectorType vType = op.getVectorType(); // Determine m x n tile sizes. @@ -156,9 +154,8 @@ struct TileStoreConversion : public ConvertOpToLLVMPattern { struct TileMulFConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(TileMulFOp op, ArrayRef operands, + matchAndRewrite(TileMulFOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - TileMulFOp::Adaptor adaptor(operands); VectorType aType = op.getLhsVectorType(); VectorType bType = op.getRhsVectorType(); VectorType cType = op.getVectorType(); @@ -179,9 +176,8 @@ struct TileMulFConversion : public ConvertOpToLLVMPattern { struct TileMulIConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(TileMulIOp op, ArrayRef operands, + matchAndRewrite(TileMulIOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - TileMulIOp::Adaptor adaptor(operands); VectorType aType = op.getLhsVectorType(); VectorType bType = op.getRhsVectorType(); VectorType cType = op.getVectorType(); diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp index ed50f45..c4770195 100644 --- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp @@ -46,12 +46,13 @@ class ForwardOperands : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(OpTy op, ArrayRef operands, + matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter) const final { - if (ValueRange(operands).getTypes() == op->getOperands().getTypes()) + if (adaptor.getOperands().getTypes() == op->getOperands().getTypes()) return rewriter.notifyMatchFailure(op, "operand types already match"); - rewriter.updateRootInPlace(op, [&]() { op->setOperands(operands); }); + rewriter.updateRootInPlace( + op, [&]() { op->setOperands(adaptor.getOperands()); }); return success(); } }; @@ -61,9 +62,10 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ReturnOp op, ArrayRef operands, + matchAndRewrite(ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { - rewriter.updateRootInPlace(op, [&]() { op->setOperands(operands); }); + rewriter.updateRootInPlace( + op, [&]() { op->setOperands(adaptor.getOperands()); }); return success(); } }; @@ -118,13 +120,12 @@ struct ScalableLoadOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(ScalableLoadOp loadOp, ArrayRef operands, + matchAndRewrite(ScalableLoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto type = loadOp.getMemRefType(); if (!isConvertibleAndHasIdentityMaps(type)) return failure(); - ScalableLoadOp::Adaptor transformed(operands); LLVMTypeConverter converter(loadOp.getContext()); auto resultType = loadOp.result().getType(); @@ -138,9 +139,8 @@ struct ScalableLoadOpLowering : public ConvertOpToLLVMPattern { converter) .getValue()); } - Value dataPtr = - getStridedElementPtr(loadOp.getLoc(), type, transformed.base(), - transformed.index(), rewriter); + Value dataPtr = getStridedElementPtr(loadOp.getLoc(), type, adaptor.base(), + adaptor.index(), rewriter); Value bitCastedPtr = rewriter.create( loadOp.getLoc(), llvmDataTypePtr, dataPtr); rewriter.replaceOpWithNewOp(loadOp, bitCastedPtr); @@ -155,13 +155,12 @@ struct ScalableStoreOpLowering using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(ScalableStoreOp storeOp, ArrayRef operands, + matchAndRewrite(ScalableStoreOp storeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto type = storeOp.getMemRefType(); if (!isConvertibleAndHasIdentityMaps(type)) return failure(); - ScalableStoreOp::Adaptor transformed(operands); LLVMTypeConverter converter(storeOp.getContext()); auto resultType = storeOp.value().getType(); @@ -175,12 +174,11 @@ struct ScalableStoreOpLowering converter) .getValue()); } - Value dataPtr = - getStridedElementPtr(storeOp.getLoc(), type, transformed.base(), - transformed.index(), rewriter); + Value dataPtr = getStridedElementPtr(storeOp.getLoc(), type, adaptor.base(), + adaptor.index(), rewriter); Value bitCastedPtr = rewriter.create( storeOp.getLoc(), llvmDataTypePtr, dataPtr); - rewriter.replaceOpWithNewOp(storeOp, transformed.value(), + rewriter.replaceOpWithNewOp(storeOp, adaptor.value(), bitCastedPtr); return success(); } diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp index 2127d7d..2d0886c 100644 --- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp @@ -337,10 +337,10 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(CreateGroupOp op, ArrayRef operands, + matchAndRewrite(CreateGroupOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( - op, GroupType::get(op->getContext()), operands); + op, GroupType::get(op->getContext()), adaptor.getOperands()); return success(); } }; @@ -356,10 +356,10 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(AddToGroupOp op, ArrayRef operands, + matchAndRewrite(AddToGroupOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( - op, rewriter.getIndexType(), operands); + op, rewriter.getIndexType(), adaptor.getOperands()); return success(); } }; @@ -382,7 +382,7 @@ public: outlinedFunctions(outlinedFunctions) {} LogicalResult - matchAndRewrite(AwaitType op, ArrayRef operands, + matchAndRewrite(AwaitType op, typename AwaitType::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { // We can only await on one the `AwaitableType` (for `await` it can be // a `token` or a `value`, for `await_all` it must be a `group`). @@ -395,7 +395,7 @@ public: const bool isInCoroutine = outlined != outlinedFunctions.end(); Location loc = op->getLoc(); - Value operand = AwaitAdaptor(operands).operand(); + Value operand = adaptor.operand(); Type i1 = rewriter.getI1Type(); @@ -520,7 +520,7 @@ public: outlinedFunctions(outlinedFunctions) {} LogicalResult - matchAndRewrite(async::YieldOp op, ArrayRef operands, + matchAndRewrite(async::YieldOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Check if yield operation is inside the async coroutine function. auto func = op->template getParentOfType(); @@ -534,7 +534,7 @@ public: // Store yielded values into the async values storage and switch async // values state to available. - for (auto tuple : llvm::zip(operands, coro.returnValues)) { + for (auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) { Value yieldValue = std::get<0>(tuple); Value asyncValue = std::get<1>(tuple); rewriter.create(loc, yieldValue, asyncValue); @@ -563,7 +563,7 @@ public: outlinedFunctions(outlinedFunctions) {} LogicalResult - matchAndRewrite(AssertOp op, ArrayRef operands, + matchAndRewrite(AssertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Check if assert operation is inside the async coroutine function. auto func = op->template getParentOfType(); @@ -577,7 +577,7 @@ public: Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op)); rewriter.setInsertionPointToEnd(cont->getPrevNode()); - rewriter.create(loc, AssertOpAdaptor(operands).arg(), + rewriter.create(loc, adaptor.arg(), /*trueDest=*/cont, /*trueArgs=*/ArrayRef(), /*falseDest=*/setupSetErrorBlock(coro), diff --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp index 79a111f..a8cf8c1 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp @@ -104,9 +104,8 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(InitTensorOp op, ArrayRef operands, + matchAndRewrite(InitTensorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { - linalg::InitTensorOpAdaptor adaptor(operands, op->getAttrDictionary()); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()).cast(), adaptor.sizes()); @@ -126,9 +125,8 @@ public: memref::ExpandShapeOp, memref::CollapseShapeOp>; LogicalResult - matchAndRewrite(TensorReshapeOp op, ArrayRef operands, + matchAndRewrite(TensorReshapeOp op, Adaptor adaptor, ConversionPatternRewriter &rewriter) const final { - Adaptor adaptor(operands, op->getAttrDictionary()); rewriter.replaceOpWithNewOp(op, this->getTypeConverter() ->convertType(op.getType()) @@ -145,9 +143,8 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(FillOp op, ArrayRef operands, + matchAndRewrite(FillOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { - linalg::FillOpAdaptor adaptor(operands, op->getAttrDictionary()); if (!op.output().getType().isa()) return rewriter.notifyMatchFailure(op, "operand must be of a tensor type"); @@ -208,9 +205,8 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(tensor::ExtractSliceOp op, ArrayRef operands, + matchAndRewrite(tensor::ExtractSliceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { - tensor::ExtractSliceOpAdaptor adaptor(operands, op->getAttrDictionary()); Value sourceMemref = adaptor.source(); assert(sourceMemref.getType().isa()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp index 08278bc..15b7f9e 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp @@ -60,7 +60,7 @@ class DetensorizeGenericOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(GenericOp op, ArrayRef operands, + matchAndRewrite(GenericOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Block *originalBlock = op->getBlock(); @@ -78,7 +78,7 @@ public: rewriter.replaceOp(op, yieldOp->getOperands()); // No need for these intermediate blocks, merge them into 1. - rewriter.mergeBlocks(opEntryBlock, originalBlock, operands); + rewriter.mergeBlocks(opEntryBlock, originalBlock, adaptor.getOperands()); rewriter.mergeBlocks(newBlock, originalBlock, {}); rewriter.eraseOp(&*Block::iterator(yieldOp)); diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp index c34660b..b84a6cb 100644 --- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp @@ -21,7 +21,7 @@ class ConvertForOpTypes : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ForOp op, ArrayRef operands, + matchAndRewrite(ForOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { SmallVector newResultTypes; for (auto type : op.getResultTypes()) { @@ -63,7 +63,7 @@ public: } // Change the clone to use the updated operands. We could have cloned with // a BlockAndValueMapping, but this seems a bit more direct. - newOp->setOperands(operands); + newOp->setOperands(adaptor.getOperands()); // Update the result types to the new converted types. for (auto t : llvm::zip(newOp.getResults(), newResultTypes)) std::get<0>(t).setType(std::get<1>(t)); @@ -79,7 +79,7 @@ class ConvertIfOpTypes : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(IfOp op, ArrayRef operands, + matchAndRewrite(IfOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // TODO: Generalize this to any type conversion, not just 1:1. // @@ -108,7 +108,7 @@ public: newOp.elseRegion().end()); // Update the operands and types. - newOp->setOperands(operands); + newOp->setOperands(adaptor.getOperands()); for (auto t : llvm::zip(newOp.getResults(), newResultTypes)) std::get<0>(t).setType(std::get<1>(t)); rewriter.replaceOp(op, newOp.getResults()); @@ -125,9 +125,9 @@ class ConvertYieldOpTypes : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(scf::YieldOp op, ArrayRef operands, + matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, operands); + rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); return success(); } }; @@ -139,7 +139,7 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(WhileOp op, ArrayRef operands, + matchAndRewrite(WhileOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto *converter = getTypeConverter(); assert(converter); @@ -147,7 +147,6 @@ public: if (failed(converter->convertTypes(op.getResultTypes(), newResultTypes))) return failure(); - WhileOp::Adaptor adaptor(operands); auto newOp = rewriter.create(op.getLoc(), newResultTypes, adaptor.getOperands()); for (auto i : {0u, 1u}) { @@ -167,9 +166,10 @@ class ConvertConditionOpTypes : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ConditionOp op, ArrayRef operands, + matchAndRewrite(ConditionOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.updateRootInPlace(op, [&]() { op->setOperands(operands); }); + rewriter.updateRootInPlace( + op, [&]() { op->setOperands(adaptor.getOperands()); }); return success(); } }; diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp index 20a793d..10a3ba6 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp @@ -156,7 +156,7 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(spirv::FuncOp funcOp, ArrayRef operands, + matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -168,7 +168,7 @@ class LowerABIAttributesPass final } // namespace LogicalResult ProcessInterfaceVarABI::matchAndRewrite( - spirv::FuncOp funcOp, ArrayRef operands, + spirv::FuncOp funcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { if (!funcOp->getAttrOfType( spirv::getEntryPointABIAttrName())) { diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index 76abf15..e2981bf 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -541,13 +541,13 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(FuncOp funcOp, ArrayRef operands, + matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; } // namespace LogicalResult -FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef operands, +FuncOpConversion::matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto fnType = funcOp.getType(); if (fnType.getNumResults() > 1) diff --git a/mlir/lib/Dialect/Shape/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/Shape/Transforms/StructuralTypeConversions.cpp index b58fa4d..008164a 100644 --- a/mlir/lib/Dialect/Shape/Transforms/StructuralTypeConversions.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/StructuralTypeConversions.cpp @@ -20,7 +20,7 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(AssumingOp op, ArrayRef operands, + matchAndRewrite(AssumingOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { SmallVector newResultTypes; newResultTypes.reserve(op.getNumResults()); @@ -48,9 +48,9 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(AssumingYieldOp op, ArrayRef operands, + matchAndRewrite(AssumingYieldOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { - rewriter.replaceOpWithNewOp(op, operands); + rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); return success(); } }; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp index 5df5477..328bf8e 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -249,9 +249,9 @@ class SparseReturnConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ReturnOp op, ArrayRef operands, + matchAndRewrite(ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, operands); + rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); return success(); } }; @@ -262,7 +262,7 @@ class SparseTensorToDimSizeConverter public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(tensor::DimOp op, ArrayRef operands, + matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type resType = op.getType(); auto enc = getSparseTensorEncoding(op.source().getType()); @@ -278,7 +278,7 @@ public: // Generate the call. StringRef name = "sparseDimSize"; SmallVector params; - params.push_back(operands[0]); + params.push_back(adaptor.getOperands()[0]); params.push_back( rewriter.create(op.getLoc(), rewriter.getIndexAttr(idx))); rewriter.replaceOpWithNewOp( @@ -291,14 +291,15 @@ public: class SparseTensorNewConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(NewOp op, ArrayRef operands, + matchAndRewrite(NewOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type resType = op.getType(); auto enc = getSparseTensorEncoding(resType); if (!enc) return failure(); Value perm; - rewriter.replaceOp(op, genNewCall(rewriter, op, enc, 0, perm, operands[0])); + rewriter.replaceOp( + op, genNewCall(rewriter, op, enc, 0, perm, adaptor.getOperands()[0])); return success(); } }; @@ -307,7 +308,7 @@ class SparseTensorNewConverter : public OpConversionPattern { class SparseTensorConvertConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ConvertOp op, ArrayRef operands, + matchAndRewrite(ConvertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type resType = op.getType(); auto encDst = getSparseTensorEncoding(resType); @@ -320,7 +321,8 @@ class SparseTensorConvertConverter : public OpConversionPattern { // yield the fastest conversion but avoids the need for a full // O(N^2) conversion matrix. Value perm; - Value coo = genNewCall(rewriter, op, encDst, 3, perm, operands[0]); + Value coo = + genNewCall(rewriter, op, encDst, 3, perm, adaptor.getOperands()[0]); rewriter.replaceOp(op, genNewCall(rewriter, op, encDst, 1, perm, coo)); return success(); } @@ -349,7 +351,7 @@ class SparseTensorConvertConverter : public OpConversionPattern { MemRefType::get({ShapedType::kDynamicSize}, rewriter.getIndexType()); Value perm; Value ptr = genNewCall(rewriter, op, encDst, 2, perm); - Value tensor = operands[0]; + Value tensor = adaptor.getOperands()[0]; Value arg = rewriter.create( loc, rewriter.getIndexAttr(shape.getRank())); Value ind = rewriter.create(loc, memTp, ValueRange{arg}); @@ -381,7 +383,7 @@ class SparseTensorToPointersConverter public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ToPointersOp op, ArrayRef operands, + matchAndRewrite(ToPointersOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type resType = op.getType(); Type eltType = resType.cast().getElementType(); @@ -398,10 +400,11 @@ public: name = "sparsePointers8"; else return failure(); - rewriter.replaceOpWithNewOp( - op, resType, - getFunc(op, name, resType, operands, /*emitCInterface=*/true), - operands); + rewriter.replaceOpWithNewOp(op, resType, + getFunc(op, name, resType, + adaptor.getOperands(), + /*emitCInterface=*/true), + adaptor.getOperands()); return success(); } }; @@ -411,7 +414,7 @@ class SparseTensorToIndicesConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ToIndicesOp op, ArrayRef operands, + matchAndRewrite(ToIndicesOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type resType = op.getType(); Type eltType = resType.cast().getElementType(); @@ -428,10 +431,11 @@ public: name = "sparseIndices8"; else return failure(); - rewriter.replaceOpWithNewOp( - op, resType, - getFunc(op, name, resType, operands, /*emitCInterface=*/true), - operands); + rewriter.replaceOpWithNewOp(op, resType, + getFunc(op, name, resType, + adaptor.getOperands(), + /*emitCInterface=*/true), + adaptor.getOperands()); return success(); } }; @@ -441,7 +445,7 @@ class SparseTensorToValuesConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ToValuesOp op, ArrayRef operands, + matchAndRewrite(ToValuesOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type resType = op.getType(); Type eltType = resType.cast().getElementType(); @@ -460,10 +464,11 @@ public: name = "sparseValuesI8"; else return failure(); - rewriter.replaceOpWithNewOp( - op, resType, - getFunc(op, name, resType, operands, /*emitCInterface=*/true), - operands); + rewriter.replaceOpWithNewOp(op, resType, + getFunc(op, name, resType, + adaptor.getOperands(), + /*emitCInterface=*/true), + adaptor.getOperands()); return success(); } }; @@ -474,12 +479,12 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult // Simply fold the operator into the pointer to the sparse storage scheme. - matchAndRewrite(ToTensorOp op, ArrayRef operands, + matchAndRewrite(ToTensorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Check that all arguments of the tensor reconstruction operators are calls // into the support library that query exactly the same opaque pointer. Value ptr; - for (Value op : operands) { + for (Value op : adaptor.getOperands()) { if (auto call = op.getDefiningOp()) { Value arg = call.getOperand(0); if (!arg.getType().isa()) diff --git a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp index 06f6c12..23b7019 100644 --- a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp @@ -27,9 +27,8 @@ class BufferizeIndexCastOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(IndexCastOp op, ArrayRef operands, + matchAndRewrite(IndexCastOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - IndexCastOp::Adaptor adaptor(operands); auto tensorType = op.getType().cast(); rewriter.replaceOpWithNewOp( op, adaptor.in(), @@ -42,12 +41,11 @@ class BufferizeSelectOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(SelectOp op, ArrayRef operands, + matchAndRewrite(SelectOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (!op.condition().getType().isa()) return rewriter.notifyMatchFailure(op, "requires scalar condition"); - SelectOp::Adaptor adaptor(operands); rewriter.replaceOpWithNewOp( op, adaptor.condition(), adaptor.true_value(), adaptor.false_value()); return success(); diff --git a/mlir/lib/Dialect/StandardOps/Transforms/DecomposeCallGraphTypes.cpp b/mlir/lib/Dialect/StandardOps/Transforms/DecomposeCallGraphTypes.cpp index 7636bc7..3686568 100644 --- a/mlir/lib/Dialect/StandardOps/Transforms/DecomposeCallGraphTypes.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/DecomposeCallGraphTypes.cpp @@ -61,7 +61,7 @@ struct DecomposeCallGraphTypesForFuncArgs DecomposeCallGraphTypesOpConversionPattern; LogicalResult - matchAndRewrite(FuncOp op, ArrayRef operands, + matchAndRewrite(FuncOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { auto functionType = op.getType(); @@ -106,10 +106,10 @@ struct DecomposeCallGraphTypesForReturnOp using DecomposeCallGraphTypesOpConversionPattern:: DecomposeCallGraphTypesOpConversionPattern; LogicalResult - matchAndRewrite(ReturnOp op, ArrayRef operands, + matchAndRewrite(ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { SmallVector newOperands; - for (Value operand : operands) + for (Value operand : adaptor.getOperands()) decomposer.decomposeValue(rewriter, op.getLoc(), operand.getType(), operand, newOperands); rewriter.replaceOpWithNewOp(op, newOperands); @@ -131,12 +131,12 @@ struct DecomposeCallGraphTypesForCallOp DecomposeCallGraphTypesOpConversionPattern; LogicalResult - matchAndRewrite(CallOp op, ArrayRef operands, + matchAndRewrite(CallOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { // Create the operands list of the new `CallOp`. SmallVector newOperands; - for (Value operand : operands) + for (Value operand : adaptor.getOperands()) decomposer.decomposeValue(rewriter, op.getLoc(), operand.getType(), operand, newOperands); diff --git a/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp b/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp index 49aaade..8756fcf 100644 --- a/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp @@ -20,7 +20,7 @@ struct CallOpSignatureConversion : public OpConversionPattern { /// Hook for derived classes to implement combined matching and rewriting. LogicalResult - matchAndRewrite(CallOp callOp, ArrayRef operands, + matchAndRewrite(CallOp callOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Convert the original function results. SmallVector convertedResults; @@ -30,8 +30,8 @@ struct CallOpSignatureConversion : public OpConversionPattern { // Substitute with the new result types from the corresponding FuncType // conversion. - rewriter.replaceOpWithNewOp(callOp, callOp.callee(), - convertedResults, operands); + rewriter.replaceOpWithNewOp( + callOp, callOp.callee(), convertedResults, adaptor.getOperands()); return success(); } }; @@ -96,13 +96,12 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ReturnOp op, ArrayRef operands, + matchAndRewrite(ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { // For a return, all operands go to the results of the parent, so // rewrite them all. - Operation *operation = op.getOperation(); - rewriter.updateRootInPlace( - op, [operands, operation]() { operation->setOperands(operands); }); + rewriter.updateRootInPlace(op, + [&] { op->setOperands(adaptor.getOperands()); }); return success(); } }; diff --git a/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp index c916a73..035251f 100644 --- a/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp @@ -66,7 +66,7 @@ public: globals(globals) {} LogicalResult - matchAndRewrite(ConstantOp op, ArrayRef operands, + matchAndRewrite(ConstantOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto type = op.getType().dyn_cast(); if (!type) diff --git a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp index e8c3865..f35d701 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp @@ -26,10 +26,11 @@ class BufferizeCastOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(tensor::CastOp op, ArrayRef operands, + matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto resultType = getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp(op, resultType, operands[0]); + rewriter.replaceOpWithNewOp(op, resultType, + adaptor.getOperands()[0]); return success(); } }; @@ -40,9 +41,8 @@ class BufferizeDimOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(tensor::DimOp op, ArrayRef operands, + matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - tensor::DimOp::Adaptor adaptor(operands); rewriter.replaceOpWithNewOp(op, adaptor.source(), adaptor.index()); return success(); @@ -55,9 +55,8 @@ class BufferizeExtractOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(tensor::ExtractOp op, ArrayRef operands, + matchAndRewrite(tensor::ExtractOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - tensor::ExtractOp::Adaptor adaptor(operands); rewriter.replaceOpWithNewOp(op, adaptor.tensor(), adaptor.indices()); return success(); @@ -71,7 +70,7 @@ class BufferizeFromElementsOp public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(tensor::FromElementsOp op, ArrayRef operands, + matchAndRewrite(tensor::FromElementsOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { int numberOfElements = op.elements().size(); auto resultType = MemRefType::get( @@ -95,16 +94,15 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(tensor::GenerateOp op, ArrayRef operands, + matchAndRewrite(tensor::GenerateOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { // Allocate memory. Location loc = op.getLoc(); - tensor::GenerateOp::Adaptor transformed(operands); RankedTensorType tensorType = op.getType().cast(); MemRefType memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType()); - Value result = rewriter.create( - loc, memrefType, transformed.dynamicExtents()); + Value result = rewriter.create(loc, memrefType, + adaptor.dynamicExtents()); // Collect loop bounds. int64_t rank = tensorType.getRank(); @@ -117,7 +115,7 @@ public: for (int i = 0; i < rank; i++) { Value upperBound = tensorType.isDynamicDim(i) - ? transformed.dynamicExtents()[nextDynamicIndex++] + ? adaptor.dynamicExtents()[nextDynamicIndex++] : rewriter.create(loc, memrefType.getDimSize(i)); upperBounds.push_back(upperBound); } diff --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp index c2a7a19..7b174e1 100644 --- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp @@ -46,18 +46,18 @@ struct LowerToIntrinsic : public OpConversionPattern { } LogicalResult - matchAndRewrite(OpTy op, ArrayRef operands, + matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type elementType = getSrcVectorElementType(op); unsigned bitwidth = elementType.getIntOrFloatBitWidth(); if (bitwidth == 32) return LLVM::detail::oneToOneRewrite(op, Intr32OpTy::getOperationName(), - operands, getTypeConverter(), - rewriter); + adaptor.getOperands(), + getTypeConverter(), rewriter); if (bitwidth == 64) return LLVM::detail::oneToOneRewrite(op, Intr64OpTy::getOperationName(), - operands, getTypeConverter(), - rewriter); + adaptor.getOperands(), + getTypeConverter(), rewriter); return rewriter.notifyMatchFailure( op, "expected 'src' to be either f32 or f64"); } @@ -68,9 +68,8 @@ struct MaskCompressOpConversion using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(MaskCompressOp op, ArrayRef operands, + matchAndRewrite(MaskCompressOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - MaskCompressOp::Adaptor adaptor(operands); auto opType = adaptor.a().getType(); Value src; @@ -95,10 +94,8 @@ struct RsqrtOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(RsqrtOp op, ArrayRef operands, + matchAndRewrite(RsqrtOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - RsqrtOp::Adaptor adaptor(operands); - auto opType = adaptor.a().getType(); rewriter.replaceOpWithNewOp(op, opType, adaptor.a()); return success(); @@ -109,9 +106,8 @@ struct DotOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(DotOp op, ArrayRef operands, + matchAndRewrite(DotOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - DotOp::Adaptor adaptor(operands); auto opType = adaptor.a().getType(); Type llvmIntType = IntegerType::get(&getTypeConverter()->getContext(), 8); // Dot product of all elements, broadcasted to all elements. diff --git a/mlir/lib/Transforms/Bufferize.cpp b/mlir/lib/Transforms/Bufferize.cpp index 7ed7526..27d27c0 100644 --- a/mlir/lib/Transforms/Bufferize.cpp +++ b/mlir/lib/Transforms/Bufferize.cpp @@ -58,9 +58,8 @@ class BufferizeTensorLoadOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(memref::TensorLoadOp op, ArrayRef operands, + matchAndRewrite(memref::TensorLoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - memref::TensorLoadOp::Adaptor adaptor(operands); rewriter.replaceOp(op, adaptor.memref()); return success(); } @@ -74,9 +73,8 @@ class BufferizeCastOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(memref::BufferCastOp op, ArrayRef operands, + matchAndRewrite(memref::BufferCastOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - memref::BufferCastOp::Adaptor adaptor(operands); rewriter.replaceOp(op, adaptor.tensor()); return success(); }