Change OwningRewritePatternList to carry an MLIRContext with it.
authorChris Lattner <clattner@nondot.org>
Sat, 20 Mar 2021 23:29:41 +0000 (16:29 -0700)
committerChris Lattner <clattner@nondot.org>
Sun, 21 Mar 2021 17:06:31 +0000 (10:06 -0700)
This updates the codebase to pass the context when creating an instance of
OwningRewritePatternList, and starts removing extraneous MLIRContext
parameters.  There are many many more to be removed.

Differential Revision: https://reviews.llvm.org/D99028

134 files changed:
mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h
mlir/include/mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h
mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h
mlir/include/mlir/Conversion/LinalgToSPIRV/LinalgToSPIRV.h
mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h
mlir/include/mlir/Conversion/SCFToGPU/SCFToGPU.h
mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h
mlir/include/mlir/Conversion/SCFToStandard/SCFToStandard.h
mlir/include/mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h
mlir/include/mlir/Conversion/ShapeToStandard/ShapeToStandard.h
mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h
mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
mlir/include/mlir/Conversion/TosaToSCF/TosaToSCF.h
mlir/include/mlir/Conversion/TosaToStandard/TosaToStandard.h
mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h
mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h
mlir/include/mlir/Dialect/GPU/Passes.h
mlir/include/mlir/Dialect/Linalg/Passes.h
mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/include/mlir/Dialect/Math/Transforms/Passes.h
mlir/include/mlir/Dialect/SCF/Transforms.h
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.h
mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h
mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h
mlir/include/mlir/Dialect/Vector/VectorOps.h
mlir/include/mlir/IR/PatternMatch.h
mlir/include/mlir/Transforms/Bufferize.h
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp
mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp
mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp
mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp
mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.cpp
mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
mlir/lib/Conversion/TosaToSCF/TosaToSCFPass.cpp
mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp
mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
mlir/lib/Dialect/Affine/Utils/Utils.cpp
mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
mlir/lib/Dialect/Linalg/Transforms/SparseLowering.cpp
mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/Dialect/Math/Transforms/ExpandTanh.cpp
mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp
mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp
mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp
mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.cpp
mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp
mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp
mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp
mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp
mlir/lib/Dialect/Shape/Transforms/StructuralTypeConversions.cpp
mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp
mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp
mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp
mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp
mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/lib/Transforms/Bufferize.cpp
mlir/lib/Transforms/Canonicalizer.cpp
mlir/lib/Transforms/Utils/DialectConversion.cpp
mlir/lib/Transforms/Utils/LoopUtils.cpp
mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
mlir/test/lib/Dialect/SPIRV/TestGLSLCanonicalization.cpp
mlir/test/lib/Dialect/Test/TestPatterns.cpp
mlir/test/lib/Dialect/Test/TestTraits.cpp
mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp
mlir/test/lib/Transforms/TestConvVectorization.cpp
mlir/test/lib/Transforms/TestConvertCallOp.cpp
mlir/test/lib/Transforms/TestDecomposeCallGraphTypes.cpp
mlir/test/lib/Transforms/TestExpandTanh.cpp
mlir/test/lib/Transforms/TestGpuRewrite.cpp
mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
mlir/test/lib/Transforms/TestLinalgTransforms.cpp
mlir/test/lib/Transforms/TestPolynomialApproximation.cpp
mlir/test/lib/Transforms/TestSparsification.cpp
mlir/test/lib/Transforms/TestVectorTransforms.cpp
mlir/unittests/Rewrite/PatternBenefit.cpp

index 4647cac..8d3301c 100644 (file)
@@ -18,7 +18,6 @@ class AffineMap;
 class AffineParallelOp;
 class Location;
 struct LogicalResult;
-class MLIRContext;
 class OpBuilder;
 class Pass;
 class RewritePattern;
@@ -43,13 +42,12 @@ Optional<SmallVector<Value, 8>> expandAffineMap(OpBuilder &builder,
 /// Collect a set of patterns to convert from the Affine dialect to the Standard
 /// dialect, in particular convert structured affine control flow into CFG
 /// branch-based control flow.
-void populateAffineToStdConversionPatterns(OwningRewritePatternList &patterns,
-                                           MLIRContext *ctx);
+void populateAffineToStdConversionPatterns(OwningRewritePatternList &patterns);
 
 /// Collect a set of patterns to convert vector-related Affine ops to the Vector
 /// dialect.
 void populateAffineToVectorConversionPatterns(
-    OwningRewritePatternList &patterns, MLIRContext *ctx);
+    OwningRewritePatternList &patterns);
 
 /// Emit code that computes the lower bound of the given affine loop using
 /// standard arithmetic operations.
index 938c5cb..670942a 100644 (file)
@@ -33,8 +33,8 @@ std::unique_ptr<OperationPass<ModuleOp>> createConvertAsyncToLLVMPass();
 /// the TypeConverter, but otherwise don't care what type conversions are
 /// happening.
 void populateAsyncStructuralTypeConversionsAndLegality(
-    MLIRContext *context, TypeConverter &typeConverter,
-    OwningRewritePatternList &patterns, ConversionTarget &target);
+    TypeConverter &typeConverter, OwningRewritePatternList &patterns,
+    ConversionTarget &target);
 
 } // namespace mlir
 
index ad5dac0..e679b86 100644 (file)
@@ -21,8 +21,7 @@ class SPIRVTypeConverter;
 /// Appends to a pattern list additional patterns for translating GPU Ops to
 /// SPIR-V ops. For a gpu.func to be converted, it should have a
 /// spv.entry_point_abi attribute.
-void populateGPUToSPIRVPatterns(MLIRContext *context,
-                                SPIRVTypeConverter &typeConverter,
+void populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                 OwningRewritePatternList &patterns);
 } // namespace mlir
 
index b2fc9e4..8f94597 100644 (file)
@@ -20,8 +20,7 @@ class SPIRVTypeConverter;
 
 /// Appends to a pattern list additional patterns for translating Linalg ops to
 /// SPIR-V ops.
-void populateLinalgToSPIRVPatterns(MLIRContext *context,
-                                   SPIRVTypeConverter &typeConverter,
+void populateLinalgToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                    OwningRewritePatternList &patterns);
 
 } // namespace mlir
index 3a6c8bb..240bc1f 100644 (file)
@@ -70,7 +70,7 @@ public:
 
 /// Populate the given list with patterns that convert from Linalg to Standard.
 void populateLinalgToStandardConversionPatterns(
-    OwningRewritePatternList &patterns, MLIRContext *ctx);
+    OwningRewritePatternList &patterns);
 
 } // namespace linalg
 
index d6316f6..14c1608 100644 (file)
@@ -42,8 +42,7 @@ LogicalResult convertAffineLoopNestToGPULaunch(AffineForOp forOp,
 
 /// Adds the conversion pattern from `scf.parallel` to `gpu.launch` to the
 /// provided pattern list.
-void populateParallelLoopToGPUPatterns(OwningRewritePatternList &patterns,
-                                       MLIRContext *ctx);
+void populateParallelLoopToGPUPatterns(OwningRewritePatternList &patterns);
 
 /// Configures the rewrite target such that only `scf.parallel` operations that
 /// are not rewritten by the provided patterns are legal.
index e0bab27..5a14c9b 100644 (file)
@@ -15,7 +15,6 @@
 #include <memory>
 
 namespace mlir {
-class MLIRContext;
 class Pass;
 
 // Owning list of rewriting patterns.
@@ -35,8 +34,7 @@ private:
 
 /// Collects a set of patterns to lower from scf.for, scf.if, and
 /// loop.terminator to CFG operations within the SPIR-V dialect.
-void populateSCFToSPIRVPatterns(MLIRContext *context,
-                                SPIRVTypeConverter &typeConverter,
+void populateSCFToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                 ScfToSPIRVContext &scfToSPIRVContext,
                                 OwningRewritePatternList &patterns);
 } // namespace mlir
index fd85a3d..95667d8 100644 (file)
@@ -14,7 +14,6 @@
 
 namespace mlir {
 struct LogicalResult;
-class MLIRContext;
 class Pass;
 class RewritePattern;
 
@@ -24,8 +23,7 @@ class OwningRewritePatternList;
 /// Collect a set of patterns to lower from scf.for, scf.if, and
 /// loop.terminator to CFG operations within the Standard dialect, in particular
 /// convert structured control flow into CFG branch-based control flow.
-void populateLoopToStdConversionPatterns(OwningRewritePatternList &patterns,
-                                         MLIRContext *ctx);
+void populateLoopToStdConversionPatterns(OwningRewritePatternList &patterns);
 
 /// Creates a pass to convert scf.for, scf.if and loop.terminator ops to CFG.
 std::unique_ptr<Pass> createLowerToCFGPass();
index 3ba24ea..2f6b6d7 100644 (file)
@@ -40,20 +40,17 @@ void encodeBindAttribute(ModuleOp module);
 void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter);
 
 /// Populates the given list with patterns that convert from SPIR-V to LLVM.
-void populateSPIRVToLLVMConversionPatterns(MLIRContext *context,
-                                           LLVMTypeConverter &typeConverter,
+void populateSPIRVToLLVMConversionPatterns(LLVMTypeConverter &typeConverter,
                                            OwningRewritePatternList &patterns);
 
 /// Populates the given list with patterns for function conversion from SPIR-V
 /// to LLVM.
 void populateSPIRVToLLVMFunctionConversionPatterns(
-    MLIRContext *context, LLVMTypeConverter &typeConverter,
-    OwningRewritePatternList &patterns);
+    LLVMTypeConverter &typeConverter, OwningRewritePatternList &patterns);
 
 /// Populates the given patterns for module conversion from SPIR-V to LLVM.
 void populateSPIRVToLLVMModuleConversionPatterns(
-    MLIRContext *context, LLVMTypeConverter &typeConverter,
-    OwningRewritePatternList &patterns);
+    LLVMTypeConverter &typeConverter, OwningRewritePatternList &patterns);
 
 } // namespace mlir
 
index 176f101..7c94470 100644 (file)
 namespace mlir {
 
 class FuncOp;
-class MLIRContext;
 class ModuleOp;
 template <typename T>
 class OperationPass;
 class OwningRewritePatternList;
 
 void populateShapeToStandardConversionPatterns(
-    OwningRewritePatternList &patterns, MLIRContext *ctx);
+    OwningRewritePatternList &patterns);
 
 std::unique_ptr<OperationPass<ModuleOp>> createConvertShapeToStandardPass();
 
 void populateConvertShapeConstraintsConversionPatterns(
-    OwningRewritePatternList &patterns, MLIRContext *ctx);
+    OwningRewritePatternList &patterns);
 
 std::unique_ptr<OperationPass<FuncOp>> createConvertShapeConstraintsPass();
 
index 87946d3..18cf4f3 100644 (file)
@@ -21,8 +21,7 @@ class SPIRVTypeConverter;
 /// Appends to a pattern list additional patterns for translating standard ops
 /// to SPIR-V ops. Also adds the patterns to legalize ops not directly
 /// translated to SPIR-V dialect.
-void populateStandardToSPIRVPatterns(MLIRContext *context,
-                                     SPIRVTypeConverter &typeConverter,
+void populateStandardToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                      OwningRewritePatternList &patterns);
 
 /// Appends to a pattern list additional patterns for translating tensor ops
@@ -37,15 +36,14 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
 /// variables. SPIR-V consumers in GPU drivers may or may not optimize that
 /// away. So this has implications over register pressure. Therefore, a
 /// threshold is used to control when the patterns should kick in.
-void populateTensorToSPIRVPatterns(MLIRContext *context,
-                                   SPIRVTypeConverter &typeConverter,
+void populateTensorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                    int64_t byteCountThreshold,
                                    OwningRewritePatternList &patterns);
 
 /// Appends to a pattern list patterns to legalize ops that are not directly
 /// lowered to SPIR-V.
 void populateStdLegalizationPatternsForSPIRVLowering(
-    MLIRContext *context, OwningRewritePatternList &patterns);
+    OwningRewritePatternList &patterns);
 
 } // namespace mlir
 
index 42493a5..7553839 100644 (file)
@@ -28,7 +28,7 @@ void addTosaToLinalgOnTensorsPasses(OpPassManager &pm);
 
 /// Populates conversion passes from TOSA dialect to Linalg dialect.
 void populateTosaToLinalgOnTensorsConversionPatterns(
-    MLIRContext *context, OwningRewritePatternList *patterns);
+    OwningRewritePatternList *patterns);
 
 } // namespace tosa
 } // namespace mlir
index 68ed0e0..08b2fe9 100644 (file)
@@ -20,8 +20,7 @@ namespace tosa {
 
 std::unique_ptr<Pass> createTosaToSCF();
 
-void populateTosaToSCFConversionPatterns(MLIRContext *context,
-                                         OwningRewritePatternList *patterns);
+void populateTosaToSCFConversionPatterns(OwningRewritePatternList *patterns);
 
 /// Populates passes to convert from TOSA to SCF.
 void addTosaToSCFPasses(OpPassManager &pm);
index 5a63d78..f130471 100644 (file)
@@ -21,10 +21,10 @@ namespace tosa {
 std::unique_ptr<Pass> createTosaToStandard();
 
 void populateTosaToStandardConversionPatterns(
-    MLIRContext *context, OwningRewritePatternList *patterns);
+    OwningRewritePatternList *patterns);
 
 void populateTosaRescaleToStandardConversionPatterns(
-    MLIRContext *context, OwningRewritePatternList *patterns);
+    OwningRewritePatternList *patterns);
 
 /// Populates passes to convert from TOSA to Standard.
 void addTosaToStandardPasses(OpPassManager &pm);
index f34a576..e7478cf 100644 (file)
@@ -162,7 +162,7 @@ struct VectorTransferRewriter : public RewritePattern {
 
 /// Collect a set of patterns to convert from the Vector dialect to SCF + std.
 void populateVectorToSCFConversionPatterns(
-    OwningRewritePatternList &patterns, MLIRContext *context,
+    OwningRewritePatternList &patterns,
     const VectorTransferToSCFOptions &options = VectorTransferToSCFOptions());
 
 /// Create a pass to convert a subset of vector ops to SCF.
index 7908f6e..8fc606f 100644 (file)
@@ -20,8 +20,7 @@ class SPIRVTypeConverter;
 
 /// Appends to a pattern list additional patterns for translating Vector Ops to
 /// SPIR-V ops.
-void populateVectorToSPIRVPatterns(MLIRContext *context,
-                                   SPIRVTypeConverter &typeConverter,
+void populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                    OwningRewritePatternList &patterns);
 
 } // namespace mlir
index bfb5626..327f9d6 100644 (file)
@@ -31,13 +31,11 @@ std::unique_ptr<OperationPass<ModuleOp>> createGpuKernelOutliningPass();
 std::unique_ptr<OperationPass<FuncOp>> createGpuAsyncRegionPass();
 
 /// Collect a set of patterns to rewrite all-reduce ops within the GPU dialect.
-void populateGpuAllReducePatterns(MLIRContext *context,
-                                  OwningRewritePatternList &patterns);
+void populateGpuAllReducePatterns(OwningRewritePatternList &patterns);
 
 /// Collect all patterns to rewrite ops within the GPU dialect.
-inline void populateGpuRewritePatterns(MLIRContext *context,
-                                       OwningRewritePatternList &patterns) {
-  populateGpuAllReducePatterns(context, patterns);
+inline void populateGpuRewritePatterns(OwningRewritePatternList &patterns) {
+  populateGpuAllReducePatterns(patterns);
 }
 
 namespace gpu {
index 34e2568..24f49b5 100644 (file)
@@ -53,7 +53,7 @@ std::unique_ptr<OperationPass<FuncOp>> createLinalgBufferizePass();
 /// Populate patterns that convert `ElementwiseMappable` ops to linalg
 /// parallel loops.
 void populateElementwiseToLinalgConversionPatterns(
-    OwningRewritePatternList &patterns, MLIRContext *ctx);
+    OwningRewritePatternList &patterns);
 
 /// Create a pass to conver named Linalg operations to Linalg generic
 /// operations.
@@ -67,14 +67,14 @@ std::unique_ptr<Pass> createLinalgDetensorizePass();
 /// producer (consumer) generic operation by expanding the dimensionality of the
 /// loop in the generic op.
 void populateFoldReshapeOpsByExpansionPatterns(
-    MLIRContext *context, OwningRewritePatternList &patterns);
+    OwningRewritePatternList &patterns);
 
 /// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
 /// producer (consumer) generic/indexed_generic operation by linearizing the
 /// indexing map used to access the source (target) of the reshape operation in
 /// the generic/indexed_generic operation.
 void populateFoldReshapeOpsByLinearizationPatterns(
-    MLIRContext *context, OwningRewritePatternList &patterns);
+    OwningRewritePatternList &patterns);
 
 /// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
 /// producer (consumer) generic/indexed_generic operation by linearizing the
@@ -83,16 +83,15 @@ void populateFoldReshapeOpsByLinearizationPatterns(
 /// the tensor reshape involved is collapsing (introducing) unit-extent
 /// dimensions.
 void populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
-    MLIRContext *context, OwningRewritePatternList &patterns);
+    OwningRewritePatternList &patterns);
 
 /// Patterns for fusing linalg operation on tensors.
-void populateLinalgTensorOpsFusionPatterns(MLIRContext *context,
-                                           OwningRewritePatternList &patterns);
+void populateLinalgTensorOpsFusionPatterns(OwningRewritePatternList &patterns);
 
 /// Patterns to fold unit-extent dimensions in operands/results of linalg ops on
 /// tensors.
 void populateLinalgFoldUnitExtentDimsPatterns(
-    MLIRContext *context, OwningRewritePatternList &patterns);
+    OwningRewritePatternList &patterns);
 
 //===----------------------------------------------------------------------===//
 // Registration
index 872e763..421a544 100644 (file)
@@ -36,11 +36,11 @@ template <template <typename> class PatternType, typename ConcreteOpType,
           typename = std::enable_if_t<std::is_member_function_pointer<
               decltype(&ConcreteOpType::getOperationName)>::value>>
 void sfinae_enqueue(OwningRewritePatternList &patternList, OptionsType options,
-                    MLIRContext *context, StringRef opName,
-                    linalg::LinalgTransformationFilter m) {
+                    StringRef opName, linalg::LinalgTransformationFilter m) {
   assert(opName == ConcreteOpType::getOperationName() &&
          "explicit name must match ConcreteOpType::getOperationName");
-  patternList.insert<PatternType<ConcreteOpType>>(context, options, m);
+  patternList.insert<PatternType<ConcreteOpType>>(patternList.getContext(),
+                                                  options, m);
 }
 
 /// SFINAE: Enqueue helper for OpType that do not have a `getOperationName`
@@ -48,25 +48,26 @@ void sfinae_enqueue(OwningRewritePatternList &patternList, OptionsType options,
 template <template <typename> class PatternType, typename OpType,
           typename OptionsType>
 void sfinae_enqueue(OwningRewritePatternList &patternList, OptionsType options,
-                    MLIRContext *context, StringRef opName,
-                    linalg::LinalgTransformationFilter m) {
+                    StringRef opName, linalg::LinalgTransformationFilter m) {
   assert(!opName.empty() && "opName must not be empty");
-  patternList.insert<PatternType<OpType>>(opName, context, options, m);
+  patternList.insert<PatternType<OpType>>(opName, patternList.getContext(),
+                                          options, m);
 }
 
 template <typename PatternType, typename OpType, typename OptionsType>
 void enqueue(OwningRewritePatternList &patternList, OptionsType options,
-             MLIRContext *context, StringRef opName,
-             linalg::LinalgTransformationFilter m) {
+             StringRef opName, linalg::LinalgTransformationFilter m) {
   if (!opName.empty())
-    patternList.insert<PatternType>(opName, context, options, m);
+    patternList.insert<PatternType>(opName, patternList.getContext(), options,
+                                    m);
   else
     patternList.insert<PatternType>(m.addOpFilter<OpType>(), options);
 }
 
 /// Promotion transformation enqueues a particular stage-1 pattern for
 /// `Tile<LinalgOpType>`with the appropriate `options`.
-template <typename LinalgOpType> struct Tile : public Transformation {
+template <typename LinalgOpType>
+struct Tile : public Transformation {
   explicit Tile(linalg::LinalgTilingOptions options,
                 linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
       : Transformation(f), opName(LinalgOpType::getOperationName()),
@@ -79,9 +80,9 @@ template <typename LinalgOpType> struct Tile : public Transformation {
   OwningRewritePatternList
   buildRewritePatterns(MLIRContext *context,
                        linalg::LinalgTransformationFilter m) override {
-    OwningRewritePatternList tilingPatterns;
+    OwningRewritePatternList tilingPatterns(context);
     sfinae_enqueue<linalg::LinalgTilingPattern, LinalgOpType>(
-        tilingPatterns, options, context, opName, m);
+        tilingPatterns, options, opName, m);
     return tilingPatterns;
   }
 
@@ -92,7 +93,8 @@ private:
 
 /// Promotion transformation enqueues a particular stage-1 pattern for
 /// `Promote<LinalgOpType>`with the appropriate `options`.
-template <typename LinalgOpType> struct Promote : public Transformation {
+template <typename LinalgOpType>
+struct Promote : public Transformation {
   explicit Promote(
       linalg::LinalgPromotionOptions options,
       linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
@@ -106,9 +108,9 @@ template <typename LinalgOpType> struct Promote : public Transformation {
   OwningRewritePatternList
   buildRewritePatterns(MLIRContext *context,
                        linalg::LinalgTransformationFilter m) override {
-    OwningRewritePatternList promotionPatterns;
+    OwningRewritePatternList promotionPatterns(context);
     sfinae_enqueue<linalg::LinalgPromotionPattern, LinalgOpType>(
-        promotionPatterns, options, context, opName, m);
+        promotionPatterns, options, opName, m);
     return promotionPatterns;
   }
 
@@ -134,9 +136,9 @@ struct Vectorize : public Transformation {
   OwningRewritePatternList
   buildRewritePatterns(MLIRContext *context,
                        linalg::LinalgTransformationFilter m) override {
-    OwningRewritePatternList vectorizationPatterns;
+    OwningRewritePatternList vectorizationPatterns(context);
     enqueue<linalg::LinalgVectorizationPattern, LinalgOpType>(
-        vectorizationPatterns, options, context, opName, m);
+        vectorizationPatterns, options, opName, m);
     vectorizationPatterns.insert<linalg::LinalgCopyVTRForwardingPattern,
                                  linalg::LinalgCopyVTWForwardingPattern>(
         context, /*benefit=*/2);
index 6d42838..318db82 100644 (file)
@@ -37,8 +37,7 @@ void populateConvVectorizationPatterns(
     ArrayRef<int64_t> tileSizes);
 
 /// Populates the given list with patterns to bufferize linalg ops.
-void populateLinalgBufferizePatterns(MLIRContext *context,
-                                     BufferizeTypeConverter &converter,
+void populateLinalgBufferizePatterns(BufferizeTypeConverter &converter,
                                      OwningRewritePatternList &patterns);
 
 /// Performs standalone tiling of a single LinalgOp by `tileSizes`.
@@ -445,7 +444,7 @@ struct LinalgTilingOptions {
 OwningRewritePatternList
 getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx);
 void populateLinalgTilingCanonicalizationPatterns(
-    OwningRewritePatternList &patterns, MLIRContext *ctx);
+    OwningRewritePatternList &patterns);
 
 /// Base pattern that applied the tiling transformation specified by `options`.
 /// Abort and return failure in 2 cases:
@@ -692,11 +691,10 @@ template <
     typename = std::enable_if_t<detect_has_get_operation_name<OpType>::value>,
     typename = void>
 void insertVectorizationPatternImpl(OwningRewritePatternList &patternList,
-                                    MLIRContext *context,
                                     linalg::LinalgVectorizationOptions options,
                                     linalg::LinalgTransformationFilter f) {
   patternList.insert<linalg::LinalgVectorizationPattern>(
-      OpType::getOperationName(), context, options, f);
+      OpType::getOperationName(), patternList.getContext(), options, f);
 }
 
 /// SFINAE helper for single C++ class without a `getOperationName` method (e.g.
@@ -704,7 +702,6 @@ void insertVectorizationPatternImpl(OwningRewritePatternList &patternList,
 template <typename OpType, typename = std::enable_if_t<
                                !detect_has_get_operation_name<OpType>::value>>
 void insertVectorizationPatternImpl(OwningRewritePatternList &patternList,
-                                    MLIRContext *context,
                                     linalg::LinalgVectorizationOptions options,
                                     linalg::LinalgTransformationFilter f) {
   patternList.insert<linalg::LinalgVectorizationPattern>(
@@ -714,14 +711,14 @@ void insertVectorizationPatternImpl(OwningRewritePatternList &patternList,
 /// Variadic helper function to insert vectorization patterns for C++ ops.
 template <typename... OpTypes>
 void insertVectorizationPatterns(OwningRewritePatternList &patternList,
-                                 MLIRContext *context,
                                  linalg::LinalgVectorizationOptions options,
                                  linalg::LinalgTransformationFilter f =
                                      linalg::LinalgTransformationFilter()) {
   // FIXME: In c++17 this can be simplified by using 'fold expressions'.
-  (void)std::initializer_list<int>{0, (insertVectorizationPatternImpl<OpTypes>(
-                                           patternList, context, options, f),
-                                       0)...};
+  (void)std::initializer_list<int>{
+      0, (insertVectorizationPatternImpl<OpTypes>(
+              patternList, patternList.getContext(), options, f),
+          0)...};
 }
 
 ///
@@ -793,13 +790,13 @@ private:
 /// Populates `patterns` with patterns to convert spec-generated named ops to
 /// linalg.generic ops.
 void populateLinalgNamedOpsGeneralizationPatterns(
-    MLIRContext *context, OwningRewritePatternList &patterns,
+    OwningRewritePatternList &patterns,
     LinalgTransformationFilter filter = LinalgTransformationFilter());
 
 /// Populates `patterns` with patterns to convert linalg.conv ops to
 /// linalg.generic ops.
 void populateLinalgConvGeneralizationPatterns(
-    MLIRContext *context, OwningRewritePatternList &patterns,
+    OwningRewritePatternList &patterns,
     LinalgTransformationFilter filter = LinalgTransformationFilter());
 
 //===----------------------------------------------------------------------===//
@@ -893,7 +890,7 @@ struct AffineMinSCFCanonicalizationPattern
                                 PatternRewriter &rewriter) const override;
 };
 
-  /// Helper struct to return the results of `substituteMin`.
+/// Helper struct to return the results of `substituteMin`.
 struct AffineMapAndOperands {
   AffineMap map;
   SmallVector<Value> dims;
@@ -914,8 +911,8 @@ struct AffineMapAndOperands {
 /// Return a new AffineMap, dims and symbols that have been canonicalized and
 /// simplified.
 AffineMapAndOperands substituteMin(
-  AffineMinOp affineMinOp,
-  llvm::function_ref<bool(Operation *)> substituteOperation = nullptr);
+    AffineMinOp affineMinOp,
+    llvm::function_ref<bool(Operation *)> substituteOperation = nullptr);
 
 /// Converts Convolution op into vector contraction.
 ///
@@ -1060,12 +1057,12 @@ struct SparsificationOptions {
 
 /// Sets up sparsification rewriting rules with the given options.
 void populateSparsificationPatterns(
-    MLIRContext *context, OwningRewritePatternList &patterns,
+    OwningRewritePatternList &patterns,
     const SparsificationOptions &options = SparsificationOptions());
 
 /// Sets up sparsification conversion rules with the given options.
 void populateSparsificationConversionPatterns(
-    MLIRContext *context, OwningRewritePatternList &patterns);
+    OwningRewritePatternList &patterns);
 
 } // namespace linalg
 } // namespace mlir
index c965bab..3ce88a1 100644 (file)
@@ -9,18 +9,14 @@
 #ifndef MLIR_DIALECT_MATH_TRANSFORMS_PASSES_H_
 #define MLIR_DIALECT_MATH_TRANSFORMS_PASSES_H_
 
-#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/Bufferize.h"
-
 namespace mlir {
 
 class OwningRewritePatternList;
 
-void populateExpandTanhPattern(OwningRewritePatternList &patterns,
-                               MLIRContext *ctx);
+void populateExpandTanhPattern(OwningRewritePatternList &patterns);
 
 void populateMathPolynomialApproximationPatterns(
-    OwningRewritePatternList &patterns, MLIRContext *ctx);
+    OwningRewritePatternList &patterns);
 
 } // namespace mlir
 
index 456eb4e..914a1a0 100644 (file)
@@ -60,8 +60,8 @@ tileParallelLoop(ParallelOp op, llvm::ArrayRef<int64_t> tileSizes);
 /// corresponding scf.yield ops need to update their types accordingly to the
 /// TypeConverter, but otherwise don't care what type conversions are happening.
 void populateSCFStructuralTypeConversionsAndLegality(
-    MLIRContext *context, TypeConverter &typeConverter,
-    OwningRewritePatternList &patterns, ConversionTarget &target);
+    TypeConverter &typeConverter, OwningRewritePatternList &patterns,
+    ConversionTarget &target);
 
 } // namespace scf
 } // namespace mlir
index 1921dbb..098d4fd 100644 (file)
@@ -24,7 +24,7 @@
 namespace mlir {
 namespace spirv {
 void populateSPIRVGLSLCanonicalizationPatterns(
-    mlir::OwningRewritePatternList &results, mlir::MLIRContext *context);
+    mlir::OwningRewritePatternList &results);
 } // namespace spirv
 } // namespace mlir
 
index 1ac7db1..d7cd76b 100644 (file)
@@ -67,8 +67,7 @@ private:
 /// `func` op to the SPIR-V dialect. These patterns do not handle shader
 /// interface/ABI; they convert function parameters to be of SPIR-V allowed
 /// types.
-void populateBuiltinFuncToSPIRVPatterns(MLIRContext *context,
-                                        SPIRVTypeConverter &typeConverter,
+void populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                         OwningRewritePatternList &patterns);
 
 namespace spirv {
index 6df1299..9e4b4af 100644 (file)
@@ -28,8 +28,7 @@ namespace mlir {
 std::unique_ptr<Pass> createShapeToShapeLowering();
 
 /// Collects a set of patterns to rewrite ops within the Shape dialect.
-void populateShapeRewritePatterns(MLIRContext *context,
-                                  OwningRewritePatternList &patterns);
+void populateShapeRewritePatterns(OwningRewritePatternList &patterns);
 
 // Collects a set of patterns to replace all constraints with passing witnesses.
 // This is intended to then allow all ShapeConstraint related ops and data to
@@ -37,8 +36,7 @@ void populateShapeRewritePatterns(MLIRContext *context,
 // canonicalization and dead code elimination.
 //
 // After this pass, no cstr_ operations exist.
-void populateRemoveShapeConstraintsPatterns(OwningRewritePatternList &patterns,
-                                            MLIRContext *ctx);
+void populateRemoveShapeConstraintsPatterns(OwningRewritePatternList &patterns);
 std::unique_ptr<FunctionPass> createRemoveShapeConstraintsPass();
 
 /// Populates patterns for shape dialect structural type conversions and sets up
@@ -53,8 +51,8 @@ std::unique_ptr<FunctionPass> createRemoveShapeConstraintsPass();
 /// do for a structural type conversion is to update both of their types
 /// consistently to the new types prescribed by the TypeConverter.
 void populateShapeStructuralTypeConversionsAndLegality(
-    MLIRContext *context, TypeConverter &typeConverter,
-    OwningRewritePatternList &patterns, ConversionTarget &target);
+    TypeConverter &typeConverter, OwningRewritePatternList &patterns,
+    ConversionTarget &target);
 
 // Bufferizes shape dialect ops.
 //
index 1a0308d..a7eb59a 100644 (file)
@@ -25,7 +25,6 @@ class TypeConverter;
 /// Add a pattern to the given pattern list to convert the operand and result
 /// types of a CallOp with the given type converter.
 void populateCallOpTypeConversionPattern(OwningRewritePatternList &patterns,
-                                         MLIRContext *ctx,
                                          TypeConverter &converter);
 
 /// Add a pattern to the given pattern list to rewrite branch operations to use
@@ -33,8 +32,7 @@ void populateCallOpTypeConversionPattern(OwningRewritePatternList &patterns,
 /// be done if the branch operation implements the BranchOpInterface. Only
 /// needed for partial conversions.
 void populateBranchOpInterfaceTypeConversionPattern(
-    OwningRewritePatternList &patterns, MLIRContext *ctx,
-    TypeConverter &converter);
+    OwningRewritePatternList &patterns, TypeConverter &converter);
 
 /// Return true if op is a BranchOpInterface op whose operands are all legal
 /// according to converter.
@@ -44,7 +42,6 @@ bool isLegalForBranchOpInterfaceTypeConversionPattern(Operation *op,
 /// Add a pattern to the given pattern list to rewrite `return` ops to use
 /// operands that have been legalized by the conversion framework.
 void populateReturnOpTypeConversionPattern(OwningRewritePatternList &patterns,
-                                           MLIRContext *ctx,
                                            TypeConverter &converter);
 
 /// For ReturnLike ops (except `return`), return True. If op is a `return` &&
index a6fdca8..1e04b22 100644 (file)
@@ -21,8 +21,7 @@ namespace mlir {
 
 class OwningRewritePatternList;
 
-void populateStdBufferizePatterns(MLIRContext *context,
-                                  BufferizeTypeConverter &typeConverter,
+void populateStdBufferizePatterns(BufferizeTypeConverter &typeConverter,
                                   OwningRewritePatternList &patterns);
 
 /// Creates an instance of std bufferization pass.
@@ -42,8 +41,7 @@ std::unique_ptr<Pass> createTensorConstantBufferizePass();
 std::unique_ptr<Pass> createStdExpandOpsPass();
 
 /// Collects a set of patterns to rewrite ops within the Std dialect.
-void populateStdExpandOpsPatterns(MLIRContext *context,
-                                  OwningRewritePatternList &patterns);
+void populateStdExpandOpsPatterns(OwningRewritePatternList &patterns);
 
 //===----------------------------------------------------------------------===//
 // Registration
index 436b3fc..72539c8 100644 (file)
@@ -16,8 +16,7 @@ namespace mlir {
 
 class OwningRewritePatternList;
 
-void populateTensorBufferizePatterns(MLIRContext *context,
-                                     BufferizeTypeConverter &typeConverter,
+void populateTensorBufferizePatterns(BufferizeTypeConverter &typeConverter,
                                      OwningRewritePatternList &patterns);
 
 /// Creates an instance of `tensor` dialect bufferization pass.
index 9e486d0..7d20e64 100644 (file)
@@ -39,11 +39,11 @@ struct BitmaskEnumStorage;
 
 /// Collect a set of vector-to-vector canonicalization patterns.
 void populateVectorToVectorCanonicalizationPatterns(
-    OwningRewritePatternList &patterns, MLIRContext *context);
+    OwningRewritePatternList &patterns);
 
 /// Collect a set of vector-to-vector transformation patterns.
 void populateVectorToVectorTransformationPatterns(
-    OwningRewritePatternList &patterns, MLIRContext *context);
+    OwningRewritePatternList &patterns);
 
 /// Collect a set of patterns to split transfer read/write ops.
 ///
@@ -54,7 +54,7 @@ void populateVectorToVectorTransformationPatterns(
 /// of being generic canonicalization patterns. Also one can let the
 /// `ignoreFilter` to return true to fail matching for fine-grained control.
 void populateSplitVectorTransferPatterns(
-    OwningRewritePatternList &patterns, MLIRContext *context,
+    OwningRewritePatternList &patterns,
     std::function<bool(Operation *)> ignoreFilter = nullptr);
 
 /// Collect a set of leading one dimension removal patterns.
@@ -64,15 +64,14 @@ void populateSplitVectorTransferPatterns(
 /// With them, there are more chances that we can cancel out extract-insert
 /// pairs or forward write-read pairs.
 void populateCastAwayVectorLeadingOneDimPatterns(
-    OwningRewritePatternList &patterns, MLIRContext *context);
+    OwningRewritePatternList &patterns);
 
 /// Collect a set of patterns that bubble up/down bitcast ops.
 ///
 /// These patterns move vector.bitcast ops to be before insert ops or after
 /// extract ops where suitable. With them, bitcast will happen on smaller
 /// vectors and there are more chances to share extract/insert ops.
-void populateBubbleVectorBitCastOpPatterns(OwningRewritePatternList &patterns,
-                                           MLIRContext *context);
+void populateBubbleVectorBitCastOpPatterns(OwningRewritePatternList &patterns);
 
 /// Collect a set of vector slices transformation patterns:
 ///    ExtractSlicesOpLowering, InsertSlicesOpLowering
@@ -82,15 +81,13 @@ void populateBubbleVectorBitCastOpPatterns(OwningRewritePatternList &patterns,
 /// use for "slices" ops), this lowering removes all tuple related
 /// operations as well (through DCE and folding). If tuple values
 /// "leak" coming in, however, some tuple related ops will remain.
-void populateVectorSlicesLoweringPatterns(OwningRewritePatternList &patterns,
-                                          MLIRContext *context);
+void populateVectorSlicesLoweringPatterns(OwningRewritePatternList &patterns);
 
 /// Collect a set of transfer read/write lowering patterns.
 ///
 /// These patterns lower transfer ops to simpler ops like `vector.load`,
 /// `vector.store` and `vector.broadcast`.
-void populateVectorTransferLoweringPatterns(OwningRewritePatternList &patterns,
-                                            MLIRContext *context);
+void populateVectorTransferLoweringPatterns(OwningRewritePatternList &patterns);
 
 /// An attribute that specifies the combining function for `vector.contract`,
 /// and `vector.reduction`.
@@ -174,7 +171,7 @@ struct VectorTransformsOptions {
 /// These transformation express higher level vector ops in terms of more
 /// elementary extraction, insertion, reduction, product, and broadcast ops.
 void populateVectorContractLoweringPatterns(
-    OwningRewritePatternList &patterns, MLIRContext *context,
+    OwningRewritePatternList &patterns,
     VectorTransformsOptions vectorTransformOptions = VectorTransformsOptions());
 
 /// Returns the integer type required for subscripts in the vector dialect.
index c797f53..bc49103 100644 (file)
@@ -255,7 +255,8 @@ public:
   PDLValue(ValueRange *value) : value(value), kind(Kind::ValueRange) {}
 
   /// Returns true if the type of the held value is `T`.
-  template <typename T> bool isa() const {
+  template <typename T>
+  bool isa() const {
     assert(value && "isa<> used on a null value");
     return kind == getKindOf<T>();
   }
@@ -271,7 +272,8 @@ public:
 
   /// Cast this value to type `T`, asserts if this value is not an instance of
   /// `T`.
-  template <typename T> T cast() const {
+  template <typename T>
+  T cast() const {
     assert(isa<T>() && "expected value to be of type `T`");
     return castImpl<T>();
   }
@@ -290,7 +292,8 @@ public:
 
 private:
   /// Find the index of a given type in a range of other types.
-  template <typename...> struct index_of_t;
+  template <typename...>
+  struct index_of_t;
   template <typename T, typename... R>
   struct index_of_t<T, T, R...> : std::integral_constant<size_t, 0> {};
   template <typename T, typename F, typename... R>
@@ -298,7 +301,8 @@ private:
       : std::integral_constant<size_t, 1 + index_of_t<T, R...>::value> {};
 
   /// Return the kind used for the given T.
-  template <typename T> static Kind getKindOf() {
+  template <typename T>
+  static Kind getKindOf() {
     return static_cast<Kind>(index_of_t<T, Attribute, Operation *, Type,
                                         TypeRange, Value, ValueRange>::value);
   }
@@ -718,14 +722,19 @@ class OwningRewritePatternList {
   using NativePatternListT = std::vector<std::unique_ptr<RewritePattern>>;
 
 public:
-  OwningRewritePatternList() = default;
+  OwningRewritePatternList(MLIRContext *context) : context(context) {}
 
   /// Construct a OwningRewritePatternList populated with the given pattern.
-  OwningRewritePatternList(std::unique_ptr<RewritePattern> pattern) {
+  OwningRewritePatternList(MLIRContext *context,
+                           std::unique_ptr<RewritePattern> pattern)
+      : context(context) {
     nativePatterns.emplace_back(std::move(pattern));
   }
   OwningRewritePatternList(PDLPatternModule &&pattern)
-      : pdlPatterns(std::move(pattern)) {}
+      : context(pattern.getModule()->getContext()),
+        pdlPatterns(std::move(pattern)) {}
+
+  MLIRContext *getContext() const { return context; }
 
   /// Return the native patterns held in this list.
   NativePatternListT &getNativePatterns() { return nativePatterns; }
@@ -750,7 +759,7 @@ public:
             typename... ConstructorArgs,
             typename = std::enable_if_t<sizeof...(Ts) != 0>>
   OwningRewritePatternList &insert(ConstructorArg &&arg,
-                                   ConstructorArgs &&...args) {
+                                   ConstructorArgs &&... args) {
     // The following expands a call to emplace_back for each of the pattern
     // types 'Ts'. This magic is necessary due to a limitation in the places
     // that a parameter pack can be expanded in c++11.
@@ -761,7 +770,8 @@ public:
 
   /// Add an instance of each of the pattern types 'Ts'. Return a reference to
   /// `this` for chaining insertions.
-  template <typename... Ts> OwningRewritePatternList &insert() {
+  template <typename... Ts>
+  OwningRewritePatternList &insert() {
     (void)std::initializer_list<int>{0, (insertImpl<Ts>(), 0)...};
     return *this;
   }
@@ -785,16 +795,17 @@ private:
   /// chaining insertions.
   template <typename T, typename... Args>
   std::enable_if_t<std::is_base_of<RewritePattern, T>::value>
-  insertImpl(Args &&...args) {
+  insertImpl(Args &&... args) {
     nativePatterns.emplace_back(
         std::make_unique<T>(std::forward<Args>(args)...));
   }
   template <typename T, typename... Args>
   std::enable_if_t<std::is_base_of<PDLPatternModule, T>::value>
-  insertImpl(Args &&...args) {
+  insertImpl(Args &&... args) {
     pdlPatterns.mergeIn(T(std::forward<Args>(args)...));
   }
 
+  MLIRContext *const context;
   NativePatternListT nativePatterns;
   PDLPatternModule pdlPatterns;
 };
index 29e16c2..9f2c0e3 100644 (file)
@@ -56,8 +56,7 @@ void populateBufferizeMaterializationLegality(ConversionTarget &target);
 ///
 /// In particular, these are the tensor_load/buffer_cast ops.
 void populateEliminateBufferizeMaterializationsPatterns(
-    MLIRContext *context, BufferizeTypeConverter &typeConverter,
-    OwningRewritePatternList &patterns);
+    BufferizeTypeConverter &typeConverter, OwningRewritePatternList &patterns);
 
 } // end namespace mlir
 
index 5cc5d8a..b93fffa 100644 (file)
@@ -425,20 +425,18 @@ private:
 /// FunctionLike ops which use FunctionType to represent their type.
 void populateFunctionLikeTypeConversionPattern(
     StringRef functionLikeOpName, OwningRewritePatternList &patterns,
-    MLIRContext *ctx, TypeConverter &converter);
+    TypeConverter &converter);
 
 template <typename FuncOpT>
 void populateFunctionLikeTypeConversionPattern(
-    OwningRewritePatternList &patterns, MLIRContext *ctx,
-    TypeConverter &converter) {
+    OwningRewritePatternList &patterns, TypeConverter &converter) {
   populateFunctionLikeTypeConversionPattern(FuncOpT::getOperationName(),
-                                            patterns, ctx, converter);
+                                            patterns, converter);
 }
 
 /// Add a pattern to the given pattern list to convert the signature of a FuncOp
 /// with the given type converter.
 void populateFuncOpTypeConversionPattern(OwningRewritePatternList &patterns,
-                                         MLIRContext *ctx,
                                          TypeConverter &converter);
 
 //===----------------------------------------------------------------------===//
@@ -604,22 +602,26 @@ public:
 
   /// Register a legality action for the given operation.
   void setOpAction(OperationName op, LegalizationAction action);
-  template <typename OpT> void setOpAction(LegalizationAction action) {
+  template <typename OpT>
+  void setOpAction(LegalizationAction action) {
     setOpAction(OperationName(OpT::getOperationName(), &ctx), action);
   }
 
   /// Register the given operations as legal.
-  template <typename OpT> void addLegalOp() {
+  template <typename OpT>
+  void addLegalOp() {
     setOpAction<OpT>(LegalizationAction::Legal);
   }
-  template <typename OpT, typename OpT2, typename... OpTs> void addLegalOp() {
+  template <typename OpT, typename OpT2, typename... OpTs>
+  void addLegalOp() {
     addLegalOp<OpT>();
     addLegalOp<OpT2, OpTs...>();
   }
 
   /// Register the given operation as dynamically legal, i.e. requiring custom
   /// handling by the target via 'isDynamicallyLegal'.
-  template <typename OpT> void addDynamicallyLegalOp() {
+  template <typename OpT>
+  void addDynamicallyLegalOp() {
     setOpAction<OpT>(LegalizationAction::Dynamic);
   }
   template <typename OpT, typename OpT2, typename... OpTs>
@@ -651,10 +653,12 @@ public:
 
   /// Register the given operation as illegal, i.e. this operation is known to
   /// not be supported by this target.
-  template <typename OpT> void addIllegalOp() {
+  template <typename OpT>
+  void addIllegalOp() {
     setOpAction<OpT>(LegalizationAction::Illegal);
   }
-  template <typename OpT, typename OpT2, typename... OpTs> void addIllegalOp() {
+  template <typename OpT, typename OpT2, typename... OpTs>
+  void addIllegalOp() {
     addIllegalOp<OpT>();
     addIllegalOp<OpT2, OpTs...>();
   }
@@ -692,7 +696,8 @@ public:
     SmallVector<StringRef, 2> dialectNames({name, names...});
     setDialectAction(dialectNames, LegalizationAction::Legal);
   }
-  template <typename... Args> void addLegalDialect() {
+  template <typename... Args>
+  void addLegalDialect() {
     SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
     setDialectAction(dialectNames, LegalizationAction::Legal);
   }
@@ -736,7 +741,8 @@ public:
     SmallVector<StringRef, 2> dialectNames({name, names...});
     setDialectAction(dialectNames, LegalizationAction::Illegal);
   }
-  template <typename... Args> void addIllegalDialect() {
+  template <typename... Args>
+  void addIllegalDialect() {
     SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
     setDialectAction(dialectNames, LegalizationAction::Illegal);
   }
index de2e059..4c741d4 100644 (file)
@@ -747,7 +747,7 @@ public:
 } // end namespace
 
 void mlir::populateAffineToStdConversionPatterns(
-    OwningRewritePatternList &patterns, MLIRContext *ctx) {
+    OwningRewritePatternList &patterns) {
   // clang-format off
   patterns.insert<
       AffineApplyLowering,
@@ -761,25 +761,25 @@ void mlir::populateAffineToStdConversionPatterns(
       AffineStoreLowering,
       AffineForLowering,
       AffineIfLowering,
-      AffineYieldOpLowering>(ctx);
+      AffineYieldOpLowering>(patterns.getContext());
   // clang-format on
 }
 
 void mlir::populateAffineToVectorConversionPatterns(
-    OwningRewritePatternList &patterns, MLIRContext *ctx) {
+    OwningRewritePatternList &patterns) {
   // clang-format off
   patterns.insert<
       AffineVectorLoadLowering,
-      AffineVectorStoreLowering>(ctx);
+      AffineVectorStoreLowering>(patterns.getContext());
   // clang-format on
 }
 
 namespace {
 class LowerAffinePass : public ConvertAffineToStandardBase<LowerAffinePass> {
   void runOnOperation() override {
-    OwningRewritePatternList patterns;
-    populateAffineToStdConversionPatterns(patterns, &getContext());
-    populateAffineToVectorConversionPatterns(patterns, &getContext());
+    OwningRewritePatternList patterns(&getContext());
+    populateAffineToStdConversionPatterns(patterns);
+    populateAffineToVectorConversionPatterns(patterns);
     ConversionTarget target(getContext());
     target.addLegalDialect<memref::MemRefDialect, scf::SCFDialect,
                            StandardOpsDialect, VectorDialect>();
index 3fe1c7f..23a826a 100644 (file)
@@ -875,7 +875,7 @@ void ConvertAsyncToLLVMPass::runOnOperation() {
 
   // Convert async dialect types and operations to LLVM dialect.
   AsyncRuntimeTypeConverter converter;
-  OwningRewritePatternList patterns;
+  OwningRewritePatternList patterns(ctx);
 
   // We use conversion to LLVM type to lower async.runtime load and store
   // operations.
@@ -883,8 +883,8 @@ void ConvertAsyncToLLVMPass::runOnOperation() {
   llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes);
 
   // Convert async types in function signatures and function calls.
-  populateFuncOpTypeConversionPattern(patterns, ctx, converter);
-  populateCallOpTypeConversionPattern(patterns, ctx, converter);
+  populateFuncOpTypeConversionPattern(patterns, converter);
+  populateCallOpTypeConversionPattern(patterns, converter);
 
   // Convert return operations inside async.execute regions.
   patterns.insert<ReturnOpOpConversion>(converter, ctx);
@@ -985,8 +985,8 @@ std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() {
 }
 
 void mlir::populateAsyncStructuralTypeConversionsAndLegality(
-    MLIRContext *context, TypeConverter &typeConverter,
-    OwningRewritePatternList &patterns, ConversionTarget &target) {
+    TypeConverter &typeConverter, OwningRewritePatternList &patterns,
+    ConversionTarget &target) {
   typeConverter.addConversion([&](TokenType type) { return type; });
   typeConverter.addConversion([&](ValueType type) {
     return ValueType::get(typeConverter.convertType(type.getValueType()));
@@ -994,7 +994,7 @@ void mlir::populateAsyncStructuralTypeConversionsAndLegality(
 
   patterns
       .insert<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>(
-          typeConverter, context);
+          typeConverter, patterns.getContext());
 
   target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>(
       [&](Operation *op) { return typeConverter.isLegal(op); });
index 00ab637..71b2fc0 100644 (file)
@@ -284,7 +284,7 @@ void ConvertComplexToLLVMPass::runOnOperation() {
   auto module = getOperation();
 
   // Convert to the LLVM IR dialect using the converter defined above.
-  OwningRewritePatternList patterns;
+  OwningRewritePatternList patterns(&getContext());
   LLVMTypeConverter converter(&getContext());
   populateComplexToLLVMConversionPatterns(converter, patterns);
 
index d490c52..dde968c 100644 (file)
@@ -308,13 +308,13 @@ private:
 
 void GpuToLLVMConversionPass::runOnOperation() {
   LLVMTypeConverter converter(&getContext());
-  OwningRewritePatternList patterns;
+  OwningRewritePatternList patterns(&getContext());
   LLVMConversionTarget target(getContext());
 
   populateVectorToLLVMConversionPatterns(converter, patterns);
   populateStdToLLVMConversionPatterns(converter, patterns);
-  populateAsyncStructuralTypeConversionsAndLegality(&getContext(), converter,
-                                                    patterns, target);
+  populateAsyncStructuralTypeConversionsAndLegality(converter, patterns,
+                                                    target);
 
   converter.addConversion(
       [context = &converter.getContext()](gpu::AsyncTokenType type) -> Type {
index 9e16712..3a6548b 100644 (file)
@@ -125,12 +125,13 @@ struct LowerGpuOpsToNVVMOpsPass
       return converter.convertType(MemRefType::Builder(type).setMemorySpace(0));
     });
 
-    OwningRewritePatternList patterns, llvmPatterns;
+    OwningRewritePatternList patterns(m.getContext());
+    OwningRewritePatternList llvmPatterns(m.getContext());
 
     // Apply in-dialect lowering first. In-dialect lowering will replace ops
     // which need to be lowered further, which is not supported by a single
     // conversion pass.
-    populateGpuRewritePatterns(m.getContext(), patterns);
+    populateGpuRewritePatterns(patterns);
     (void)applyPatternsAndFoldGreedily(m, std::move(patterns));
 
     populateStdToLLVMConversionPatterns(converter, llvmPatterns);
index d61c047..21ae015 100644 (file)
@@ -60,9 +60,10 @@ struct LowerGpuOpsToROCDLOpsPass
                                   /*useAlignedAlloc =*/false};
     LLVMTypeConverter converter(m.getContext(), options);
 
-    OwningRewritePatternList patterns, llvmPatterns;
+    OwningRewritePatternList patterns(m.getContext());
+    OwningRewritePatternList llvmPatterns(m.getContext());
 
-    populateGpuRewritePatterns(m.getContext(), patterns);
+    populateGpuRewritePatterns(patterns);
     (void)applyPatternsAndFoldGreedily(m, std::move(patterns));
 
     populateVectorToLLVMConversionPatterns(converter, llvmPatterns);
index 1e0a766..2bb1543 100644 (file)
@@ -329,9 +329,9 @@ namespace {
 #include "GPUToSPIRV.cpp.inc"
 }
 
-void mlir::populateGPUToSPIRVPatterns(MLIRContext *context,
-                                      SPIRVTypeConverter &typeConverter,
+void mlir::populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                       OwningRewritePatternList &patterns) {
+  auto *context = patterns.getContext();
   populateWithGenerated(context, patterns);
   patterns.insert<
       GPUFuncOpConversion, GPUModuleConversion, GPUReturnOpConversion,
index 8edb42e..a8644c8 100644 (file)
@@ -57,9 +57,9 @@ void GPUToSPIRVPass::runOnOperation() {
       spirv::SPIRVConversionTarget::get(targetAttr);
 
   SPIRVTypeConverter typeConverter(targetAttr);
-  OwningRewritePatternList patterns;
-  populateGPUToSPIRVPatterns(context, typeConverter, patterns);
-  populateStandardToSPIRVPatterns(context, typeConverter, patterns);
+  OwningRewritePatternList patterns(context);
+  populateGPUToSPIRVPatterns(typeConverter, patterns);
+  populateStandardToSPIRVPatterns(typeConverter, patterns);
 
   if (failed(applyFullConversion(kernelModules, *target, std::move(patterns))))
     return signalPassFailure();
index 5c0eb5e..e49d6b8 100644 (file)
@@ -221,7 +221,7 @@ void ConvertLinalgToLLVMPass::runOnOperation() {
   auto module = getOperation();
 
   // Convert to the LLVM IR dialect using the converter defined above.
-  OwningRewritePatternList patterns;
+  OwningRewritePatternList patterns(&getContext());
   LLVMTypeConverter converter(&getContext());
   populateLinalgToLLVMConversionPatterns(converter, patterns);
 
index 0db760b..052dea4 100644 (file)
@@ -203,8 +203,8 @@ LogicalResult SingleWorkgroupReduction::matchAndRewrite(
 // Pattern population
 //===----------------------------------------------------------------------===//
 
-void mlir::populateLinalgToSPIRVPatterns(MLIRContext *context,
-                                         SPIRVTypeConverter &typeConverter,
+void mlir::populateLinalgToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                          OwningRewritePatternList &patterns) {
-  patterns.insert<SingleWorkgroupReduction>(typeConverter, context);
+  patterns.insert<SingleWorkgroupReduction>(typeConverter,
+                                            patterns.getContext());
 }
index ddcc97d..d9df551 100644 (file)
@@ -30,9 +30,9 @@ void LinalgToSPIRVPass::runOnOperation() {
       spirv::SPIRVConversionTarget::get(targetAttr);
 
   SPIRVTypeConverter typeConverter(targetAttr);
-  OwningRewritePatternList patterns;
-  populateLinalgToSPIRVPatterns(context, typeConverter, patterns);
-  populateBuiltinFuncToSPIRVPatterns(context, typeConverter, patterns);
+  OwningRewritePatternList patterns(context);
+  populateLinalgToSPIRVPatterns(typeConverter, patterns);
+  populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
 
   // Allow builtin ops.
   target->addLegalOp<ModuleOp, ModuleTerminatorOp>();
index bf947a4..ce4fe8a 100644 (file)
@@ -192,14 +192,14 @@ mlir::linalg::IndexedGenericOpToLibraryCallRewrite::matchAndRewrite(
 
 /// Populate the given list with patterns that convert from Linalg to Standard.
 void mlir::linalg::populateLinalgToStandardConversionPatterns(
-    OwningRewritePatternList &patterns, MLIRContext *ctx) {
+    OwningRewritePatternList &patterns) {
   // TODO: ConvOp conversion needs to export a descriptor with relevant
   // attribute values such as kernel striding and dilation.
   // clang-format off
   patterns.insert<
       CopyOpToLibraryCallRewrite,
       CopyTransposeRewrite,
-      IndexedGenericOpToLibraryCallRewrite>(ctx);
+      IndexedGenericOpToLibraryCallRewrite>(patterns.getContext());
   patterns.insert<LinalgOpToLibraryCallRewrite>();
   // clang-format on
 }
@@ -218,8 +218,8 @@ void ConvertLinalgToStandardPass::runOnOperation() {
                          StandardOpsDialect>();
   target.addLegalOp<ModuleOp, FuncOp, ModuleTerminatorOp, ReturnOp>();
   target.addLegalOp<linalg::ReshapeOp, linalg::RangeOp>();
-  OwningRewritePatternList patterns;
-  populateLinalgToStandardConversionPatterns(patterns, &getContext());
+  OwningRewritePatternList patterns(&getContext());
+  populateLinalgToStandardConversionPatterns(patterns);
   if (failed(applyFullConversion(module, target, std::move(patterns))))
     signalPassFailure();
 }
index 7bc5100..833d51f 100644 (file)
@@ -58,7 +58,7 @@ void ConvertOpenMPToLLVMPass::runOnOperation() {
   auto module = getOperation();
 
   // Convert to OpenMP operations with LLVM IR dialect
-  OwningRewritePatternList patterns;
+  OwningRewritePatternList patterns(&getContext());
   LLVMTypeConverter converter(&getContext());
   populateStdToLLVMConversionPatterns(converter, patterns);
   populateOpenMPToLLVMConversionPatterns(converter, patterns);
index 9f5e4ab..b9602dd 100644 (file)
@@ -642,9 +642,9 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,
   return success();
 }
 
-void mlir::populateParallelLoopToGPUPatterns(OwningRewritePatternList &patterns,
-                                             MLIRContext *ctx) {
-  patterns.insert<ParallelToGpuLaunchLowering>(ctx);
+void mlir::populateParallelLoopToGPUPatterns(
+    OwningRewritePatternList &patterns) {
+  patterns.insert<ParallelToGpuLaunchLowering>(patterns.getContext());
 }
 
 void mlir::configureParallelLoopToGPULegality(ConversionTarget &target) {
index 15075b5..a6ab449 100644 (file)
@@ -47,8 +47,8 @@ struct ForLoopMapper : public ConvertAffineForToGPUBase<ForLoopMapper> {
 struct ParallelLoopToGpuPass
     : public ConvertParallelLoopToGpuBase<ParallelLoopToGpuPass> {
   void runOnOperation() override {
-    OwningRewritePatternList patterns;
-    populateParallelLoopToGPUPatterns(patterns, &getContext());
+    OwningRewritePatternList patterns(&getContext());
+    populateParallelLoopToGPUPatterns(patterns);
     ConversionTarget target(getContext());
     target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
     configureParallelLoopToGPULegality(target);
index 3adb02a..46e67e5 100644 (file)
@@ -90,7 +90,7 @@ static LogicalResult applyPatterns(FuncOp func) {
       [](scf::YieldOp op) { return !isa<scf::ParallelOp>(op->getParentOp()); });
   target.addLegalDialect<omp::OpenMPDialect>();
 
-  OwningRewritePatternList patterns;
+  OwningRewritePatternList patterns(func.getContext());
   patterns.insert<ParallelOpLowering>(func.getContext());
   FrozenRewritePatternList frozen(std::move(patterns));
   return applyPartialConversion(func, target, frozen);
index 19837fe..344af68 100644 (file)
@@ -319,10 +319,9 @@ LogicalResult TerminatorOpConversion::matchAndRewrite(
 // Hooks
 //===----------------------------------------------------------------------===//
 
-void mlir::populateSCFToSPIRVPatterns(MLIRContext *context,
-                                      SPIRVTypeConverter &typeConverter,
+void mlir::populateSCFToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                       ScfToSPIRVContext &scfToSPIRVContext,
                                       OwningRewritePatternList &patterns) {
   patterns.insert<ForOpConversion, IfOpConversion, TerminatorOpConversion>(
-      context, typeConverter, scfToSPIRVContext.getImpl());
+      patterns.getContext(), typeConverter, scfToSPIRVContext.getImpl());
 }
index b0d8799..024ff2c 100644 (file)
@@ -37,10 +37,10 @@ void SCFToSPIRVPass::runOnOperation() {
 
   SPIRVTypeConverter typeConverter(targetAttr);
   ScfToSPIRVContext scfContext;
-  OwningRewritePatternList patterns;
-  populateSCFToSPIRVPatterns(context, typeConverter, scfContext, patterns);
-  populateStandardToSPIRVPatterns(context, typeConverter, patterns);
-  populateBuiltinFuncToSPIRVPatterns(context, typeConverter, patterns);
+  OwningRewritePatternList patterns(context);
+  populateSCFToSPIRVPatterns(typeConverter, scfContext, patterns);
+  populateStandardToSPIRVPatterns(typeConverter, patterns);
+  populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
 
   if (failed(applyPartialConversion(module, *target, std::move(patterns))))
     return signalPassFailure();
index b8f3140..5250d53 100644 (file)
@@ -569,15 +569,15 @@ DoWhileLowering::matchAndRewrite(WhileOp whileOp,
 }
 
 void mlir::populateLoopToStdConversionPatterns(
-    OwningRewritePatternList &patterns, MLIRContext *ctx) {
+    OwningRewritePatternList &patterns) {
   patterns.insert<ForLowering, IfLowering, ParallelLowering, WhileLowering>(
-      ctx);
-  patterns.insert<DoWhileLowering>(ctx, /*benefit=*/2);
+      patterns.getContext());
+  patterns.insert<DoWhileLowering>(patterns.getContext(), /*benefit=*/2);
 }
 
 void SCFToStandardPass::runOnOperation() {
-  OwningRewritePatternList patterns;
-  populateLoopToStdConversionPatterns(patterns, &getContext());
+  OwningRewritePatternList patterns(&getContext());
+  populateLoopToStdConversionPatterns(patterns);
   // Configure conversion to lower out scf.for, scf.if, scf.parallel and
   // scf.while. Anything else is fine.
   ConversionTarget target(getContext());
index d152a73..7f3752f 100644 (file)
@@ -278,7 +278,7 @@ public:
         /*emitCWrappers=*/true,
         /*indexBitwidth=*/kDeriveIndexBitwidthFromDataLayout};
     auto *context = module.getContext();
-    OwningRewritePatternList patterns;
+    OwningRewritePatternList patterns(context);
     LLVMTypeConverter typeConverter(context, options);
     populateStdToLLVMConversionPatterns(typeConverter, patterns);
     patterns.insert<GPULaunchLowering>(typeConverter);
index 3a139b4..6f6d56f 100644 (file)
@@ -1385,8 +1385,7 @@ void mlir::populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter) {
 }
 
 void mlir::populateSPIRVToLLVMConversionPatterns(
-    MLIRContext *context, LLVMTypeConverter &typeConverter,
-    OwningRewritePatternList &patterns) {
+    LLVMTypeConverter &typeConverter, OwningRewritePatternList &patterns) {
   patterns.insert<
       // Arithmetic ops
       DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
@@ -1496,20 +1495,18 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
       ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
 
       // Return ops
-      ReturnPattern, ReturnValuePattern>(context, typeConverter);
+      ReturnPattern, ReturnValuePattern>(patterns.getContext(), typeConverter);
 }
 
 void mlir::populateSPIRVToLLVMFunctionConversionPatterns(
-    MLIRContext *context, LLVMTypeConverter &typeConverter,
-    OwningRewritePatternList &patterns) {
-  patterns.insert<FuncConversionPattern>(context, typeConverter);
+    LLVMTypeConverter &typeConverter, OwningRewritePatternList &patterns) {
+  patterns.insert<FuncConversionPattern>(patterns.getContext(), typeConverter);
 }
 
 void mlir::populateSPIRVToLLVMModuleConversionPatterns(
-    MLIRContext *context, LLVMTypeConverter &typeConverter,
-    OwningRewritePatternList &patterns) {
+    LLVMTypeConverter &typeConverter, OwningRewritePatternList &patterns) {
   patterns.insert<ModuleConversionPattern, ModuleEndConversionPattern>(
-      context, typeConverter);
+      patterns.getContext(), typeConverter);
 }
 
 //===----------------------------------------------------------------------===//
index 2a4113f..a807b31 100644 (file)
@@ -36,15 +36,15 @@ void ConvertSPIRVToLLVMPass::runOnOperation() {
   // Encode global variable's descriptor set and binding if they exist.
   encodeBindAttribute(module);
 
-  OwningRewritePatternList patterns;
+  OwningRewritePatternList patterns(context);
 
   populateSPIRVToLLVMTypeConversion(converter);
 
-  populateSPIRVToLLVMModuleConversionPatterns(context, converter, patterns);
-  populateSPIRVToLLVMConversionPatterns(context, converter, patterns);
-  populateSPIRVToLLVMFunctionConversionPatterns(context, converter, patterns);
+  populateSPIRVToLLVMModuleConversionPatterns(converter, patterns);
+  populateSPIRVToLLVMConversionPatterns(converter, patterns);
+  populateSPIRVToLLVMFunctionConversionPatterns(converter, patterns);
 
-  ConversionTarget target(getContext());
+  ConversionTarget target(*context);
   target.addIllegalDialect<spirv::SPIRVDialect>();
   target.addLegalDialect<LLVM::LLVMDialect>();
 
index af97605..28697ba 100644 (file)
@@ -37,10 +37,10 @@ public:
 } // namespace
 
 void mlir::populateConvertShapeConstraintsConversionPatterns(
-    OwningRewritePatternList &patterns, MLIRContext *ctx) {
-  patterns.insert<CstrBroadcastableToRequire>(ctx);
-  patterns.insert<CstrEqToRequire>(ctx);
-  patterns.insert<ConvertCstrRequireOp>(ctx);
+    OwningRewritePatternList &patterns) {
+  patterns.insert<CstrBroadcastableToRequire>(patterns.getContext());
+  patterns.insert<CstrEqToRequire>(patterns.getContext());
+  patterns.insert<ConvertCstrRequireOp>(patterns.getContext());
 }
 
 namespace {
@@ -54,8 +54,8 @@ class ConvertShapeConstraints
     auto func = getOperation();
     auto *context = &getContext();
 
-    OwningRewritePatternList patterns;
-    populateConvertShapeConstraintsConversionPatterns(patterns, context);
+    OwningRewritePatternList patterns(context);
+    populateConvertShapeConstraintsConversionPatterns(patterns);
 
     if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns))))
       return signalPassFailure();
index 2c06702..048e352 100644 (file)
@@ -678,8 +678,8 @@ void ConvertShapeToStandardPass::runOnOperation() {
   target.addLegalOp<CstrRequireOp, FuncOp, ModuleOp, ModuleTerminatorOp>();
 
   // Setup conversion patterns.
-  OwningRewritePatternList patterns;
-  populateShapeToStandardConversionPatterns(patterns, &ctx);
+  OwningRewritePatternList patterns(&ctx);
+  populateShapeToStandardConversionPatterns(patterns);
 
   // Apply conversion.
   auto module = getOperation();
@@ -688,9 +688,9 @@ void ConvertShapeToStandardPass::runOnOperation() {
 }
 
 void mlir::populateShapeToStandardConversionPatterns(
-    OwningRewritePatternList &patterns, MLIRContext *ctx) {
+    OwningRewritePatternList &patterns) {
   // clang-format off
-  populateWithGenerated(ctx, patterns);
+  populateWithGenerated(patterns.getContext(), patterns);
   patterns.insert<
       AnyOpConversion,
       BinaryOpConversion<AddOp, AddIOp>,
@@ -705,7 +705,7 @@ void mlir::populateShapeToStandardConversionPatterns(
       ShapeEqOpConverter,
       ShapeOfOpConversion,
       SplitAtOpConversion,
-      ToExtentTensorOpConversion>(ctx);
+      ToExtentTensorOpConversion>(patterns.getContext());
   // clang-format on
 }
 
index 2490f35..63036c4 100644 (file)
@@ -4079,7 +4079,7 @@ struct LLVMLoweringPass : public ConvertStandardToLLVMBase<LLVMLoweringPass> {
                                   llvm::DataLayout(this->dataLayout)};
     LLVMTypeConverter typeConverter(&getContext(), options);
 
-    OwningRewritePatternList patterns;
+    OwningRewritePatternList patterns(&getContext());
     populateStdToLLVMConversionPatterns(typeConverter, patterns);
 
     LLVMConversionTarget target(getContext());
index 00bf6c0..57f1b17 100644 (file)
@@ -193,11 +193,12 @@ StoreOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy storeOp,
 //===----------------------------------------------------------------------===//
 
 void mlir::populateStdLegalizationPatternsForSPIRVLowering(
-    MLIRContext *context, OwningRewritePatternList &patterns) {
+    OwningRewritePatternList &patterns) {
   patterns.insert<LoadOpOfSubViewFolder<memref::LoadOp>,
                   LoadOpOfSubViewFolder<vector::TransferReadOp>,
                   StoreOpOfSubViewFolder<memref::StoreOp>,
-                  StoreOpOfSubViewFolder<vector::TransferWriteOp>>(context);
+                  StoreOpOfSubViewFolder<vector::TransferWriteOp>>(
+      patterns.getContext());
 }
 
 //===----------------------------------------------------------------------===//
@@ -212,9 +213,8 @@ struct SPIRVLegalization final
 } // namespace
 
 void SPIRVLegalization::runOnOperation() {
-  OwningRewritePatternList patterns;
-  auto *context = &getContext();
-  populateStdLegalizationPatternsForSPIRVLowering(context, patterns);
+  OwningRewritePatternList patterns(&getContext());
+  populateStdLegalizationPatternsForSPIRVLowering(patterns);
   (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(),
                                      std::move(patterns));
 }
index 025029a..8552db4 100644 (file)
@@ -1224,9 +1224,10 @@ XOrOpPattern::matchAndRewrite(XOrOp xorOp, ArrayRef<Value> operands,
 //===----------------------------------------------------------------------===//
 
 namespace mlir {
-void populateStandardToSPIRVPatterns(MLIRContext *context,
-                                     SPIRVTypeConverter &typeConverter,
+void populateStandardToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                      OwningRewritePatternList &patterns) {
+  MLIRContext *context = patterns.getContext();
+
   patterns.insert<
       // Math dialect operations.
       // TODO: Move to separate pass.
@@ -1293,11 +1294,10 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
                                           /*benefit=*/2);
 }
 
-void populateTensorToSPIRVPatterns(MLIRContext *context,
-                                   SPIRVTypeConverter &typeConverter,
+void populateTensorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                    int64_t byteCountThreshold,
                                    OwningRewritePatternList &patterns) {
-  patterns.insert<TensorExtractPattern>(typeConverter, context,
+  patterns.insert<TensorExtractPattern>(typeConverter, patterns.getContext(),
                                         byteCountThreshold);
 }
 
index ce8419b..a1c6f98 100644 (file)
@@ -35,11 +35,11 @@ void ConvertStandardToSPIRVPass::runOnOperation() {
       spirv::SPIRVConversionTarget::get(targetAttr);
 
   SPIRVTypeConverter typeConverter(targetAttr);
-  OwningRewritePatternList patterns;
-  populateStandardToSPIRVPatterns(context, typeConverter, patterns);
-  populateTensorToSPIRVPatterns(context, typeConverter,
+  OwningRewritePatternList patterns(context);
+  populateStandardToSPIRVPatterns(typeConverter, patterns);
+  populateTensorToSPIRVPatterns(typeConverter,
                                 /*byteCountThreshold=*/64, patterns);
-  populateBuiltinFuncToSPIRVPatterns(context, typeConverter, patterns);
+  populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
 
   if (failed(applyPartialConversion(module, *target, std::move(patterns))))
     return signalPassFailure();
index fc83116..698fb5a 100644 (file)
@@ -989,7 +989,7 @@ public:
 } // namespace
 
 void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
-    MLIRContext *context, OwningRewritePatternList *patterns) {
+    OwningRewritePatternList *patterns) {
   patterns->insert<
       PointwiseConverter<tosa::AddOp>, PointwiseConverter<tosa::SubOp>,
       PointwiseConverter<tosa::MulOp>, PointwiseConverter<tosa::NegateOp>,
@@ -1014,5 +1014,6 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
       IdentityNConverter<tosa::IdentityNOp>, ReduceConverter<tosa::ReduceMinOp>,
       ReduceConverter<tosa::ReduceMaxOp>, ReduceConverter<tosa::ReduceSumOp>,
       ReduceConverter<tosa::ReduceProdOp>, ConcatConverter, ReshapeConverter,
-      RescaleConverter, ReverseConverter, TransposeConverter>(context);
+      RescaleConverter, ReverseConverter, TransposeConverter>(
+      patterns->getContext());
 }
index e0f1369..7d6815e 100644 (file)
@@ -37,7 +37,7 @@ public:
   }
 
   void runOnFunction() override {
-    OwningRewritePatternList patterns;
+    OwningRewritePatternList patterns(&getContext());
     ConversionTarget target(getContext());
     target.addLegalDialect<linalg::LinalgDialect, memref::MemRefDialect,
                            StandardOpsDialect>();
@@ -52,8 +52,7 @@ public:
     target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
 
     FuncOp func = getFunction();
-    mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
-        func.getContext(), &patterns);
+    mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(&patterns);
     if (failed(applyFullConversion(func, target, std::move(patterns))))
       signalPassFailure();
   }
index 55ed64b..4fb06d1 100644 (file)
@@ -103,7 +103,7 @@ public:
 } // namespace
 
 void mlir::tosa::populateTosaToSCFConversionPatterns(
-    MLIRContext *context, OwningRewritePatternList *patterns) {
-  patterns->insert<IfOpConverter>(context);
-  patterns->insert<WhileOpConverter>(context);
+    OwningRewritePatternList *patterns) {
+  patterns->insert<IfOpConverter>(patterns->getContext());
+  patterns->insert<WhileOpConverter>(patterns->getContext());
 }
index f403a46..9b562fa 100644 (file)
@@ -29,15 +29,14 @@ namespace {
 struct TosaToSCF : public TosaToSCFBase<TosaToSCF> {
 public:
   void runOnOperation() override {
-    OwningRewritePatternList patterns;
+    OwningRewritePatternList patterns(&getContext());
     ConversionTarget target(getContext());
     target.addLegalDialect<tensor::TensorDialect, scf::SCFDialect>();
     target.addIllegalOp<tosa::IfOp, tosa::WhileOp>();
     target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
 
     auto *op = getOperation();
-    mlir::tosa::populateTosaToSCFConversionPatterns(op->getContext(),
-                                                    &patterns);
+    mlir::tosa::populateTosaToSCFConversionPatterns(&patterns);
     if (failed(applyPartialConversion(op, target, std::move(patterns))))
       signalPassFailure();
   }
index 95f5c51..8db7868 100644 (file)
@@ -154,12 +154,12 @@ public:
 } // namespace
 
 void mlir::tosa::populateTosaToStandardConversionPatterns(
-    MLIRContext *context, OwningRewritePatternList *patterns) {
+    OwningRewritePatternList *patterns) {
   patterns->insert<ApplyScaleOpConverter, ConstOpConverter, SliceOpConverter>(
-      context);
+      patterns->getContext());
 }
 
 void mlir::tosa::populateTosaRescaleToStandardConversionPatterns(
-    MLIRContext *context, OwningRewritePatternList *patterns) {
-  patterns->insert<ApplyScaleOpConverter>(context);
+    OwningRewritePatternList *patterns) {
+  patterns->insert<ApplyScaleOpConverter>(patterns->getContext());
 }
index 14c800e..de8768b 100644 (file)
@@ -29,17 +29,16 @@ namespace {
 struct TosaToStandard : public TosaToStandardBase<TosaToStandard> {
 public:
   void runOnOperation() override {
-    OwningRewritePatternList patterns;
+    OwningRewritePatternList patterns(&getContext());
     ConversionTarget target(getContext());
     target.addIllegalOp<tosa::ConstOp>();
     target.addIllegalOp<tosa::SliceOp>();
     target.addIllegalOp<tosa::ApplyScaleOp>();
     target.addLegalDialect<StandardOpsDialect>();
 
-    auto *op = getOperation();
-    mlir::tosa::populateTosaToStandardConversionPatterns(op->getContext(),
-                                                         &patterns);
-    if (failed(applyPartialConversion(op, target, std::move(patterns))))
+    mlir::tosa::populateTosaToStandardConversionPatterns(&patterns);
+    if (failed(applyPartialConversion(getOperation(), target,
+                                      std::move(patterns))))
       signalPassFailure();
   }
 };
index 8565774..b8c43c8 100644 (file)
@@ -61,16 +61,16 @@ void LowerVectorToLLVMPass::runOnOperation() {
   // Perform progressive lowering of operations on slices and
   // all contraction operations. Also applies folding and DCE.
   {
-    OwningRewritePatternList patterns;
-    populateVectorToVectorCanonicalizationPatterns(patterns, &getContext());
-    populateVectorSlicesLoweringPatterns(patterns, &getContext());
-    populateVectorContractLoweringPatterns(patterns, &getContext());
+    OwningRewritePatternList patterns(&getContext());
+    populateVectorToVectorCanonicalizationPatterns(patterns);
+    populateVectorSlicesLoweringPatterns(patterns);
+    populateVectorContractLoweringPatterns(patterns);
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }
 
   // Convert to the LLVM IR dialect.
   LLVMTypeConverter converter(&getContext());
-  OwningRewritePatternList patterns;
+  OwningRewritePatternList patterns(&getContext());
   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
   populateVectorToLLVMConversionPatterns(
       converter, patterns, reassociateFPReductions, enableIndexOptimizations);
@@ -98,7 +98,7 @@ void LowerVectorToLLVMPass::runOnOperation() {
       return false;
     };
     // Remove any ArmSVE-specific types from function signatures and results.
-    populateFuncOpTypeConversionPattern(patterns, &getContext(), converter);
+    populateFuncOpTypeConversionPattern(patterns, converter);
     target.addDynamicallyLegalOp<FuncOp>([hasScalableVectorType](FuncOp op) {
       return !hasScalableVectorType(op.getType().getInputs()) &&
              !hasScalableVectorType(op.getType().getResults());
index 42c0726..4b097c5 100644 (file)
@@ -158,7 +158,7 @@ struct LowerVectorToROCDLPass
 
 void LowerVectorToROCDLPass::runOnOperation() {
   LLVMTypeConverter converter(&getContext());
-  OwningRewritePatternList patterns;
+  OwningRewritePatternList patterns(&getContext());
 
   populateVectorToROCDLConversionPatterns(converter, patterns);
   populateStdToLLVMConversionPatterns(converter, patterns);
index dce5b64..3c7c457 100644 (file)
@@ -694,11 +694,11 @@ LogicalResult VectorTransferRewriter<TransferWriteOp>::matchAndRewrite(
 }
 
 void populateVectorToSCFConversionPatterns(
-    OwningRewritePatternList &patterns, MLIRContext *context,
+    OwningRewritePatternList &patterns,
     const VectorTransferToSCFOptions &options) {
   patterns.insert<VectorTransferRewriter<vector::TransferReadOp>,
-                  VectorTransferRewriter<vector::TransferWriteOp>>(options,
-                                                                   context);
+                  VectorTransferRewriter<vector::TransferWriteOp>>(
+      options, patterns.getContext());
 }
 
 } // namespace mlir
@@ -713,10 +713,9 @@ struct ConvertVectorToSCFPass
   }
 
   void runOnFunction() override {
-    OwningRewritePatternList patterns;
-    auto *context = getFunction().getContext();
+    OwningRewritePatternList patterns(getFunction().getContext());
     populateVectorToSCFConversionPatterns(
-        patterns, context, VectorTransferToSCFOptions().setUnroll(fullUnroll));
+        patterns, VectorTransferToSCFOptions().setUnroll(fullUnroll));
     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
   }
 };
index 8d4fcba..2d8ffc0 100644 (file)
@@ -241,12 +241,12 @@ struct VectorInsertStridedSliceOpConvert final
 
 } // namespace
 
-void mlir::populateVectorToSPIRVPatterns(MLIRContext *context,
-                                         SPIRVTypeConverter &typeConverter,
+void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                          OwningRewritePatternList &patterns) {
   patterns.insert<VectorBitcastConvert, VectorBroadcastConvert,
                   VectorExtractElementOpConvert, VectorExtractOpConvert,
                   VectorExtractStridedSliceOpConvert, VectorFmaOpConvert,
                   VectorInsertElementOpConvert, VectorInsertOpConvert,
-                  VectorInsertStridedSliceOpConvert>(typeConverter, context);
+                  VectorInsertStridedSliceOpConvert>(typeConverter,
+                                                     patterns.getContext());
 }
index 9a4d09f..b3c6384 100644 (file)
@@ -37,8 +37,8 @@ void LowerVectorToSPIRVPass::runOnOperation() {
       spirv::SPIRVConversionTarget::get(targetAttr);
 
   SPIRVTypeConverter typeConverter(targetAttr);
-  OwningRewritePatternList patterns;
-  populateVectorToSPIRVPatterns(context, typeConverter, patterns);
+  OwningRewritePatternList patterns(context);
+  populateVectorToSPIRVPatterns(typeConverter, patterns);
 
   target->addLegalOp<ModuleOp, ModuleTerminatorOp>();
   target->addLegalOp<FuncOp>();
index e3834ea..62cad1f 100644 (file)
@@ -227,7 +227,7 @@ void AffineDataCopyGeneration::runOnFunction() {
   // Promoting single iteration loops could lead to simplification of
   // contained load's/store's, and the latter could anyway also be
   // canonicalized.
-  OwningRewritePatternList patterns;
+  OwningRewritePatternList patterns(&getContext());
   AffineLoadOp::getCanonicalizationPatterns(patterns, &getContext());
   AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext());
   FrozenRewritePatternList frozenPatterns(std::move(patterns));
index 918fec4..512ecd6 100644 (file)
@@ -79,7 +79,7 @@ mlir::createSimplifyAffineStructuresPass() {
 void SimplifyAffineStructures::runOnFunction() {
   auto func = getFunction();
   simplifiedAttributes.clear();
-  OwningRewritePatternList patterns;
+  OwningRewritePatternList patterns(func.getContext());
   AffineForOp::getCanonicalizationPatterns(patterns, func.getContext());
   AffineIfOp::getCanonicalizationPatterns(patterns, func.getContext());
   AffineApplyOp::getCanonicalizationPatterns(patterns, func.getContext());
index acd854d..12d3a73 100644 (file)
@@ -188,7 +188,7 @@ LogicalResult mlir::hoistAffineIfOp(AffineIfOp ifOp, bool *folded) {
   // effective (no unused operands). Since the pattern rewriter's folding is
   // entangled with application of patterns, we may fold/end up erasing the op,
   // in which case we return with `folded` being set.
-  OwningRewritePatternList patterns;
+  OwningRewritePatternList patterns(ifOp.getContext());
   AffineIfOp::getCanonicalizationPatterns(patterns, ifOp.getContext());
   bool erased;
   FrozenRewritePatternList frozenPatterns(std::move(patterns));
index f4f6e0b..cb124e3 100644 (file)
@@ -270,7 +270,7 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
 void AsyncParallelForPass::runOnFunction() {
   MLIRContext *ctx = &getContext();
 
-  OwningRewritePatternList patterns;
+  OwningRewritePatternList patterns(ctx);
   patterns.insert<AsyncParallelForRewrite>(ctx, numConcurrentAsyncExecute);
 
   if (failed(applyPatternsAndFoldGreedily(getFunction(), std::move(patterns))))
index a17da42..99cc0b0 100644 (file)
@@ -485,7 +485,7 @@ void AsyncToAsyncRuntimePass::runOnOperation() {
 
   // Lower async operations to async.runtime operations.
   MLIRContext *ctx = module->getContext();
-  OwningRewritePatternList asyncPatterns;
+  OwningRewritePatternList asyncPatterns(ctx);
 
   // Async lowering does not use type converter because it must preserve all
   // types for async.runtime operations.
index 8e9ec0b..3e4189d 100644 (file)
@@ -401,7 +401,6 @@ struct GpuAllReduceConversion : public RewritePattern {
 };
 } // namespace
 
-void mlir::populateGpuAllReducePatterns(MLIRContext *context,
-                                        OwningRewritePatternList &patterns) {
-  patterns.insert<GpuAllReduceConversion>(context);
+void mlir::populateGpuAllReducePatterns(OwningRewritePatternList &patterns) {
+  patterns.insert<GpuAllReduceConversion>(patterns.getContext());
 }
index 419226b..df195af 100644 (file)
@@ -323,8 +323,8 @@ struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> {
     target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOperation);
     target.addDynamicallyLegalOp<ConstantOp>(isLegalOperation);
 
-    OwningRewritePatternList patterns;
-    populateLinalgBufferizePatterns(&context, typeConverter, patterns);
+    OwningRewritePatternList patterns(&context);
+    populateLinalgBufferizePatterns(typeConverter, patterns);
     if (failed(applyPartialConversion(getOperation(), target,
                                       std::move(patterns))))
       signalPassFailure();
@@ -337,8 +337,7 @@ std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgBufferizePass() {
 }
 
 void mlir::linalg::populateLinalgBufferizePatterns(
-    MLIRContext *context, BufferizeTypeConverter &typeConverter,
-    OwningRewritePatternList &patterns) {
+    BufferizeTypeConverter &typeConverter, OwningRewritePatternList &patterns) {
   patterns.insert<BufferizeAnyLinalgOp>(typeConverter);
   // TODO: Drop this once tensor constants work in standard.
   // clang-format off
@@ -347,6 +346,6 @@ void mlir::linalg::populateLinalgBufferizePatterns(
       BufferizeInitTensorOp,
       SubTensorOpConverter,
       SubTensorInsertOpConverter
-    >(typeConverter, context);
+    >(typeConverter, patterns.getContext());
   // clang-format on
 }
index cd7b481..a7e1332 100644 (file)
@@ -76,7 +76,7 @@ void mlir::linalg::CodegenStrategy::transform(FuncOp func) const {
 
   // Programmatic splitting of slow/fast path vector transfers.
   if (lateCodegenStrategyOptions.enableVectorTransferPartialRewrite) {
-    OwningRewritePatternList patterns;
+    OwningRewritePatternList patterns(context);
     patterns.insert<vector::VectorTransferFullPartialRewriter>(
         context, vectorTransformsOptions);
     (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
@@ -84,7 +84,7 @@ void mlir::linalg::CodegenStrategy::transform(FuncOp func) const {
 
   // Programmatic controlled lowering of vector.contract only.
   if (lateCodegenStrategyOptions.enableVectorContractLowering) {
-    OwningRewritePatternList vectorContractLoweringPatterns;
+    OwningRewritePatternList vectorContractLoweringPatterns(context);
     vectorContractLoweringPatterns
         .insert<ContractionOpToOuterProductOpLowering,
                 ContractionOpToMatmulOpLowering, ContractionOpLowering>(
@@ -95,8 +95,8 @@ void mlir::linalg::CodegenStrategy::transform(FuncOp func) const {
 
   // Programmatic controlled lowering of vector.transfer only.
   if (lateCodegenStrategyOptions.enableVectorToSCFConversion) {
-    OwningRewritePatternList vectorToLoopsPatterns;
-    populateVectorToSCFConversionPatterns(vectorToLoopsPatterns, context,
+    OwningRewritePatternList vectorToLoopsPatterns(context);
+    populateVectorToSCFConversionPatterns(vectorToLoopsPatterns,
                                           vectorToSCFOptions);
     (void)applyPatternsAndFoldGreedily(func, std::move(vectorToLoopsPatterns));
   }
index 2d34468..cc95218 100644 (file)
@@ -163,7 +163,7 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
   void runOnFunction() override {
     auto *context = &getContext();
     DetensorizeTypeConverter typeConverter;
-    OwningRewritePatternList patterns;
+    OwningRewritePatternList patterns(context);
     ConversionTarget target(*context);
 
     target.addDynamicallyLegalOp<GenericOp>([&](GenericOp op) {
@@ -199,13 +199,12 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
                                                      context, typeConverter);
     // Since non-entry block arguments get detensorized, we also need to update
     // the control flow inside the function to reflect the correct types.
-    populateBranchOpInterfaceTypeConversionPattern(patterns, context,
-                                                   typeConverter);
+    populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter);
 
     if (failed(applyFullConversion(getFunction(), target, std::move(patterns))))
       signalPassFailure();
 
-    OwningRewritePatternList canonPatterns;
+    OwningRewritePatternList canonPatterns(context);
     canonPatterns.insert<ExtractFromReshapeFromElements>(context);
     if (failed(applyPatternsAndFoldGreedily(getFunction(),
                                             std::move(canonPatterns))))
index c7b7640..a8db840 100644 (file)
@@ -490,14 +490,15 @@ struct FoldReshapeOpWithUnitExtent : OpRewritePattern<TensorReshapeOp> {
 /// Patterns that are used to canonicalize the use of unit-extent dims for
 /// broadcasting.
 void mlir::populateLinalgFoldUnitExtentDimsPatterns(
-    MLIRContext *context, OwningRewritePatternList &patterns) {
+    OwningRewritePatternList &patterns) {
+  auto *context = patterns.getContext();
   patterns
       .insert<FoldUnitDimLoops<GenericOp>, FoldUnitDimLoops<IndexedGenericOp>,
               ReplaceUnitExtentTensors<GenericOp>,
               ReplaceUnitExtentTensors<IndexedGenericOp>>(context);
   TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
   patterns.insert<FoldReshapeOpWithUnitExtent>(context);
-  populateFoldUnitDimsReshapeOpsByLinearizationPatterns(context, patterns);
+  populateFoldUnitDimsReshapeOpsByLinearizationPatterns(patterns);
 }
 
 namespace {
@@ -505,14 +506,14 @@ namespace {
 struct LinalgFoldUnitExtentDimsPass
     : public LinalgFoldUnitExtentDimsBase<LinalgFoldUnitExtentDimsPass> {
   void runOnFunction() override {
-    OwningRewritePatternList patterns;
     FuncOp funcOp = getFunction();
     MLIRContext *context = funcOp.getContext();
+    OwningRewritePatternList patterns(context);
     if (foldOneTripLoopsOnly)
       patterns.insert<FoldUnitDimLoops<GenericOp>,
                       FoldUnitDimLoops<IndexedGenericOp>>(context);
     else
-      populateLinalgFoldUnitExtentDimsPatterns(context, patterns);
+      populateLinalgFoldUnitExtentDimsPatterns(patterns);
     (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
   }
 };
index 1d50e06..48677df 100644 (file)
@@ -116,7 +116,7 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
 } // namespace
 
 void mlir::populateElementwiseToLinalgConversionPatterns(
-    OwningRewritePatternList &patterns, MLIRContext *) {
+    OwningRewritePatternList &patterns) {
   patterns.insert<ConvertAnyElementwiseMappableOpOnRankedTensors>();
 }
 
@@ -128,9 +128,9 @@ class ConvertElementwiseToLinalgPass
     auto func = getOperation();
     auto *context = &getContext();
     ConversionTarget target(*context);
-    OwningRewritePatternList patterns;
+    OwningRewritePatternList patterns(context);
 
-    populateElementwiseToLinalgConversionPatterns(patterns, context);
+    populateElementwiseToLinalgConversionPatterns(patterns);
     target.markUnknownOpDynamicallyLegal([](Operation *op) {
       return !isElementwiseMappableOpOnRankedTensors(op);
     });
index ad7ad11..a61102d 100644 (file)
@@ -1112,9 +1112,9 @@ struct FuseTensorOps : public OpRewritePattern<LinalgOpTy> {
 struct FusionOfTensorOpsPass
     : public LinalgFusionOfTensorOpsBase<FusionOfTensorOpsPass> {
   void runOnOperation() override {
-    OwningRewritePatternList patterns;
     Operation *op = getOperation();
-    populateLinalgTensorOpsFusionPatterns(op->getContext(), patterns);
+    OwningRewritePatternList patterns(op->getContext());
+    populateLinalgTensorOpsFusionPatterns(patterns);
     (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
   }
 };
@@ -1125,9 +1125,9 @@ struct FoldReshapeOpsByLinearizationPass
     : public LinalgFoldReshapeOpsByLinearizationBase<
           FoldReshapeOpsByLinearizationPass> {
   void runOnOperation() override {
-    OwningRewritePatternList patterns;
     Operation *op = getOperation();
-    populateFoldReshapeOpsByLinearizationPatterns(op->getContext(), patterns);
+    OwningRewritePatternList patterns(op->getContext());
+    populateFoldReshapeOpsByLinearizationPatterns(patterns);
     (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
   }
 };
@@ -1135,33 +1135,36 @@ struct FoldReshapeOpsByLinearizationPass
 } // namespace
 
 void mlir::populateFoldReshapeOpsByLinearizationPatterns(
-    MLIRContext *context, OwningRewritePatternList &patterns) {
+    OwningRewritePatternList &patterns) {
   patterns.insert<FoldProducerReshapeOpByLinearization<GenericOp, false>,
                   FoldProducerReshapeOpByLinearization<IndexedGenericOp, false>,
-                  FoldConsumerReshapeOpByLinearization<false>>(context);
+                  FoldConsumerReshapeOpByLinearization<false>>(
+      patterns.getContext());
 }
 
 void mlir::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
-    MLIRContext *context, OwningRewritePatternList &patterns) {
+    OwningRewritePatternList &patterns) {
   patterns.insert<FoldProducerReshapeOpByLinearization<GenericOp, true>,
                   FoldProducerReshapeOpByLinearization<IndexedGenericOp, true>,
-                  FoldConsumerReshapeOpByLinearization<true>>(context);
+                  FoldConsumerReshapeOpByLinearization<true>>(
+      patterns.getContext());
 }
 
 void mlir::populateFoldReshapeOpsByExpansionPatterns(
-    MLIRContext *context, OwningRewritePatternList &patterns) {
+    OwningRewritePatternList &patterns) {
   patterns.insert<FoldReshapeWithGenericOpByExpansion,
                   FoldWithProducerReshapeOpByExpansion<GenericOp>,
                   FoldWithProducerReshapeOpByExpansion<IndexedGenericOp>>(
-      context);
+      patterns.getContext());
 }
 
 void mlir::populateLinalgTensorOpsFusionPatterns(
-    MLIRContext *context, OwningRewritePatternList &patterns) {
+    OwningRewritePatternList &patterns) {
+  auto *context = patterns.getContext();
   patterns.insert<FuseTensorOps<GenericOp>, FuseTensorOps<IndexedGenericOp>,
                   FoldSplatConstants<GenericOp>,
                   FoldSplatConstants<IndexedGenericOp>>(context);
-  populateFoldReshapeOpsByExpansionPatterns(context, patterns);
+  populateFoldReshapeOpsByExpansionPatterns(patterns);
   GenericOp::getCanonicalizationPatterns(patterns, context);
   IndexedGenericOp::getCanonicalizationPatterns(patterns, context);
   TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
index 69de55c..3783ef5 100644 (file)
@@ -143,9 +143,9 @@ struct LinalgGeneralizationPass
 
 void LinalgGeneralizationPass::runOnFunction() {
   FuncOp func = getFunction();
-  OwningRewritePatternList patterns;
-  linalg::populateLinalgConvGeneralizationPatterns(&getContext(), patterns);
-  linalg::populateLinalgNamedOpsGeneralizationPatterns(&getContext(), patterns);
+  OwningRewritePatternList patterns(&getContext());
+  linalg::populateLinalgConvGeneralizationPatterns(patterns);
+  linalg::populateLinalgNamedOpsGeneralizationPatterns(patterns);
   (void)applyPatternsAndFoldGreedily(func.getBody(), std::move(patterns));
 }
 
@@ -167,15 +167,16 @@ linalg::GenericOp GeneralizeConvOp::createGenericOp(linalg::ConvOp convOp,
 }
 
 void mlir::linalg::populateLinalgConvGeneralizationPatterns(
-    MLIRContext *context, OwningRewritePatternList &patterns,
+    OwningRewritePatternList &patterns,
     linalg::LinalgTransformationFilter marker) {
-  patterns.insert<GeneralizeConvOp>(context, marker);
+  patterns.insert<GeneralizeConvOp>(patterns.getContext(), marker);
 }
 
 void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns(
-    MLIRContext *context, OwningRewritePatternList &patterns,
+    OwningRewritePatternList &patterns,
     linalg::LinalgTransformationFilter marker) {
-  patterns.insert<LinalgNamedOpGeneralizationPattern>(context, marker);
+  patterns.insert<LinalgNamedOpGeneralizationPattern>(patterns.getContext(),
+                                                      marker);
 }
 
 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgGeneralizationPass() {
index cc0cce7..635855f 100644 (file)
@@ -378,7 +378,7 @@ void mlir::linalg::hoistRedundantVectorTransfersOnTensor(FuncOp func) {
     // Apply canonicalization so the newForOp + yield folds immediately, thus
     // cleaning up the IR and potentially enabling more hoisting.
     if (changed) {
-      OwningRewritePatternList patterns;
+      OwningRewritePatternList patterns(func->getContext());
       scf::ForOp::getCanonicalizationPatterns(patterns, func->getContext());
       (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
     }
index d6423f4..10b4cac 100644 (file)
@@ -545,7 +545,7 @@ template <typename LoopType>
 static void lowerLinalgToLoopsImpl(FuncOp funcOp,
                                    ArrayRef<unsigned> interchangeVector) {
   MLIRContext *context = funcOp.getContext();
-  OwningRewritePatternList patterns;
+  OwningRewritePatternList patterns(context);
   patterns.insert<LinalgRewritePattern<LoopType>>(interchangeVector);
   memref::DimOp::getCanonicalizationPatterns(patterns, context);
   AffineApplyOp::getCanonicalizationPatterns(patterns, context);
index d9c2580..1fc82d5 100644 (file)
@@ -137,8 +137,8 @@ public:
 /// Populates the given patterns list with conversion rules required for
 /// the sparsification of linear algebra operations.
 void linalg::populateSparsificationConversionPatterns(
-    MLIRContext *context, OwningRewritePatternList &patterns) {
+    OwningRewritePatternList &patterns) {
   patterns.insert<TensorFromPointerConverter, TensorToDimSizeConverter,
                   TensorToPointersConverter, TensorToIndicesConverter,
-                  TensorToValuesConverter>(context);
+                  TensorToValuesConverter>(patterns.getContext());
 }
index a940bd6..c740241 100644 (file)
@@ -1361,7 +1361,6 @@ private:
 /// Populates the given patterns list with rewriting rules required for
 /// the sparsification of linear algebra operations.
 void linalg::populateSparsificationPatterns(
-    MLIRContext *context, OwningRewritePatternList &patterns,
-    const SparsificationOptions &options) {
-  patterns.insert<GenericOpSparsifier>(context, options);
+    OwningRewritePatternList &patterns, const SparsificationOptions &options) {
+  patterns.insert<GenericOpSparsifier>(patterns.getContext(), options);
 }
index d638c60..3f4c698 100644 (file)
@@ -511,15 +511,15 @@ class CanonicalizationPatternList;
 template <>
 class CanonicalizationPatternList<> {
 public:
-  static void insert(OwningRewritePatternList &patterns, MLIRContext *ctx) {}
+  static void insert(OwningRewritePatternList &patterns) {}
 };
 
 template <typename OpTy, typename... OpTypes>
 class CanonicalizationPatternList<OpTy, OpTypes...> {
 public:
-  static void insert(OwningRewritePatternList &patterns, MLIRContext *ctx) {
-    OpTy::getCanonicalizationPatterns(patterns, ctx);
-    CanonicalizationPatternList<OpTypes...>::insert(patterns, ctx);
+  static void insert(OwningRewritePatternList &patterns) {
+    OpTy::getCanonicalizationPatterns(patterns, patterns.getContext());
+    CanonicalizationPatternList<OpTypes...>::insert(patterns);
   }
 };
 
@@ -531,32 +531,34 @@ template <>
 class RewritePatternList<> {
 public:
   static void insert(OwningRewritePatternList &patterns,
-                     const LinalgTilingOptions &options, MLIRContext *ctx) {}
+                     const LinalgTilingOptions &options) {}
 };
 
 template <typename OpTy, typename... OpTypes>
 class RewritePatternList<OpTy, OpTypes...> {
 public:
   static void insert(OwningRewritePatternList &patterns,
-                     const LinalgTilingOptions &options, MLIRContext *ctx) {
+                     const LinalgTilingOptions &options) {
+    auto *ctx = patterns.getContext();
     patterns.insert<LinalgTilingPattern<OpTy>>(
         ctx, options,
         LinalgTransformationFilter(ArrayRef<Identifier>{},
                                    Identifier::get("tiled", ctx)));
-    RewritePatternList<OpTypes...>::insert(patterns, options, ctx);
+    RewritePatternList<OpTypes...>::insert(patterns, options);
   }
 };
 } // namespace
 
 OwningRewritePatternList
 mlir::linalg::getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx) {
-  OwningRewritePatternList patterns;
-  populateLinalgTilingCanonicalizationPatterns(patterns, ctx);
+  OwningRewritePatternList patterns(ctx);
+  populateLinalgTilingCanonicalizationPatterns(patterns);
   return patterns;
 }
 
 void mlir::linalg::populateLinalgTilingCanonicalizationPatterns(
-    OwningRewritePatternList &patterns, MLIRContext *ctx) {
+    OwningRewritePatternList &patterns) {
+  auto *ctx = patterns.getContext();
   AffineApplyOp::getCanonicalizationPatterns(patterns, ctx);
   AffineForOp::getCanonicalizationPatterns(patterns, ctx);
   AffineMinOp::getCanonicalizationPatterns(patterns, ctx);
@@ -571,17 +573,16 @@ void mlir::linalg::populateLinalgTilingCanonicalizationPatterns(
   CanonicalizationPatternList<
 #define GET_OP_LIST
 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
-      >::insert(patterns, ctx);
+      >::insert(patterns);
 }
 
 /// Populate the given list with patterns that apply Linalg tiling.
 static void insertTilingPatterns(OwningRewritePatternList &patterns,
-                                 const LinalgTilingOptions &options,
-                                 MLIRContext *ctx) {
+                                 const LinalgTilingOptions &options) {
   RewritePatternList<GenericOp, IndexedGenericOp,
 #define GET_OP_LIST
 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
-                     >::insert(patterns, options, ctx);
+                     >::insert(patterns, options);
 }
 
 static void applyTilingToLoopPatterns(LinalgTilingLoopType loopType,
@@ -590,8 +591,8 @@ static void applyTilingToLoopPatterns(LinalgTilingLoopType loopType,
   auto options =
       LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(loopType);
   MLIRContext *ctx = funcOp.getContext();
-  OwningRewritePatternList patterns;
-  insertTilingPatterns(patterns, options, ctx);
+  OwningRewritePatternList patterns(ctx);
+  insertTilingPatterns(patterns, options);
   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
   (void)applyPatternsAndFoldGreedily(
       funcOp, getLinalgTilingCanonicalizationPatterns(ctx));
index dab32d2..b56072c 100644 (file)
@@ -580,8 +580,8 @@ static void
 populateVectorizationPatterns(OwningRewritePatternList &tilingPatterns,
                               OwningRewritePatternList &promotionPatterns,
                               OwningRewritePatternList &vectorizationPatterns,
-                              ArrayRef<int64_t> tileSizes,
-                              MLIRContext *context) {
+                              ArrayRef<int64_t> tileSizes) {
+  auto *context = tilingPatterns.getContext();
   if (tileSizes.size() < N)
     return;
 
@@ -608,45 +608,47 @@ populateVectorizationPatterns(OwningRewritePatternList &tilingPatterns,
 void mlir::linalg::populateConvVectorizationPatterns(
     MLIRContext *context, SmallVectorImpl<OwningRewritePatternList> &patterns,
     ArrayRef<int64_t> tileSizes) {
-  OwningRewritePatternList tiling, promotion, vectorization;
+  OwningRewritePatternList tiling(context);
+  OwningRewritePatternList promotion(context);
+  OwningRewritePatternList vectorization(context);
   populateVectorizationPatterns<ConvWOp, 1>(tiling, promotion, vectorization,
-                                            tileSizes, context);
+                                            tileSizes);
 
   populateVectorizationPatterns<ConvNWCOp, 3>(tiling, promotion, vectorization,
-                                              tileSizes, context);
+                                              tileSizes);
   populateVectorizationPatterns<ConvInputNWCFilterWCFOp, 3>(
-      tiling, promotion, vectorization, tileSizes, context);
+      tiling, promotion, vectorization, tileSizes);
 
   populateVectorizationPatterns<ConvNCWOp, 3>(tiling, promotion, vectorization,
-                                              tileSizes, context);
+                                              tileSizes);
   populateVectorizationPatterns<ConvInputNCWFilterWCFOp, 3>(
-      tiling, promotion, vectorization, tileSizes, context);
+      tiling, promotion, vectorization, tileSizes);
 
   populateVectorizationPatterns<ConvHWOp, 2>(tiling, promotion, vectorization,
-                                             tileSizes, context);
+                                             tileSizes);
 
   populateVectorizationPatterns<ConvNHWCOp, 4>(tiling, promotion, vectorization,
-                                               tileSizes, context);
+                                               tileSizes);
   populateVectorizationPatterns<ConvInputNHWCFilterHWCFOp, 4>(
-      tiling, promotion, vectorization, tileSizes, context);
+      tiling, promotion, vectorization, tileSizes);
 
   populateVectorizationPatterns<ConvNCHWOp, 4>(tiling, promotion, vectorization,
-                                               tileSizes, context);
+                                               tileSizes);
   populateVectorizationPatterns<ConvInputNCHWFilterHWCFOp, 4>(
-      tiling, promotion, vectorization, tileSizes, context);
+      tiling, promotion, vectorization, tileSizes);
 
   populateVectorizationPatterns<ConvDHWOp, 3>(tiling, promotion, vectorization,
-                                              tileSizes, context);
+                                              tileSizes);
 
-  populateVectorizationPatterns<ConvNDHWCOp, 5>(
-      tiling, promotion, vectorization, tileSizes, context);
+  populateVectorizationPatterns<ConvNDHWCOp, 5>(tiling, promotion,
+                                                vectorization, tileSizes);
   populateVectorizationPatterns<ConvInputNDHWCFilterDHWCFOp, 5>(
-      tiling, promotion, vectorization, tileSizes, context);
+      tiling, promotion, vectorization, tileSizes);
 
-  populateVectorizationPatterns<ConvNCDHWOp, 5>(
-      tiling, promotion, vectorization, tileSizes, context);
+  populateVectorizationPatterns<ConvNCDHWOp, 5>(tiling, promotion,
+                                                vectorization, tileSizes);
   populateVectorizationPatterns<ConvInputNCDHWFilterDHWCFOp, 5>(
-      tiling, promotion, vectorization, tileSizes, context);
+      tiling, promotion, vectorization, tileSizes);
 
   patterns.push_back(std::move(tiling));
   patterns.push_back(std::move(promotion));
index 06d5158..d61dc31 100644 (file)
@@ -60,7 +60,6 @@ public:
 };
 } // namespace
 
-void mlir::populateExpandTanhPattern(OwningRewritePatternList &patterns,
-                                     MLIRContext *ctx) {
-  patterns.insert<TanhOpConverter>(ctx);
+void mlir::populateExpandTanhPattern(OwningRewritePatternList &patterns) {
+  patterns.insert<TanhOpConverter>(patterns.getContext());
 }
index f13e48e..6c5d74f 100644 (file)
@@ -10,6 +10,7 @@
 // that do not rely on any of the library functions.
 //
 //===----------------------------------------------------------------------===//
+
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/Transforms/Bufferize.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include <limits.h>
+#include <climits>
 
 using namespace mlir;
 using namespace mlir::vector;
@@ -530,7 +532,7 @@ ExpApproximation::matchAndRewrite(math::ExpOp op,
 //----------------------------------------------------------------------------//
 
 void mlir::populateMathPolynomialApproximationPatterns(
-    OwningRewritePatternList &patterns, MLIRContext *ctx) {
+    OwningRewritePatternList &patterns) {
   patterns.insert<TanhApproximation, LogApproximation, Log2Approximation,
-                  ExpApproximation>(ctx);
+                  ExpApproximation>(patterns.getContext());
 }
index f67020d..44d8be9 100644 (file)
@@ -91,7 +91,7 @@ QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier,
 }
 
 void ConvertConstPass::runOnFunction() {
-  OwningRewritePatternList patterns;
+  OwningRewritePatternList patterns(&getContext());
   auto func = getFunction();
   auto *context = &getContext();
   patterns.insert<QuantizedConstRewrite>(context);
index daa1cda..ac28ce6 100644 (file)
@@ -124,8 +124,8 @@ public:
 
 void ConvertSimulatedQuantPass::runOnFunction() {
   bool hadFailure = false;
-  OwningRewritePatternList patterns;
   auto func = getFunction();
+  OwningRewritePatternList patterns(func.getContext());
   auto ctx = func.getContext();
   patterns.insert<ConstFakeQuantRewrite, ConstFakeQuantPerAxisRewrite>(
       ctx, &hadFailure);
index aa25f47..15a5aba 100644 (file)
@@ -25,12 +25,12 @@ struct SCFBufferizePass : public SCFBufferizeBase<SCFBufferizePass> {
     auto *context = &getContext();
 
     BufferizeTypeConverter typeConverter;
-    OwningRewritePatternList patterns;
+    OwningRewritePatternList patterns(context);
     ConversionTarget target(*context);
 
     populateBufferizeMaterializationLegality(target);
-    populateSCFStructuralTypeConversionsAndLegality(context, typeConverter,
-                                                    patterns, target);
+    populateSCFStructuralTypeConversionsAndLegality(typeConverter, patterns,
+                                                    target);
     if (failed(applyPartialConversion(func, target, std::move(patterns))))
       return signalPassFailure();
   };
index 9197375..0029c3b 100644 (file)
@@ -134,10 +134,10 @@ public:
 } // namespace
 
 void mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
-    MLIRContext *context, TypeConverter &typeConverter,
-    OwningRewritePatternList &patterns, ConversionTarget &target) {
+    TypeConverter &typeConverter, OwningRewritePatternList &patterns,
+    ConversionTarget &target) {
   patterns.insert<ConvertForOpTypes, ConvertIfOpTypes, ConvertYieldOpTypes>(
-      typeConverter, context);
+      typeConverter, patterns.getContext());
   target.addDynamicallyLegalOp<ForOp, IfOp>([&](Operation *op) {
     return typeConverter.isLegal(op->getResultTypes());
   });
index 0aa4139..c5eeb8a 100644 (file)
@@ -23,13 +23,14 @@ namespace {
 namespace mlir {
 namespace spirv {
 void populateSPIRVGLSLCanonicalizationPatterns(
-    OwningRewritePatternList &results, MLIRContext *context) {
+    OwningRewritePatternList &results) {
   results.insert<ConvertComparisonIntoClampSPV_FOrdLessThanOp,
                  ConvertComparisonIntoClampSPV_FOrdLessThanEqualOp,
                  ConvertComparisonIntoClampSPV_SLessThanOp,
                  ConvertComparisonIntoClampSPV_SLessThanEqualOp,
                  ConvertComparisonIntoClampSPV_ULessThanOp,
-                 ConvertComparisonIntoClampSPV_ULessThanEqualOp>(context);
+                 ConvertComparisonIntoClampSPV_ULessThanEqualOp>(
+      results.getContext());
 }
 } // namespace spirv
 } // namespace mlir
index c4954ca..afaadb0 100644 (file)
@@ -74,10 +74,10 @@ public:
 };
 } // namespace
 
-static void populateSPIRVLayoutInfoPatterns(OwningRewritePatternList &patterns,
-                                            MLIRContext *ctx) {
+static void
+populateSPIRVLayoutInfoPatterns(OwningRewritePatternList &patterns) {
   patterns.insert<SPIRVGlobalVariableOpLayoutInfoDecoration,
-                  SPIRVAddressOfOpLayoutInfoDecoration>(ctx);
+                  SPIRVAddressOfOpLayoutInfoDecoration>(patterns.getContext());
 }
 
 namespace {
@@ -90,8 +90,8 @@ class DecorateSPIRVCompositeTypeLayoutPass
 
 void DecorateSPIRVCompositeTypeLayoutPass::runOnOperation() {
   auto module = getOperation();
-  OwningRewritePatternList patterns;
-  populateSPIRVLayoutInfoPatterns(patterns, module.getContext());
+  OwningRewritePatternList patterns(module.getContext());
+  populateSPIRVLayoutInfoPatterns(patterns);
   ConversionTarget target(*(module.getContext()));
   target.addLegalDialect<spirv::SPIRVDialect>();
   target.addLegalOp<FuncOp>();
index d96892b..71ebf8c 100644 (file)
@@ -246,7 +246,7 @@ void LowerABIAttributesPass::runOnOperation() {
     return builder.create<spirv::BitcastOp>(loc, type, inputs[0]).getResult();
   });
 
-  OwningRewritePatternList patterns;
+  OwningRewritePatternList patterns(context);
   patterns.insert<ProcessInterfaceVarABI>(typeConverter, context);
 
   ConversionTarget target(*context);
index c544512..4aa8bd4 100644 (file)
@@ -515,9 +515,8 @@ FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
 }
 
 void mlir::populateBuiltinFuncToSPIRVPatterns(
-    MLIRContext *context, SPIRVTypeConverter &typeConverter,
-    OwningRewritePatternList &patterns) {
-  patterns.insert<FuncOpConversion>(typeConverter, context);
+    SPIRVTypeConverter &typeConverter, OwningRewritePatternList &patterns) {
+  patterns.insert<FuncOpConversion>(typeConverter, patterns.getContext());
 }
 
 //===----------------------------------------------------------------------===//
index 36b5eac..779993c 100644 (file)
@@ -19,13 +19,13 @@ struct ShapeBufferizePass : public ShapeBufferizeBase<ShapeBufferizePass> {
   void runOnFunction() override {
     MLIRContext &ctx = getContext();
 
-    OwningRewritePatternList patterns;
+    OwningRewritePatternList patterns(&ctx);
     BufferizeTypeConverter typeConverter;
-    ConversionTarget target(getContext());
+    ConversionTarget target(ctx);
 
     populateBufferizeMaterializationLegality(target);
-    populateShapeStructuralTypeConversionsAndLegality(&ctx, typeConverter,
-                                                      patterns, target);
+    populateShapeStructuralTypeConversionsAndLegality(typeConverter, patterns,
+                                                      target);
 
     if (failed(
             applyPartialConversion(getFunction(), target, std::move(patterns))))
index 492abce..b712264 100644 (file)
@@ -46,8 +46,8 @@ class RemoveShapeConstraintsPass
   void runOnFunction() override {
     MLIRContext &ctx = getContext();
 
-    OwningRewritePatternList patterns;
-    populateRemoveShapeConstraintsPatterns(patterns, &ctx);
+    OwningRewritePatternList patterns(&ctx);
+    populateRemoveShapeConstraintsPatterns(patterns);
 
     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
   }
@@ -56,8 +56,9 @@ class RemoveShapeConstraintsPass
 } // namespace
 
 void mlir::populateRemoveShapeConstraintsPatterns(
-    OwningRewritePatternList &patterns, MLIRContext *ctx) {
-  patterns.insert<RemoveCstrBroadcastableOp, RemoveCstrEqOp>(ctx);
+    OwningRewritePatternList &patterns) {
+  patterns.insert<RemoveCstrBroadcastableOp, RemoveCstrEqOp>(
+      patterns.getContext());
 }
 
 std::unique_ptr<FunctionPass> mlir::createRemoveShapeConstraintsPass() {
index 6190ff3..479ce71 100644 (file)
@@ -61,8 +61,8 @@ struct ShapeToShapeLowering
 void ShapeToShapeLowering::runOnFunction() {
   MLIRContext &ctx = getContext();
 
-  OwningRewritePatternList patterns;
-  populateShapeRewritePatterns(&ctx, patterns);
+  OwningRewritePatternList patterns(&ctx);
+  populateShapeRewritePatterns(patterns);
 
   ConversionTarget target(getContext());
   target.addLegalDialect<ShapeDialect, StandardOpsDialect>();
@@ -72,9 +72,8 @@ void ShapeToShapeLowering::runOnFunction() {
     signalPassFailure();
 }
 
-void mlir::populateShapeRewritePatterns(MLIRContext *context,
-                                        OwningRewritePatternList &patterns) {
-  patterns.insert<NumElementsOpConverter>(context);
+void mlir::populateShapeRewritePatterns(OwningRewritePatternList &patterns) {
+  patterns.insert<NumElementsOpConverter>(patterns.getContext());
 }
 
 std::unique_ptr<Pass> mlir::createShapeToShapeLowering() {
index 041b54b..6ebf9fc 100644 (file)
@@ -57,10 +57,10 @@ public:
 } // namespace
 
 void mlir::populateShapeStructuralTypeConversionsAndLegality(
-    MLIRContext *context, TypeConverter &typeConverter,
-    OwningRewritePatternList &patterns, ConversionTarget &target) {
+    TypeConverter &typeConverter, OwningRewritePatternList &patterns,
+    ConversionTarget &target) {
   patterns.insert<ConvertAssumingOpTypes, ConvertAssumingYieldOpTypes>(
-      typeConverter, context);
+      typeConverter, patterns.getContext());
   target.addDynamicallyLegalOp<AssumingOp>([&](AssumingOp op) {
     return typeConverter.isLegal(op.getResultTypes());
   });
index c2b9c93..6eeb39e 100644 (file)
@@ -54,10 +54,10 @@ public:
 };
 } // namespace
 
-void mlir::populateStdBufferizePatterns(MLIRContext *context,
-                                        BufferizeTypeConverter &typeConverter,
+void mlir::populateStdBufferizePatterns(BufferizeTypeConverter &typeConverter,
                                         OwningRewritePatternList &patterns) {
-  patterns.insert<BufferizeDimOp, BufferizeSelectOp>(typeConverter, context);
+  patterns.insert<BufferizeDimOp, BufferizeSelectOp>(typeConverter,
+                                                     patterns.getContext());
 }
 
 namespace {
@@ -65,14 +65,14 @@ struct StdBufferizePass : public StdBufferizeBase<StdBufferizePass> {
   void runOnFunction() override {
     auto *context = &getContext();
     BufferizeTypeConverter typeConverter;
-    OwningRewritePatternList patterns;
+    OwningRewritePatternList patterns(context);
     ConversionTarget target(*context);
 
     target.addLegalDialect<memref::MemRefDialect>();
     target.addLegalDialect<StandardOpsDialect>();
     target.addLegalDialect<scf::SCFDialect>();
 
-    populateStdBufferizePatterns(context, typeConverter, patterns);
+    populateStdBufferizePatterns(typeConverter, patterns);
     // We only bufferize the case of tensor selected type and scalar condition,
     // as that boils down to a select over memref descriptors (don't need to
     // touch the data).
index 98b261c..3f2504e 100644 (file)
@@ -211,8 +211,8 @@ struct StdExpandOpsPass : public StdExpandOpsBase<StdExpandOpsPass> {
   void runOnFunction() override {
     MLIRContext &ctx = getContext();
 
-    OwningRewritePatternList patterns;
-    populateStdExpandOpsPatterns(&ctx, patterns);
+    OwningRewritePatternList patterns(&ctx);
+    populateStdExpandOpsPatterns(patterns);
 
     ConversionTarget target(getContext());
 
@@ -234,11 +234,10 @@ struct StdExpandOpsPass : public StdExpandOpsBase<StdExpandOpsPass> {
 
 } // namespace
 
-void mlir::populateStdExpandOpsPatterns(MLIRContext *context,
-                                        OwningRewritePatternList &patterns) {
+void mlir::populateStdExpandOpsPatterns(OwningRewritePatternList &patterns) {
   patterns.insert<AtomicRMWOpConverter, MemRefReshapeOpConverter,
                   SignedCeilDivIOpConverter, SignedFloorDivIOpConverter>(
-      context);
+      patterns.getContext());
 }
 
 std::unique_ptr<Pass> mlir::createStdExpandOpsPass() {
index d38a564..04424c7 100644 (file)
@@ -28,21 +28,20 @@ struct FuncBufferizePass : public FuncBufferizeBase<FuncBufferizePass> {
     auto *context = &getContext();
 
     BufferizeTypeConverter typeConverter;
-    OwningRewritePatternList patterns;
+    OwningRewritePatternList patterns(context);
     ConversionTarget target(*context);
 
-    populateFuncOpTypeConversionPattern(patterns, context, typeConverter);
+    populateFuncOpTypeConversionPattern(patterns, typeConverter);
     target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
       return typeConverter.isSignatureLegal(op.getType()) &&
              typeConverter.isLegal(&op.getBody());
     });
-    populateCallOpTypeConversionPattern(patterns, context, typeConverter);
+    populateCallOpTypeConversionPattern(patterns, typeConverter);
     target.addDynamicallyLegalOp<CallOp>(
         [&](CallOp op) { return typeConverter.isLegal(op); });
 
-    populateBranchOpInterfaceTypeConversionPattern(patterns, context,
-                                                   typeConverter);
-    populateReturnOpTypeConversionPattern(patterns, context, typeConverter);
+    populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter);
+    populateReturnOpTypeConversionPattern(patterns, typeConverter);
     target.addLegalOp<ModuleOp, ModuleTerminatorOp, memref::TensorLoadOp,
                       memref::BufferCastOp>();
 
index 4ba2069..4008676 100644 (file)
@@ -38,9 +38,8 @@ struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {
 } // end anonymous namespace
 
 void mlir::populateCallOpTypeConversionPattern(
-    OwningRewritePatternList &patterns, MLIRContext *ctx,
-    TypeConverter &converter) {
-  patterns.insert<CallOpSignatureConversion>(converter, ctx);
+    OwningRewritePatternList &patterns, TypeConverter &converter) {
+  patterns.insert<CallOpSignatureConversion>(converter, patterns.getContext());
 }
 
 namespace {
@@ -103,9 +102,9 @@ public:
 } // end anonymous namespace
 
 void mlir::populateBranchOpInterfaceTypeConversionPattern(
-    OwningRewritePatternList &patterns, MLIRContext *ctx,
-    TypeConverter &typeConverter) {
-  patterns.insert<BranchOpInterfaceTypeConversion>(typeConverter, ctx);
+    OwningRewritePatternList &patterns, TypeConverter &typeConverter) {
+  patterns.insert<BranchOpInterfaceTypeConversion>(typeConverter,
+                                                   patterns.getContext());
 }
 
 bool mlir::isLegalForBranchOpInterfaceTypeConversionPattern(
@@ -125,9 +124,8 @@ bool mlir::isLegalForBranchOpInterfaceTypeConversionPattern(
 }
 
 void mlir::populateReturnOpTypeConversionPattern(
-    OwningRewritePatternList &patterns, MLIRContext *ctx,
-    TypeConverter &typeConverter) {
-  patterns.insert<ReturnOpTypeConversion>(typeConverter, ctx);
+    OwningRewritePatternList &patterns, TypeConverter &typeConverter) {
+  patterns.insert<ReturnOpTypeConversion>(typeConverter, patterns.getContext());
 }
 
 bool mlir::isLegalForReturnOpTypeConversionPattern(Operation *op,
index 55d3405..625bdc1 100644 (file)
@@ -90,7 +90,7 @@ struct TensorConstantBufferizePass
 
     auto *context = &getContext();
     BufferizeTypeConverter typeConverter;
-    OwningRewritePatternList patterns;
+    OwningRewritePatternList patterns(context);
     ConversionTarget target(*context);
 
     target.addLegalDialect<memref::MemRefDialect>();
index 1ef742e..4c1d0b7 100644 (file)
@@ -138,10 +138,9 @@ public:
 } // namespace
 
 void mlir::populateTensorBufferizePatterns(
-    MLIRContext *context, BufferizeTypeConverter &typeConverter,
-    OwningRewritePatternList &patterns) {
+    BufferizeTypeConverter &typeConverter, OwningRewritePatternList &patterns) {
   patterns.insert<BufferizeCastOp, BufferizeExtractOp, BufferizeFromElementsOp,
-                  BufferizeGenerateOp>(typeConverter, context);
+                  BufferizeGenerateOp>(typeConverter, patterns.getContext());
 }
 
 namespace {
@@ -149,12 +148,12 @@ struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> {
   void runOnFunction() override {
     auto *context = &getContext();
     BufferizeTypeConverter typeConverter;
-    OwningRewritePatternList patterns;
+    OwningRewritePatternList patterns(context);
     ConversionTarget target(*context);
 
     populateBufferizeMaterializationLegality(target);
 
-    populateTensorBufferizePatterns(context, typeConverter, patterns);
+    populateTensorBufferizePatterns(typeConverter, patterns);
     target.addIllegalOp<tensor::CastOp, tensor::ExtractOp,
                         tensor::FromElementsOp, tensor::GenerateOp>();
     target.addLegalDialect<memref::MemRefDialect>();
index 540a790..2ab1a64 100644 (file)
@@ -251,7 +251,7 @@ struct TosaMakeBroadcastable
 public:
   void runOnFunction() override {
     auto func = getFunction();
-    OwningRewritePatternList patterns;
+    OwningRewritePatternList patterns(func.getContext());
     MLIRContext *ctx = func.getContext();
     // Add the generated patterns to the list.
     patterns.insert<ConvertTosaOp<tosa::AddOp>>(ctx);
index 08bf762..23b194d 100644 (file)
@@ -3534,11 +3534,11 @@ void CreateMaskOp::getCanonicalizationPatterns(
 }
 
 void mlir::vector::populateVectorToVectorCanonicalizationPatterns(
-    OwningRewritePatternList &patterns, MLIRContext *context) {
+    OwningRewritePatternList &patterns) {
   patterns.insert<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder,
                   GatherFolder, ScatterFolder, ExpandLoadFolder,
                   CompressStoreFolder, StridedSliceConstantMaskFolder,
-                  TransposeFolder>(context);
+                  TransposeFolder>(patterns.getContext());
 }
 
 #define GET_OP_CLASSES
index 57602a5..16664b1 100644 (file)
@@ -2784,7 +2784,7 @@ struct TransferReadToVectorLoadLowering
       // If broadcasting is required and the number of loaded elements is 1 then
       // we can create `memref.load` instead of `vector.load`.
       loadOp = rewriter.create<memref::LoadOp>(read.getLoc(), read.source(),
-                                             read.indices());
+                                               read.indices());
     } else {
       // Otherwise create `vector.load`.
       loadOp = rewriter.create<vector::LoadOp>(read.getLoc(),
@@ -3263,43 +3263,43 @@ struct BubbleUpBitCastForStridedSliceInsert
 // TODO: Add pattern to rewrite ExtractSlices(ConstantMaskOp).
 // TODO: Add this as DRR pattern.
 void mlir::vector::populateVectorToVectorTransformationPatterns(
-    OwningRewritePatternList &patterns, MLIRContext *context) {
+    OwningRewritePatternList &patterns) {
   patterns.insert<ShapeCastOpDecomposer, ShapeCastOpFolder, TupleGetFolderOp,
                   TransferReadExtractPattern, TransferWriteInsertPattern>(
-      context);
+      patterns.getContext());
 }
 
 void mlir::vector::populateSplitVectorTransferPatterns(
-    OwningRewritePatternList &patterns, MLIRContext *context,
+    OwningRewritePatternList &patterns,
     std::function<bool(Operation *)> ignoreFilter) {
-  patterns.insert<SplitTransferReadOp, SplitTransferWriteOp>(context,
-                                                             ignoreFilter);
+  patterns.insert<SplitTransferReadOp, SplitTransferWriteOp>(
+      patterns.getContext(), ignoreFilter);
 }
 
 void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
-    OwningRewritePatternList &patterns, MLIRContext *context) {
+    OwningRewritePatternList &patterns) {
   patterns.insert<CastAwayExtractStridedSliceLeadingOneDim,
                   CastAwayInsertStridedSliceLeadingOneDim,
                   CastAwayTransferReadLeadingOneDim,
                   CastAwayTransferWriteLeadingOneDim, ShapeCastOpFolder>(
-      context);
+      patterns.getContext());
 }
 
 void mlir::vector::populateBubbleVectorBitCastOpPatterns(
-    OwningRewritePatternList &patterns, MLIRContext *context) {
+    OwningRewritePatternList &patterns) {
   patterns.insert<BubbleDownVectorBitCastForExtract,
                   BubbleDownBitCastForStridedSliceExtract,
-                  BubbleUpBitCastForStridedSliceInsert>(context);
+                  BubbleUpBitCastForStridedSliceInsert>(patterns.getContext());
 }
 
 void mlir::vector::populateVectorSlicesLoweringPatterns(
-    OwningRewritePatternList &patterns, MLIRContext *context) {
-  patterns.insert<ExtractSlicesOpLowering, InsertSlicesOpLowering>(context);
+    OwningRewritePatternList &patterns) {
+  patterns.insert<ExtractSlicesOpLowering, InsertSlicesOpLowering>(
+      patterns.getContext());
 }
 
 void mlir::vector::populateVectorContractLoweringPatterns(
-    OwningRewritePatternList &patterns, MLIRContext *context,
-    VectorTransformsOptions parameters) {
+    OwningRewritePatternList &patterns, VectorTransformsOptions parameters) {
   // clang-format off
   patterns.insert<BroadcastOpLowering,
                   CreateMaskOpLowering,
@@ -3307,16 +3307,16 @@ void mlir::vector::populateVectorContractLoweringPatterns(
                   OuterProductOpLowering,
                   ShapeCastOp2DDownCastRewritePattern,
                   ShapeCastOp2DUpCastRewritePattern,
-                  ShapeCastOpRewritePattern>(context);
+                  ShapeCastOpRewritePattern>(patterns.getContext());
   patterns.insert<TransposeOpLowering,
                   ContractionOpLowering,
                   ContractionOpToMatmulOpLowering,
-                  ContractionOpToOuterProductOpLowering>(parameters, context);
+                  ContractionOpToOuterProductOpLowering>(parameters, patterns.getContext());
   // clang-format on
 }
 
 void mlir::vector::populateVectorTransferLoweringPatterns(
-    OwningRewritePatternList &patterns, MLIRContext *context) {
+    OwningRewritePatternList &patterns) {
   patterns.insert<TransferReadToVectorLoadLowering,
-                  TransferWriteToVectorStoreLowering>(context);
+                  TransferWriteToVectorStoreLowering>(patterns.getContext());
 }
index 74de861..ba1f566 100644 (file)
@@ -84,10 +84,9 @@ public:
 } // namespace
 
 void mlir::populateEliminateBufferizeMaterializationsPatterns(
-    MLIRContext *context, BufferizeTypeConverter &typeConverter,
-    OwningRewritePatternList &patterns) {
-  patterns.insert<BufferizeTensorLoadOp, BufferizeCastOp>(typeConverter,
-                                                          context);
+    BufferizeTypeConverter &typeConverter, OwningRewritePatternList &patterns) {
+  patterns.insert<BufferizeTensorLoadOp, BufferizeCastOp>(
+      typeConverter, patterns.getContext());
 }
 
 namespace {
@@ -101,11 +100,10 @@ struct FinalizingBufferizePass
     auto *context = &getContext();
 
     BufferizeTypeConverter typeConverter;
-    OwningRewritePatternList patterns;
+    OwningRewritePatternList patterns(context);
     ConversionTarget target(*context);
 
-    populateEliminateBufferizeMaterializationsPatterns(context, typeConverter,
-                                                       patterns);
+    populateEliminateBufferizeMaterializationsPatterns(typeConverter, patterns);
 
     // If all result types are legal, and all block arguments are legal (ensured
     // by func conversion above), then all types in the program are legal.
index cd99681..900d89c 100644 (file)
@@ -25,7 +25,7 @@ struct Canonicalizer : public CanonicalizerBase<Canonicalizer> {
   /// Initialize the canonicalizer by building the set of patterns used during
   /// execution.
   LogicalResult initialize(MLIRContext *context) override {
-    OwningRewritePatternList owningPatterns;
+    OwningRewritePatternList owningPatterns(context);
     for (auto *op : context->getRegisteredOperations())
       op->getCanonicalizationPatterns(owningPatterns, context);
     patterns = std::move(owningPatterns);
index 5c99c58..113ba46 100644 (file)
@@ -75,7 +75,8 @@ computeConversionSet(iterator_range<Region::iterator> region,
 
 /// A utility function to log a successful result for the given reason.
 template <typename... Args>
-static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
+static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt,
+                       Args &&... args) {
   LLVM_DEBUG({
     os.unindent();
     os.startLine() << "} -> SUCCESS";
@@ -88,7 +89,8 @@ static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
 
 /// A utility function to log a failure result for the given reason.
 template <typename... Args>
-static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
+static void logFailure(llvm::ScopedPrinter &os, StringRef fmt,
+                       Args &&... args) {
   LLVM_DEBUG({
     os.unindent();
     os.startLine() << "} -> FAILURE : "
@@ -2611,15 +2613,14 @@ struct FunctionLikeSignatureConversion : public ConversionPattern {
 
 void mlir::populateFunctionLikeTypeConversionPattern(
     StringRef functionLikeOpName, OwningRewritePatternList &patterns,
-    MLIRContext *ctx, TypeConverter &converter) {
-  patterns.insert<FunctionLikeSignatureConversion>(functionLikeOpName, ctx,
-                                                   converter);
+    TypeConverter &converter) {
+  patterns.insert<FunctionLikeSignatureConversion>(
+      functionLikeOpName, patterns.getContext(), converter);
 }
 
 void mlir::populateFuncOpTypeConversionPattern(
-    OwningRewritePatternList &patterns, MLIRContext *ctx,
-    TypeConverter &converter) {
-  populateFunctionLikeTypeConversionPattern<FuncOp>(patterns, ctx, converter);
+    OwningRewritePatternList &patterns, TypeConverter &converter) {
+  populateFunctionLikeTypeConversionPattern<FuncOp>(patterns, converter);
 }
 
 //===----------------------------------------------------------------------===//
index a9b5979..cd58ec9 100644 (file)
@@ -403,7 +403,7 @@ LogicalResult mlir::affineForOpBodySkew(AffineForOp forOp,
 
       if (res) {
         // Simplify/canonicalize the affine.for.
-        OwningRewritePatternList patterns;
+        OwningRewritePatternList patterns(res.getContext());
         AffineForOp::getCanonicalizationPatterns(patterns, res.getContext());
         bool erased;
         (void)applyOpPatternsAndFold(res, std::move(patterns), &erased);
index 4808557..b8aa7da 100644 (file)
@@ -110,7 +110,7 @@ void TestAffineDataCopy::runOnFunction() {
   // Promoting single iteration loops could lead to simplification of
   // generated load's/store's, and the latter could anyway also be
   // canonicalized.
-  OwningRewritePatternList patterns;
+  OwningRewritePatternList patterns(&getContext());
   for (auto op : copyOps) {
     patterns.clear();
     if (isa<AffineLoadOp>(op)) {
index 99a6022..f66ac8c 100644 (file)
@@ -139,7 +139,7 @@ void ConvertToTargetEnv::runOnFunction() {
 
   auto target = spirv::SPIRVConversionTarget::get(targetEnv);
 
-  OwningRewritePatternList patterns;
+  OwningRewritePatternList patterns(context);
   patterns.insert<ConvertToAtomCmpExchangeWeak, ConvertToBitReverse,
                   ConvertToGroupNonUniformBallot, ConvertToModule,
                   ConvertToSubgroupBallot>(context);
index d80f912..75bc52a 100644 (file)
@@ -25,8 +25,8 @@ public:
 } // namespace
 
 void TestGLSLCanonicalizationPass::runOnOperation() {
-  OwningRewritePatternList patterns;
-  spirv::populateSPIRVGLSLCanonicalizationPatterns(patterns, &getContext());
+  OwningRewritePatternList patterns(&getContext());
+  spirv::populateSPIRVGLSLCanonicalizationPatterns(patterns);
   (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
 }
 
index 53651de..8c09406 100644 (file)
@@ -79,7 +79,7 @@ public:
 
 struct TestPatternDriver : public PassWrapper<TestPatternDriver, FunctionPass> {
   void runOnFunction() override {
-    mlir::OwningRewritePatternList patterns;
+    mlir::OwningRewritePatternList patterns(&getContext());
     populateWithGenerated(&getContext(), patterns);
 
     // Verify named pattern is generated with expected name.
@@ -557,7 +557,7 @@ struct TestLegalizePatternDriver
 
   void runOnOperation() override {
     TestTypeConverter converter;
-    mlir::OwningRewritePatternList patterns;
+    mlir::OwningRewritePatternList patterns(&getContext());
     populateWithGenerated(&getContext(), patterns);
     patterns.insert<
         TestRegionRewriteBlockMovement, TestRegionRewriteUndo, TestCreateBlock,
@@ -568,10 +568,8 @@ struct TestLegalizePatternDriver
         TestNonRootReplacement, TestBoundedRecursiveRewrite,
         TestNestedOpCreationUndoRewrite>(&getContext());
     patterns.insert<TestDropOpSignatureConversion>(&getContext(), converter);
-    mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
-                                              converter);
-    mlir::populateCallOpTypeConversionPattern(patterns, &getContext(),
-                                              converter);
+    mlir::populateFuncOpTypeConversionPattern(patterns, converter);
+    mlir::populateCallOpTypeConversionPattern(patterns, converter);
 
     // Define the conversion target used for the test.
     ConversionTarget target(getContext());
@@ -700,7 +698,7 @@ struct OneVResOneVOperandOp1Converter
 struct TestRemappedValue
     : public mlir::PassWrapper<TestRemappedValue, FunctionPass> {
   void runOnFunction() override {
-    mlir::OwningRewritePatternList patterns;
+    mlir::OwningRewritePatternList patterns(&getContext());
     patterns.insert<OneVResOneVOperandOp1Converter>(&getContext());
 
     mlir::ConversionTarget target(getContext());
@@ -742,7 +740,7 @@ struct RemoveTestDialectOps : public RewritePattern {
 struct TestUnknownRootOpDriver
     : public mlir::PassWrapper<TestUnknownRootOpDriver, FunctionPass> {
   void runOnFunction() override {
-    mlir::OwningRewritePatternList patterns;
+    mlir::OwningRewritePatternList patterns(&getContext());
     patterns.insert<RemoveTestDialectOps>();
 
     mlir::ConversionTarget target(getContext());
@@ -878,12 +876,11 @@ struct TestTypeConversionDriver
     });
 
     // Initialize the set of rewrite patterns.
-    OwningRewritePatternList patterns;
+    OwningRewritePatternList patterns(&getContext());
     patterns.insert<TestTypeConsumerForward, TestTypeConversionProducer,
                     TestSignatureConversionUndo>(converter, &getContext());
     patterns.insert<TestTypeConversionAnotherProducer>(&getContext());
-    mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
-                                              converter);
+    mlir::populateFuncOpTypeConversionPattern(patterns, converter);
 
     if (failed(applyPartialConversion(getOperation(), target,
                                       std::move(patterns))))
@@ -966,8 +963,8 @@ struct TestMergeBlocksPatternDriver
     : public PassWrapper<TestMergeBlocksPatternDriver,
                          OperationPass<ModuleOp>> {
   void runOnOperation() override {
-    mlir::OwningRewritePatternList patterns;
     MLIRContext *context = &getContext();
+    mlir::OwningRewritePatternList patterns(context);
     patterns
         .insert<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>(
             context);
@@ -1035,8 +1032,8 @@ struct TestSelectiveReplacementPatternDriver
     : public PassWrapper<TestSelectiveReplacementPatternDriver,
                          OperationPass<>> {
   void runOnOperation() override {
-    mlir::OwningRewritePatternList patterns;
     MLIRContext *context = &getContext();
+    mlir::OwningRewritePatternList patterns(context);
     patterns.insert<TestSelectiveOpReplacementPattern>(context);
     (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(),
                                        std::move(patterns));
index 87bd782..e1f151f 100644 (file)
@@ -34,7 +34,7 @@ namespace {
 struct TestTraitFolder : public PassWrapper<TestTraitFolder, FunctionPass> {
   void runOnFunction() override {
     (void)applyPatternsAndFoldGreedily(getFunction(),
-                                       OwningRewritePatternList());
+                                       OwningRewritePatternList(&getContext()));
   }
 };
 } // end anonymous namespace
index 416bbca..06777ea 100644 (file)
@@ -183,8 +183,8 @@ struct TosaTestQuantUtilAPI
 };
 
 void TosaTestQuantUtilAPI::runOnFunction() {
-  OwningRewritePatternList patterns;
   auto *ctx = &getContext();
+  OwningRewritePatternList patterns(ctx);
   auto func = getFunction();
 
   patterns.insert<ConvertTosaNegateOp>(ctx);
index cda3542..cd741d0 100644 (file)
@@ -91,7 +91,7 @@ void TestConvVectorization::runOnOperation() {
   VectorTransformsOptions vectorTransformsOptions{
       VectorContractLowering::Dot, VectorTransposeLowering::EltWise};
 
-  OwningRewritePatternList vectorTransferPatterns;
+  OwningRewritePatternList vectorTransferPatterns(context);
   // Pattern is not applied because rank-reducing vector transfer is not yet
   // supported as can be seen in splitFullAndPartialTransferPrecondition,
   // VectorTransforms.cpp
@@ -106,15 +106,15 @@ void TestConvVectorization::runOnOperation() {
     llvm_unreachable("Unexpected failure in linalg to loops pass.");
 
   // Programmatic controlled lowering of vector.contract only.
-  OwningRewritePatternList vectorContractLoweringPatterns;
+  OwningRewritePatternList vectorContractLoweringPatterns(context);
   populateVectorContractLoweringPatterns(vectorContractLoweringPatterns,
-                                         context, vectorTransformsOptions);
+                                         vectorTransformsOptions);
   (void)applyPatternsAndFoldGreedily(module,
                                      std::move(vectorContractLoweringPatterns));
 
   // Programmatic controlled lowering of vector.transfer only.
-  OwningRewritePatternList vectorToLoopsPatterns;
-  populateVectorToSCFConversionPatterns(vectorToLoopsPatterns, context,
+  OwningRewritePatternList vectorToLoopsPatterns(context);
+  populateVectorToSCFConversionPatterns(vectorToLoopsPatterns,
                                         VectorTransferToSCFOptions());
   (void)applyPatternsAndFoldGreedily(module, std::move(vectorToLoopsPatterns));
 
index 2fe29b4..dbe1a31 100644 (file)
@@ -43,15 +43,15 @@ public:
     ModuleOp m = getOperation();
 
     // Populate type conversions.
-    LLVMTypeConverter type_converter(m.getContext());
-    type_converter.addConversion([&](test::TestType type) {
+    LLVMTypeConverter typeConverter(m.getContext());
+    typeConverter.addConversion([&](test::TestType type) {
       return LLVM::LLVMPointerType::get(IntegerType::get(m.getContext(), 8));
     });
 
     // Populate patterns.
-    OwningRewritePatternList patterns;
-    populateStdToLLVMConversionPatterns(type_converter, patterns);
-    patterns.insert<TestTypeProducerOpConverter>(type_converter);
+    OwningRewritePatternList patterns(m.getContext());
+    populateStdToLLVMConversionPatterns(typeConverter, patterns);
+    patterns.insert<TestTypeProducerOpConverter>(typeConverter);
 
     // Set target.
     ConversionTarget target(getContext());
index 2dd2c34..13c01a1 100644 (file)
@@ -33,7 +33,7 @@ struct TestDecomposeCallGraphTypes
     TypeConverter typeConverter;
     ConversionTarget target(*context);
     ValueDecomposer decomposer;
-    OwningRewritePatternList patterns;
+    OwningRewritePatternList patterns(context);
 
     target.addLegalDialect<test::TestDialect>();
 
index e67e89b..dc54a4b 100644 (file)
@@ -24,8 +24,8 @@ struct TestExpandTanhPass
 } // end anonymous namespace
 
 void TestExpandTanhPass::runOnFunction() {
-  OwningRewritePatternList patterns;
-  populateExpandTanhPattern(patterns, &getContext());
+  OwningRewritePatternList patterns(&getContext());
+  populateExpandTanhPattern(patterns);
   (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
 }
 
index 44ffd38..5f87a9f 100644 (file)
@@ -25,8 +25,8 @@ struct TestGpuRewritePass
     registry.insert<StandardOpsDialect, memref::MemRefDialect>();
   }
   void runOnOperation() override {
-    OwningRewritePatternList patterns;
-    populateGpuRewritePatterns(&getContext(), patterns);
+    OwningRewritePatternList patterns(&getContext());
+    populateGpuRewritePatterns(patterns);
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }
 };
index 1efc565..8cb7702 100644 (file)
@@ -109,7 +109,7 @@ struct TestLinalgFusionTransforms
   void runOnFunction() override {
     MLIRContext *context = &this->getContext();
     FuncOp funcOp = this->getFunction();
-    OwningRewritePatternList fusionPatterns;
+    OwningRewritePatternList fusionPatterns(context);
     Aliases alias;
     LinalgDependenceGraph dependenceGraph =
         LinalgDependenceGraph::buildDependenceGraph(alias, funcOp);
index 6cc390f..8e1cd2d 100644 (file)
@@ -92,7 +92,7 @@ struct TestLinalgTransforms
 
 static void applyPatterns(FuncOp funcOp) {
   MLIRContext *ctx = funcOp.getContext();
-  OwningRewritePatternList patterns;
+  OwningRewritePatternList patterns(ctx);
 
   //===--------------------------------------------------------------------===//
   // Linalg tiling patterns.
@@ -237,21 +237,26 @@ static void fillL1TilingAndMatmulToVectorPatterns(
     FuncOp funcOp, StringRef startMarker,
     SmallVectorImpl<OwningRewritePatternList> &patternsVector) {
   MLIRContext *ctx = funcOp.getContext();
-  patternsVector.emplace_back(std::make_unique<LinalgTilingPattern<MatmulOp>>(
-      ctx,
-      LinalgTilingOptions().setTileSizes({8, 12, 16}).setInterchange({1, 0, 2}),
-      LinalgTransformationFilter(Identifier::get(startMarker, ctx),
-                                 Identifier::get("L1", ctx))));
+  patternsVector.emplace_back(
+      ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>(
+               ctx,
+               LinalgTilingOptions()
+                   .setTileSizes({8, 12, 16})
+                   .setInterchange({1, 0, 2}),
+               LinalgTransformationFilter(Identifier::get(startMarker, ctx),
+                                          Identifier::get("L1", ctx))));
 
   patternsVector.emplace_back(
+      ctx,
       std::make_unique<LinalgPromotionPattern<MatmulOp>>(
           ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
           LinalgTransformationFilter(Identifier::get("L1", ctx),
                                      Identifier::get("VEC", ctx))));
 
-  patternsVector.emplace_back(std::make_unique<LinalgVectorizationPattern>(
-      MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(),
-      LinalgTransformationFilter(Identifier::get("VEC", ctx))));
+  patternsVector.emplace_back(
+      ctx, std::make_unique<LinalgVectorizationPattern>(
+               MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(),
+               LinalgTransformationFilter(Identifier::get("VEC", ctx))));
   patternsVector.back().insert<LinalgVectorizationPattern>(
       LinalgTransformationFilter().addFilter(
           [](Operation *op) { return success(isa<FillOp, CopyOp>(op)); }));
@@ -462,13 +467,14 @@ applyMatmulToVectorPatterns(FuncOp funcOp,
     fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("START", ctx),
                                           stage1Patterns);
   } else if (testMatmulToVectorPatterns2dTiling) {
-    stage1Patterns.emplace_back(std::make_unique<LinalgTilingPattern<MatmulOp>>(
-        ctx,
-        LinalgTilingOptions()
-            .setTileSizes({768, 264, 768})
-            .setInterchange({1, 2, 0}),
-        LinalgTransformationFilter(Identifier::get("START", ctx),
-                                   Identifier::get("L2", ctx))));
+    stage1Patterns.emplace_back(
+        ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>(
+                 ctx,
+                 LinalgTilingOptions()
+                     .setTileSizes({768, 264, 768})
+                     .setInterchange({1, 2, 0}),
+                 LinalgTransformationFilter(Identifier::get("START", ctx),
+                                            Identifier::get("L2", ctx))));
     fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx),
                                           stage1Patterns);
   }
@@ -481,14 +487,14 @@ applyMatmulToVectorPatterns(FuncOp funcOp,
 }
 
 static void applyVectorTransferForwardingPatterns(FuncOp funcOp) {
-  OwningRewritePatternList forwardPattern;
+  OwningRewritePatternList forwardPattern(funcOp.getContext());
   forwardPattern.insert<LinalgCopyVTRForwardingPattern>(funcOp.getContext());
   forwardPattern.insert<LinalgCopyVTWForwardingPattern>(funcOp.getContext());
   (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern));
 }
 
 static void applyLinalgToVectorPatterns(FuncOp funcOp) {
-  OwningRewritePatternList patterns;
+  OwningRewritePatternList patterns(funcOp.getContext());
   patterns.insert<LinalgVectorizationPattern>(
       LinalgTransformationFilter()
           .addOpFilter<ContractionOpInterface, FillOp, CopyOp, GenericOp>());
@@ -497,7 +503,7 @@ static void applyLinalgToVectorPatterns(FuncOp funcOp) {
 }
 
 static void applyAffineMinSCFCanonicalizationPatterns(FuncOp funcOp) {
-  OwningRewritePatternList foldPattern;
+  OwningRewritePatternList foldPattern(funcOp.getContext());
   foldPattern.insert<AffineMinSCFCanonicalizationPattern>(funcOp.getContext());
   FrozenRewritePatternList frozenPatterns(std::move(foldPattern));
 
@@ -517,7 +523,7 @@ static Value getNeutralOfLinalgOp(OpBuilder &b, OpOperand &op) {
 
 static void applyTileAndPadPattern(FuncOp funcOp) {
   MLIRContext *context = funcOp.getContext();
-  OwningRewritePatternList tilingPattern;
+  OwningRewritePatternList tilingPattern(context);
   auto linalgTilingOptions =
       linalg::LinalgTilingOptions()
           .setTileSizes({2, 3, 4})
@@ -539,13 +545,13 @@ void TestLinalgTransforms::runOnFunction() {
   std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda};
 
   if (testPromotionOptions) {
-    OwningRewritePatternList patterns;
+    OwningRewritePatternList patterns(&getContext());
     fillPromotionCallBackPatterns(&getContext(), patterns);
     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
     return;
   }
   if (testTileAndDistributionOptions) {
-    OwningRewritePatternList patterns;
+    OwningRewritePatternList patterns(&getContext());
     fillTileAndDistributePatterns(&getContext(), patterns);
     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
     return;
index b4b8ac5..c702301 100644 (file)
@@ -32,8 +32,8 @@ struct TestMathPolynomialApproximationPass
 } // end anonymous namespace
 
 void TestMathPolynomialApproximationPass::runOnFunction() {
-  OwningRewritePatternList patterns;
-  populateMathPolynomialApproximationPatterns(patterns, &getContext());
+  OwningRewritePatternList patterns(&getContext());
+  populateMathPolynomialApproximationPatterns(patterns);
   (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
 }
 
index a76b866..8c58f6e 100644 (file)
@@ -101,25 +101,25 @@ struct TestSparsification
   /// Runs the test on a function.
   void runOnOperation() override {
     auto *ctx = &getContext();
-    OwningRewritePatternList patterns;
+    OwningRewritePatternList patterns(ctx);
     // Translate strategy flags to strategy options.
     linalg::SparsificationOptions options(parallelOption(), vectorOption(),
                                           vectorLength, typeOption(ptrType),
                                           typeOption(indType), fastOutput);
     // Apply rewriting.
-    linalg::populateSparsificationPatterns(ctx, patterns, options);
-    vector::populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
+    linalg::populateSparsificationPatterns(patterns, options);
+    vector::populateVectorToVectorCanonicalizationPatterns(patterns);
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
     // Lower sparse primitives to calls into runtime support library.
     if (lower) {
-      OwningRewritePatternList conversionPatterns;
+      OwningRewritePatternList conversionPatterns(ctx);
       ConversionTarget target(*ctx);
       target.addIllegalOp<linalg::SparseTensorFromPointerOp,
                           linalg::SparseTensorToPointersMemRefOp,
                           linalg::SparseTensorToIndicesMemRefOp,
                           linalg::SparseTensorToValuesMemRefOp>();
       target.addLegalOp<CallOp>();
-      linalg::populateSparsificationConversionPatterns(ctx, conversionPatterns);
+      linalg::populateSparsificationConversionPatterns(conversionPatterns);
       if (failed(applyPartialConversion(getOperation(), target,
                                         std::move(conversionPatterns))))
         signalPassFailure();
index f11ee13..ac0b099 100644 (file)
@@ -36,19 +36,19 @@ struct TestVectorToVectorConversion
                       llvm::cl::init(false)};
 
   void runOnFunction() override {
-    OwningRewritePatternList patterns;
     auto *ctx = &getContext();
+    OwningRewritePatternList patterns(ctx);
     if (unroll) {
       patterns.insert<UnrollVectorPattern>(
           ctx,
           UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint(
               filter));
     }
-    populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
-    populateVectorToVectorTransformationPatterns(patterns, ctx);
-    populateBubbleVectorBitCastOpPatterns(patterns, ctx);
-    populateCastAwayVectorLeadingOneDimPatterns(patterns, ctx);
-    populateSplitVectorTransferPatterns(patterns, ctx);
+    populateVectorToVectorCanonicalizationPatterns(patterns);
+    populateVectorToVectorTransformationPatterns(patterns);
+    populateBubbleVectorBitCastOpPatterns(patterns);
+    populateCastAwayVectorLeadingOneDimPatterns(patterns);
+    populateSplitVectorTransferPatterns(patterns);
     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
   }
 
@@ -70,8 +70,8 @@ private:
 struct TestVectorSlicesConversion
     : public PassWrapper<TestVectorSlicesConversion, FunctionPass> {
   void runOnFunction() override {
-    OwningRewritePatternList patterns;
-    populateVectorSlicesLoweringPatterns(patterns, &getContext());
+    OwningRewritePatternList patterns(&getContext());
+    populateVectorSlicesLoweringPatterns(patterns);
     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
   }
 };
@@ -101,7 +101,7 @@ struct TestVectorContractionConversion
       llvm::cl::init(false)};
 
   void runOnFunction() override {
-    OwningRewritePatternList patterns;
+    OwningRewritePatternList patterns(&getContext());
 
     // Test on one pattern in isolation.
     if (lowerToOuterProduct) {
@@ -138,7 +138,7 @@ struct TestVectorContractionConversion
     if (lowerToFlatTranspose)
       transposeLowering = VectorTransposeLowering::Flat;
     VectorTransformsOptions options{contractLowering, transposeLowering};
-    populateVectorContractLoweringPatterns(patterns, &getContext(), options);
+    populateVectorContractLoweringPatterns(patterns, options);
     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
   }
 };
@@ -149,7 +149,7 @@ struct TestVectorUnrollingPatterns
   TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) {}
   void runOnFunction() override {
     MLIRContext *ctx = &getContext();
-    OwningRewritePatternList patterns;
+    OwningRewritePatternList patterns(ctx);
     patterns.insert<UnrollVectorPattern>(
         ctx, UnrollVectorOptions()
                  .setNativeShape(ArrayRef<int64_t>{2, 2})
@@ -185,8 +185,8 @@ struct TestVectorUnrollingPatterns
                      return success(isa<ContractionOp>(op));
                    }));
     }
-    populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
-    populateVectorToVectorTransformationPatterns(patterns, ctx);
+    populateVectorToVectorCanonicalizationPatterns(patterns);
+    populateVectorToVectorTransformationPatterns(patterns);
     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
   }
 
@@ -210,7 +210,7 @@ struct TestVectorDistributePatterns
 
   void runOnFunction() override {
     MLIRContext *ctx = &getContext();
-    OwningRewritePatternList patterns;
+    OwningRewritePatternList patterns(ctx);
     FuncOp func = getFunction();
     func.walk([&](AddFOp op) {
       OpBuilder builder(op);
@@ -241,7 +241,7 @@ struct TestVectorDistributePatterns
       }
     });
     patterns.insert<PointwiseExtractPattern>(ctx);
-    populateVectorToVectorTransformationPatterns(patterns, ctx);
+    populateVectorToVectorTransformationPatterns(patterns);
     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
   }
 };
@@ -260,7 +260,7 @@ struct TestVectorToLoopPatterns
       llvm::cl::init(32)};
   void runOnFunction() override {
     MLIRContext *ctx = &getContext();
-    OwningRewritePatternList patterns;
+    OwningRewritePatternList patterns(ctx);
     FuncOp func = getFunction();
     func.walk([&](AddFOp op) {
       // Check that the operation type can be broken down into a loop.
@@ -301,7 +301,7 @@ struct TestVectorToLoopPatterns
       return mlir::WalkResult::interrupt();
     });
     patterns.insert<PointwiseExtractPattern>(ctx);
-    populateVectorToVectorTransformationPatterns(patterns, ctx);
+    populateVectorToVectorTransformationPatterns(patterns);
     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
   }
 };
@@ -313,7 +313,7 @@ struct TestVectorTransferUnrollingPatterns
   }
   void runOnFunction() override {
     MLIRContext *ctx = &getContext();
-    OwningRewritePatternList patterns;
+    OwningRewritePatternList patterns(ctx);
     patterns.insert<UnrollVectorPattern>(
         ctx,
         UnrollVectorOptions()
@@ -322,8 +322,8 @@ struct TestVectorTransferUnrollingPatterns
               return success(
                   isa<vector::TransferReadOp, vector::TransferWriteOp>(op));
             }));
-    populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
-    populateVectorToVectorTransformationPatterns(patterns, ctx);
+    populateVectorToVectorCanonicalizationPatterns(patterns);
+    populateVectorToVectorTransformationPatterns(patterns);
     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
   }
 };
@@ -347,7 +347,7 @@ struct TestVectorTransferFullPartialSplitPatterns
       llvm::cl::init(false)};
   void runOnFunction() override {
     MLIRContext *ctx = &getContext();
-    OwningRewritePatternList patterns;
+    OwningRewritePatternList patterns(ctx);
     VectorTransformsOptions options;
     if (useLinalgOps)
       options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy);
@@ -369,8 +369,8 @@ struct TestVectorTransferLoweringPatterns
     registry.insert<memref::MemRefDialect>();
   }
   void runOnFunction() override {
-    OwningRewritePatternList patterns;
-    populateVectorTransferLoweringPatterns(patterns, &getContext());
+    OwningRewritePatternList patterns(&getContext());
+    populateVectorTransferLoweringPatterns(patterns);
     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
   }
 };
index 721ec5e..ee36c6a 100644 (file)
@@ -52,7 +52,7 @@ TEST(PatternBenefitTest, BenefitOrder) {
     bool *called;
   };
 
-  OwningRewritePatternList patterns;
+  OwningRewritePatternList patterns(&context);
 
   bool called1 = false;
   bool called2 = false;