From 7df761217cd7d0026ffff23c4bdac846bb60f185 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Markus=20B=C3=B6ck?= Date: Tue, 10 Jan 2023 21:27:18 +0100 Subject: [PATCH] [mlir][NFC] Migrate rest of the dialects to the new fold API --- mlir/include/mlir/Dialect/Affine/IR/AffineOps.td | 1 + .../Dialect/Bufferization/IR/BufferizationBase.td | 1 + .../include/mlir/Dialect/Complex/IR/ComplexBase.td | 1 + mlir/include/mlir/Dialect/EmitC/IR/EmitCBase.td | 1 + mlir/include/mlir/Dialect/Func/IR/FuncOps.td | 1 + mlir/include/mlir/Dialect/GPU/IR/GPUBase.td | 1 + mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td | 1 + mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td | 1 + mlir/include/mlir/Dialect/Quant/QuantOpsBase.td | 1 + mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 1 + .../Dialect/SparseTensor/IR/SparseTensorBase.td | 1 + .../mlir/Dialect/Transform/IR/TransformDialect.td | 2 ++ mlir/include/mlir/IR/BuiltinDialect.td | 2 ++ mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 24 ++++++------- .../Dialect/Bufferization/IR/BufferizationOps.cpp | 6 ++-- mlir/lib/Dialect/Complex/IR/ComplexOps.cpp | 40 +++++++--------------- mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 3 +- mlir/lib/Dialect/Func/IR/FuncOps.cpp | 3 +- mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 4 +-- mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 12 +++---- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 3 +- mlir/lib/Dialect/Quant/IR/QuantOps.cpp | 2 +- mlir/lib/Dialect/SCF/IR/SCF.cpp | 2 +- .../SparseTensor/IR/SparseTensorDialect.cpp | 4 +-- mlir/lib/Dialect/Transform/IR/TransformOps.cpp | 2 +- mlir/lib/IR/BuiltinDialect.cpp | 2 +- mlir/test/lib/Dialect/Test/TestDialect.cpp | 15 ++++---- mlir/test/lib/Dialect/Test/TestDialect.td | 1 + mlir/test/lib/Dialect/Test/TestOps.td | 8 ++--- mlir/test/lib/Dialect/Test/TestTraits.cpp | 4 +-- .../mlir-linalg-ods-yaml-gen.cpp | 2 +- 31 files changed, 72 insertions(+), 80 deletions(-) diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td index d19e4d2..4a383ec 100644 --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -24,6 +24,7 @@ def Affine_Dialect : Dialect { let cppNamespace = "mlir"; let hasConstantMaterializer = 1; let dependentDialects = ["arith::ArithDialect"]; + let useFoldAPI = kEmitFoldAdaptorFolder; } // Base class for Affine dialect ops. diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationBase.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationBase.td index 280bfdb..ecb8ec9 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationBase.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationBase.td @@ -69,6 +69,7 @@ def Bufferization_Dialect : Dialect { kEscapeAttrName = "bufferization.escape"; }]; let hasOperationAttrVerify = 1; + let useFoldAPI = kEmitFoldAdaptorFolder; } #endif // BUFFERIZATION_BASE diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexBase.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexBase.td index 31135fc..20bc712 100644 --- a/mlir/include/mlir/Dialect/Complex/IR/ComplexBase.td +++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexBase.td @@ -22,6 +22,7 @@ def Complex_Dialect : Dialect { let dependentDialects = ["arith::ArithDialect"]; let hasConstantMaterializer = 1; let useDefaultAttributePrinterParser = 1; + let useFoldAPI = kEmitFoldAdaptorFolder; } #endif // COMPLEX_BASE diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitCBase.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitCBase.td index 375dbcb..19b2a32 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitCBase.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitCBase.td @@ -31,6 +31,7 @@ def EmitC_Dialect : Dialect { let hasConstantMaterializer = 1; let useDefaultTypePrinterParser = 1; let useDefaultAttributePrinterParser = 1; + let useFoldAPI = kEmitFoldAdaptorFolder; } #endif // MLIR_DIALECT_EMITC_IR_EMITCBASE diff --git a/mlir/include/mlir/Dialect/Func/IR/FuncOps.td b/mlir/include/mlir/Dialect/Func/IR/FuncOps.td index 4922689..58a33a1 100644 --- a/mlir/include/mlir/Dialect/Func/IR/FuncOps.td +++ b/mlir/include/mlir/Dialect/Func/IR/FuncOps.td @@ -23,6 +23,7 @@ def Func_Dialect : Dialect { let cppNamespace = "::mlir::func"; let dependentDialects = ["cf::ControlFlowDialect"]; let hasConstantMaterializer = 1; + let useFoldAPI = kEmitFoldAdaptorFolder; } // Base class for Func dialect ops. diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td index d5b36e4..7b02ed7 100644 --- a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td @@ -56,6 +56,7 @@ def GPU_Dialect : Dialect { let dependentDialects = ["arith::ArithDialect"]; let useDefaultAttributePrinterParser = 1; let useDefaultTypePrinterParser = 1; + let useFoldAPI = kEmitFoldAdaptorFolder; } def GPU_AsyncToken : DialectType< diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td index 80f8265..37365f6 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -31,6 +31,7 @@ def LLVM_Dialect : Dialect { let hasRegionArgAttrVerify = 1; let hasRegionResultAttrVerify = 1; let hasOperationAttrVerify = 1; + let useFoldAPI = kEmitFoldAdaptorFolder; let extraClassDeclaration = [{ /// Name of the data layout attributes. diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td index 1ee9d2a..706b2ff 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td @@ -46,6 +46,7 @@ def Linalg_Dialect : Dialect { let hasCanonicalizer = 1; let hasOperationAttrVerify = 1; let hasConstantMaterializer = 1; + let useFoldAPI = kEmitFoldAdaptorFolder; let extraClassDeclaration = [{ /// Attribute name used to to memoize indexing maps for named ops. constexpr const static ::llvm::StringLiteral diff --git a/mlir/include/mlir/Dialect/Quant/QuantOpsBase.td b/mlir/include/mlir/Dialect/Quant/QuantOpsBase.td index 74d69c0..4a2e11d 100644 --- a/mlir/include/mlir/Dialect/Quant/QuantOpsBase.td +++ b/mlir/include/mlir/Dialect/Quant/QuantOpsBase.td @@ -20,6 +20,7 @@ def Quantization_Dialect : Dialect { let cppNamespace = "::mlir::quant"; let useDefaultTypePrinterParser = 1; + let useFoldAPI = kEmitFoldAdaptorFolder; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td index d5a1505..a610562 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -25,6 +25,7 @@ def SCF_Dialect : Dialect { let name = "scf"; let cppNamespace = "::mlir::scf"; let dependentDialects = ["arith::ArithDialect"]; + let useFoldAPI = kEmitFoldAdaptorFolder; } // Base class for SCF dialect ops. diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorBase.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorBase.td index 2506da4..12066f3 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorBase.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorBase.td @@ -83,6 +83,7 @@ def SparseTensor_Dialect : Dialect { let useDefaultAttributePrinterParser = 1; let useDefaultTypePrinterParser = 1; + let useFoldAPI = kEmitFoldAdaptorFolder; } #endif // SPARSETENSOR_BASE diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td index a7bb75e..05b094c 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td @@ -23,6 +23,8 @@ def Transform_Dialect : Dialect { "::mlir::pdl_interp::PDLInterpDialect", ]; + let useFoldAPI = kEmitFoldAdaptorFolder; + let extraClassDeclaration = [{ /// Returns the named PDL constraint functions available in the dialect /// as a map from their name to the function. diff --git a/mlir/include/mlir/IR/BuiltinDialect.td b/mlir/include/mlir/IR/BuiltinDialect.td index a3d0f0c..e9d0096 100644 --- a/mlir/include/mlir/IR/BuiltinDialect.td +++ b/mlir/include/mlir/IR/BuiltinDialect.td @@ -34,6 +34,8 @@ def Builtin_Dialect : Dialect { public: }]; + + let useFoldAPI = kEmitFoldAdaptorFolder; } #endif // BUILTIN_BASE diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index be58118..f4c851c 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -562,7 +562,7 @@ bool AffineApplyOp::isValidSymbol(Region *region) { }); } -OpFoldResult AffineApplyOp::fold(ArrayRef operands) { +OpFoldResult AffineApplyOp::fold(FoldAdaptor adaptor) { auto map = getAffineMap(); // Fold dims and symbols to existing values. @@ -574,7 +574,7 @@ OpFoldResult AffineApplyOp::fold(ArrayRef operands) { // Otherwise, default to folding the map. SmallVector result; - if (failed(map.constantFold(operands, result))) + if (failed(map.constantFold(adaptor.getMapOperands(), result))) return {}; return result[0]; } @@ -2135,7 +2135,7 @@ static bool hasTrivialZeroTripCount(AffineForOp op) { return tripCount && *tripCount == 0; } -LogicalResult AffineForOp::fold(ArrayRef operands, +LogicalResult AffineForOp::fold(FoldAdaptor adaptor, SmallVectorImpl &results) { bool folded = succeeded(foldLoopBounds(*this)); folded |= succeeded(canonicalizeLoopBounds(*this)); @@ -2723,7 +2723,7 @@ static void composeSetAndOperands(IntegerSet &set, } /// Canonicalize an affine if op's conditional (integer set + operands). -LogicalResult AffineIfOp::fold(ArrayRef, +LogicalResult AffineIfOp::fold(FoldAdaptor, SmallVectorImpl &) { auto set = getIntegerSet(); SmallVector operands(getOperands()); @@ -2858,7 +2858,7 @@ void AffineLoadOp::getCanonicalizationPatterns(RewritePatternSet &results, results.add>(context); } -OpFoldResult AffineLoadOp::fold(ArrayRef cstOperands) { +OpFoldResult AffineLoadOp::fold(FoldAdaptor adaptor) { /// load(memrefcast) -> load if (succeeded(memref::foldMemRefCast(*this))) return getResult(); @@ -2975,7 +2975,7 @@ void AffineStoreOp::getCanonicalizationPatterns(RewritePatternSet &results, results.add>(context); } -LogicalResult AffineStoreOp::fold(ArrayRef cstOperands, +LogicalResult AffineStoreOp::fold(FoldAdaptor adaptor, SmallVectorImpl &results) { /// store(memrefcast) -> store return memref::foldMemRefCast(*this, getValueToStore()); @@ -3282,8 +3282,8 @@ struct CanonicalizeSingleResultAffineMinMaxOp : public OpRewritePattern { // %0 = affine.min (d0) -> (1000, d0 + 512) (%i0) // -OpFoldResult AffineMinOp::fold(ArrayRef operands) { - return foldMinMaxOp(*this, operands); +OpFoldResult AffineMinOp::fold(FoldAdaptor adaptor) { + return foldMinMaxOp(*this, adaptor.getOperands()); } void AffineMinOp::getCanonicalizationPatterns(RewritePatternSet &patterns, @@ -3310,8 +3310,8 @@ void AffineMinOp::print(OpAsmPrinter &p) { printAffineMinMaxOp(p, *this); } // %0 = affine.max (d0) -> (1000, d0 + 512) (%i0) // -OpFoldResult AffineMaxOp::fold(ArrayRef operands) { - return foldMinMaxOp(*this, operands); +OpFoldResult AffineMaxOp::fold(FoldAdaptor adaptor) { + return foldMinMaxOp(*this, adaptor.getOperands()); } void AffineMaxOp::getCanonicalizationPatterns(RewritePatternSet &patterns, @@ -3431,7 +3431,7 @@ void AffinePrefetchOp::getCanonicalizationPatterns(RewritePatternSet &results, results.add>(context); } -LogicalResult AffinePrefetchOp::fold(ArrayRef cstOperands, +LogicalResult AffinePrefetchOp::fold(FoldAdaptor adaptor, SmallVectorImpl &results) { /// prefetch(memrefcast) -> prefetch return memref::foldMemRefCast(*this); @@ -3705,7 +3705,7 @@ static LogicalResult canonicalizeLoopBounds(AffineParallelOp op) { return success(); } -LogicalResult AffineParallelOp::fold(ArrayRef operands, +LogicalResult AffineParallelOp::fold(FoldAdaptor adaptor, SmallVectorImpl &results) { return canonicalizeLoopBounds(*this); } diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp index f021a59..59263f6 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -458,7 +458,7 @@ void CloneOp::getEffects( SideEffects::DefaultResource::get()); } -OpFoldResult CloneOp::fold(ArrayRef operands) { +OpFoldResult CloneOp::fold(FoldAdaptor adaptor) { return succeeded(memref::foldMemRefCast(*this)) ? getResult() : Value(); } @@ -560,7 +560,7 @@ LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter, // ToTensorOp //===----------------------------------------------------------------------===// -OpFoldResult ToTensorOp::fold(ArrayRef) { +OpFoldResult ToTensorOp::fold(FoldAdaptor) { if (auto toMemref = getMemref().getDefiningOp()) // Approximate alias analysis by conservatively folding only when no there // is no interleaved operation. @@ -596,7 +596,7 @@ void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, // ToMemrefOp //===----------------------------------------------------------------------===// -OpFoldResult ToMemrefOp::fold(ArrayRef) { +OpFoldResult ToMemrefOp::fold(FoldAdaptor) { if (auto memrefToTensor = getTensor().getDefiningOp()) if (memrefToTensor.getMemref().getType() == getType()) return memrefToTensor.getMemref(); diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp index 44ebf69..c71e2e0 100644 --- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp +++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp @@ -17,8 +17,7 @@ using namespace mlir::complex; // ConstantOp //===----------------------------------------------------------------------===// -OpFoldResult ConstantOp::fold(ArrayRef operands) { - assert(operands.empty() && "constant has no operands"); +OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); } @@ -68,8 +67,7 @@ LogicalResult ConstantOp::verify() { // CreateOp //===----------------------------------------------------------------------===// -OpFoldResult CreateOp::fold(ArrayRef operands) { - assert(operands.size() == 2 && "binary op takes two operands"); +OpFoldResult CreateOp::fold(FoldAdaptor adaptor) { // Fold complex.create(complex.re(op), complex.im(op)). if (auto reOp = getOperand(0).getDefiningOp()) { if (auto imOp = getOperand(1).getDefiningOp()) { @@ -85,9 +83,8 @@ OpFoldResult CreateOp::fold(ArrayRef operands) { // ImOp //===----------------------------------------------------------------------===// -OpFoldResult ImOp::fold(ArrayRef operands) { - assert(operands.size() == 1 && "unary op takes 1 operand"); - ArrayAttr arrayAttr = operands[0].dyn_cast_or_null(); +OpFoldResult ImOp::fold(FoldAdaptor adaptor) { + ArrayAttr arrayAttr = adaptor.getComplex().dyn_cast_or_null(); if (arrayAttr && arrayAttr.size() == 2) return arrayAttr[1]; if (auto createOp = getOperand().getDefiningOp()) @@ -99,9 +96,8 @@ OpFoldResult ImOp::fold(ArrayRef operands) { // ReOp //===----------------------------------------------------------------------===// -OpFoldResult ReOp::fold(ArrayRef operands) { - assert(operands.size() == 1 && "unary op takes 1 operand"); - ArrayAttr arrayAttr = operands[0].dyn_cast_or_null(); +OpFoldResult ReOp::fold(FoldAdaptor adaptor) { + ArrayAttr arrayAttr = adaptor.getComplex().dyn_cast_or_null(); if (arrayAttr && arrayAttr.size() == 2) return arrayAttr[0]; if (auto createOp = getOperand().getDefiningOp()) @@ -113,9 +109,7 @@ OpFoldResult ReOp::fold(ArrayRef operands) { // AddOp //===----------------------------------------------------------------------===// -OpFoldResult AddOp::fold(ArrayRef operands) { - assert(operands.size() == 2 && "binary op takes 2 operands"); - +OpFoldResult AddOp::fold(FoldAdaptor adaptor) { // complex.add(complex.sub(a, b), b) -> a if (auto sub = getLhs().getDefiningOp()) if (getRhs() == sub.getRhs()) @@ -142,9 +136,7 @@ OpFoldResult AddOp::fold(ArrayRef operands) { // SubOp //===----------------------------------------------------------------------===// -OpFoldResult SubOp::fold(ArrayRef operands) { - assert(operands.size() == 2 && "binary op takes 2 operands"); - +OpFoldResult SubOp::fold(FoldAdaptor adaptor) { // complex.sub(complex.add(a, b), b) -> a if (auto add = getLhs().getDefiningOp()) if (getRhs() == add.getRhs()) @@ -166,9 +158,7 @@ OpFoldResult SubOp::fold(ArrayRef operands) { // NegOp //===----------------------------------------------------------------------===// -OpFoldResult NegOp::fold(ArrayRef operands) { - assert(operands.size() == 1 && "unary op takes 1 operand"); - +OpFoldResult NegOp::fold(FoldAdaptor adaptor) { // complex.neg(complex.neg(a)) -> a if (auto negOp = getOperand().getDefiningOp()) return negOp.getOperand(); @@ -180,9 +170,7 @@ OpFoldResult NegOp::fold(ArrayRef operands) { // LogOp //===----------------------------------------------------------------------===// -OpFoldResult LogOp::fold(ArrayRef operands) { - assert(operands.size() == 1 && "unary op takes 1 operand"); - +OpFoldResult LogOp::fold(FoldAdaptor adaptor) { // complex.log(complex.exp(a)) -> a if (auto expOp = getOperand().getDefiningOp()) return expOp.getOperand(); @@ -194,9 +182,7 @@ OpFoldResult LogOp::fold(ArrayRef operands) { // ExpOp //===----------------------------------------------------------------------===// -OpFoldResult ExpOp::fold(ArrayRef operands) { - assert(operands.size() == 1 && "unary op takes 1 operand"); - +OpFoldResult ExpOp::fold(FoldAdaptor adaptor) { // complex.exp(complex.log(a)) -> a if (auto logOp = getOperand().getDefiningOp()) return logOp.getOperand(); @@ -208,9 +194,7 @@ OpFoldResult ExpOp::fold(ArrayRef operands) { // ConjOp //===----------------------------------------------------------------------===// -OpFoldResult ConjOp::fold(ArrayRef operands) { - assert(operands.size() == 1 && "unary op takes 1 operand"); - +OpFoldResult ConjOp::fold(FoldAdaptor adaptor) { // complex.conj(complex.conj(a)) -> a if (auto conjOp = getOperand().getDefiningOp()) return conjOp.getOperand(); diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index 91046c7..75d9222 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -129,8 +129,7 @@ LogicalResult emitc::ConstantOp::verify() { return success(); } -OpFoldResult emitc::ConstantOp::fold(ArrayRef operands) { - assert(operands.empty() && "constant has no operands"); +OpFoldResult emitc::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); } diff --git a/mlir/lib/Dialect/Func/IR/FuncOps.cpp b/mlir/lib/Dialect/Func/IR/FuncOps.cpp index 7c26e6a..de96e46 100644 --- a/mlir/lib/Dialect/Func/IR/FuncOps.cpp +++ b/mlir/lib/Dialect/Func/IR/FuncOps.cpp @@ -201,8 +201,7 @@ LogicalResult ConstantOp::verify() { return success(); } -OpFoldResult ConstantOp::fold(ArrayRef operands) { - assert(operands.empty() && "constant has no operands"); +OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValueAttr(); } diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index 8ef5483..e55e9a6 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -1286,12 +1286,12 @@ LogicalResult SubgroupMmaComputeOp::verify() { return success(); } -LogicalResult MemcpyOp::fold(ArrayRef operands, +LogicalResult MemcpyOp::fold(FoldAdaptor adaptor, SmallVectorImpl<::mlir::OpFoldResult> &results) { return memref::foldMemRefCast(*this); } -LogicalResult MemsetOp::fold(ArrayRef operands, +LogicalResult MemsetOp::fold(FoldAdaptor adaptor, SmallVectorImpl<::mlir::OpFoldResult> &results) { return memref::foldMemRefCast(*this); } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index e00c688..bb47aeb 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1441,7 +1441,7 @@ static Type getInsertExtractValueElementType(Type llvmType, return llvmType; } -OpFoldResult LLVM::ExtractValueOp::fold(ArrayRef operands) { +OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) { auto insertValueOp = getContainer().getDefiningOp(); OpFoldResult result = {}; while (insertValueOp) { @@ -2275,7 +2275,7 @@ LogicalResult LLVM::ConstantOp::verify() { } // Constant op constant-folds to its value. -OpFoldResult LLVM::ConstantOp::fold(ArrayRef) { return getValue(); } +OpFoldResult LLVM::ConstantOp::fold(FoldAdaptor) { return getValue(); } //===----------------------------------------------------------------------===// // Utility functions for parsing atomic ops @@ -2513,7 +2513,7 @@ LogicalResult FenceOp::verify() { // Folder for LLVM::BitcastOp //===----------------------------------------------------------------------===// -OpFoldResult LLVM::BitcastOp::fold(ArrayRef operands) { +OpFoldResult LLVM::BitcastOp::fold(FoldAdaptor adaptor) { // bitcast(x : T0, T0) -> x if (getArg().getType() == getType()) return getArg(); @@ -2528,7 +2528,7 @@ OpFoldResult LLVM::BitcastOp::fold(ArrayRef operands) { // Folder for LLVM::AddrSpaceCastOp //===----------------------------------------------------------------------===// -OpFoldResult LLVM::AddrSpaceCastOp::fold(ArrayRef operands) { +OpFoldResult LLVM::AddrSpaceCastOp::fold(FoldAdaptor adaptor) { // addrcast(x : T0, T0) -> x if (getArg().getType() == getType()) return getArg(); @@ -2543,9 +2543,9 @@ OpFoldResult LLVM::AddrSpaceCastOp::fold(ArrayRef operands) { // Folder for LLVM::GEPOp //===----------------------------------------------------------------------===// -OpFoldResult LLVM::GEPOp::fold(ArrayRef operands) { +OpFoldResult LLVM::GEPOp::fold(FoldAdaptor adaptor) { GEPIndicesAdaptor> indices(getRawConstantIndicesAttr(), - operands.drop_front()); + adaptor.getDynamicIndices()); // gep %x:T, 0 -> %x if (getBase().getType() == getType() && indices.size() == 1) diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 308d003..7f5bcce 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -980,8 +980,7 @@ void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results, results.add(context); } -LogicalResult GenericOp::fold(ArrayRef, - SmallVectorImpl &) { +LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl &) { return memref::foldMemRefCast(*this); } diff --git a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp index fcb97ae..c9a6bbc 100644 --- a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp @@ -36,7 +36,7 @@ void QuantizationDialect::initialize() { addBytecodeInterface(this); } -OpFoldResult StorageCastOp::fold(ArrayRef operands) { +OpFoldResult StorageCastOp::fold(FoldAdaptor adaptor) { // Matches x -> [scast -> scast] -> y, replacing the second scast with the // value of x if the casts invert each other. auto srcScastOp = getArg().getDefiningOp(); diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 25f953c..b98f9df 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -1598,7 +1598,7 @@ void IfOp::getSuccessorRegions(std::optional index, regions.push_back(RegionSuccessor(condition ? &getThenRegion() : elseRegion)); } -LogicalResult IfOp::fold(ArrayRef operands, +LogicalResult IfOp::fold(FoldAdaptor adaptor, SmallVectorImpl &results) { // if (!c) then A() else B() -> if c then B() else A() if (getElseRegion().empty()) diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index 4d2c4fb..59a4b13 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -467,7 +467,7 @@ LogicalResult ConvertOp::verify() { return emitError("unexpected type in convert"); } -OpFoldResult ConvertOp::fold(ArrayRef operands) { +OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) { Type dstType = getType(); // Fold trivial dense-to-dense convert and leave trivial sparse-to-sparse // convert for codegen to remove. This is because we use trivial @@ -531,7 +531,7 @@ static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op) { return op.getSpecifier().template getDefiningOp(); } -OpFoldResult GetStorageSpecifierOp::fold(ArrayRef operands) { +OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) { StorageSpecifierKind kind = getSpecifierKind(); std::optional dim = getDim(); for (auto op = getSpecifierSetDef(*this); op; op = getSpecifierSetDef(op)) diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index 20052bd..0b3391e 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -463,7 +463,7 @@ void transform::MergeHandlesOp::getEffects( // manipulation. } -OpFoldResult transform::MergeHandlesOp::fold(ArrayRef operands) { +OpFoldResult transform::MergeHandlesOp::fold(FoldAdaptor adaptor) { if (getDeduplicate() || getHandles().size() != 1) return {}; diff --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp index b66346f..0668459 100644 --- a/mlir/lib/IR/BuiltinDialect.cpp +++ b/mlir/lib/IR/BuiltinDialect.cpp @@ -190,7 +190,7 @@ LogicalResult ModuleOp::verify() { //===----------------------------------------------------------------------===// LogicalResult -UnrealizedConversionCastOp::fold(ArrayRef attrOperands, +UnrealizedConversionCastOp::fold(FoldAdaptor adaptor, SmallVectorImpl &foldResults) { OperandRange operands = getInputs(); ResultRange results = getOutputs(); diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp index 0f48bfc..7f0ad3b 100644 --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -1099,32 +1099,31 @@ void TestOpWithRegionPattern::getCanonicalizationPatterns( results.add(context); } -OpFoldResult TestOpWithRegionFold::fold(ArrayRef operands) { +OpFoldResult TestOpWithRegionFold::fold(FoldAdaptor adaptor) { return getOperand(); } -OpFoldResult TestOpConstant::fold(ArrayRef operands) { +OpFoldResult TestOpConstant::fold(FoldAdaptor adaptor) { return getValue(); } LogicalResult TestOpWithVariadicResultsAndFolder::fold( - ArrayRef operands, SmallVectorImpl &results) { + FoldAdaptor adaptor, SmallVectorImpl &results) { for (Value input : this->getOperands()) { results.push_back(input); } return success(); } -OpFoldResult TestOpInPlaceFold::fold(ArrayRef operands) { - assert(operands.size() == 1); - if (operands.front()) { - (*this)->setAttr("attr", operands.front()); +OpFoldResult TestOpInPlaceFold::fold(FoldAdaptor adaptor) { + if (adaptor.getOp()) { + (*this)->setAttr("attr", adaptor.getOp()); return getResult(); } return {}; } -OpFoldResult TestPassthroughFold::fold(ArrayRef operands) { +OpFoldResult TestPassthroughFold::fold(FoldAdaptor adaptor) { return getOperand(); } diff --git a/mlir/test/lib/Dialect/Test/TestDialect.td b/mlir/test/lib/Dialect/Test/TestDialect.td index 9ec1274..9dc4203 100644 --- a/mlir/test/lib/Dialect/Test/TestDialect.td +++ b/mlir/test/lib/Dialect/Test/TestDialect.td @@ -23,6 +23,7 @@ def Test_Dialect : Dialect { let hasNonDefaultDestructor = 1; let useDefaultTypePrinterParser = 0; let useDefaultAttributePrinterParser = 1; + let useFoldAPI = kEmitFoldAdaptorFolder; let isExtensible = 1; let dependentDialects = ["::mlir::DLTIDialect"]; diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 7d66136..e9816eb 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1290,7 +1290,7 @@ def TestOpInPlaceFoldSuccess : TEST_Op<"op_in_place_fold_success"> { let results = (outs Variadic); let hasFolder = 1; let extraClassDefinition = [{ - ::mlir::LogicalResult $cppClass::fold(ArrayRef operands, + ::mlir::LogicalResult $cppClass::fold(FoldAdaptor adaptor, SmallVectorImpl &results) { return success(); } @@ -1315,11 +1315,7 @@ def TestOpFoldWithFoldAdaptor $op `,` `[` $variadic `]` `,` `{` $var_of_var `}` $body attr-dict-with-keyword }]; - let hasFolder = 0; - - let extraClassDeclaration = [{ - ::mlir::OpFoldResult fold(FoldAdaptor adaptor); - }]; + let hasFolder = 1; } // An op that always fold itself. diff --git a/mlir/test/lib/Dialect/Test/TestTraits.cpp b/mlir/test/lib/Dialect/Test/TestTraits.cpp index 0ccffc7..d9b67ef 100644 --- a/mlir/test/lib/Dialect/Test/TestTraits.cpp +++ b/mlir/test/lib/Dialect/Test/TestTraits.cpp @@ -18,13 +18,13 @@ using namespace test; //===----------------------------------------------------------------------===// OpFoldResult TestInvolutionTraitFailingOperationFolderOp::fold( - ArrayRef operands) { + FoldAdaptor adaptor) { // This failure should cause the trait fold to run instead. return {}; } OpFoldResult TestInvolutionTraitSuccesfulOperationFolderOp::fold( - ArrayRef operands) { + FoldAdaptor adaptor) { auto argumentOp = getOperand(); // The success case should cause the trait fold to be supressed. return argumentOp.getDefiningOp() ? argumentOp : OpFoldResult{}; diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp index 9b7816c..d92c367 100644 --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -654,7 +654,7 @@ ArrayAttr {0}::getIndexingMaps() {{ // Parameters: // {0}: Class name const char structuredOpFoldersFormat[] = R"FMT( -LogicalResult {0}::fold(ArrayRef, +LogicalResult {0}::fold(FoldAdaptor, SmallVectorImpl &) {{ return memref::foldMemRefCast(*this); } -- 2.7.4