class AffineParallelOp;
class Location;
struct LogicalResult;
-class MLIRContext;
class OpBuilder;
class Pass;
class RewritePattern;
/// Collect a set of patterns to convert from the Affine dialect to the Standard
/// dialect, in particular convert structured affine control flow into CFG
/// branch-based control flow.
-void populateAffineToStdConversionPatterns(OwningRewritePatternList &patterns,
- MLIRContext *ctx);
+void populateAffineToStdConversionPatterns(OwningRewritePatternList &patterns);
/// Collect a set of patterns to convert vector-related Affine ops to the Vector
/// dialect.
void populateAffineToVectorConversionPatterns(
- OwningRewritePatternList &patterns, MLIRContext *ctx);
+ OwningRewritePatternList &patterns);
/// Emit code that computes the lower bound of the given affine loop using
/// standard arithmetic operations.
/// the TypeConverter, but otherwise don't care what type conversions are
/// happening.
void populateAsyncStructuralTypeConversionsAndLegality(
- MLIRContext *context, TypeConverter &typeConverter,
- OwningRewritePatternList &patterns, ConversionTarget &target);
+ TypeConverter &typeConverter, OwningRewritePatternList &patterns,
+ ConversionTarget &target);
} // namespace mlir
/// Appends to a pattern list additional patterns for translating GPU Ops to
/// SPIR-V ops. For a gpu.func to be converted, it should have a
/// spv.entry_point_abi attribute.
-void populateGPUToSPIRVPatterns(MLIRContext *context,
- SPIRVTypeConverter &typeConverter,
+void populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
OwningRewritePatternList &patterns);
} // namespace mlir
/// Appends to a pattern list additional patterns for translating Linalg ops to
/// SPIR-V ops.
-void populateLinalgToSPIRVPatterns(MLIRContext *context,
- SPIRVTypeConverter &typeConverter,
+void populateLinalgToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
OwningRewritePatternList &patterns);
} // namespace mlir
/// Populate the given list with patterns that convert from Linalg to Standard.
void populateLinalgToStandardConversionPatterns(
- OwningRewritePatternList &patterns, MLIRContext *ctx);
+ OwningRewritePatternList &patterns);
} // namespace linalg
/// Adds the conversion pattern from `scf.parallel` to `gpu.launch` to the
/// provided pattern list.
-void populateParallelLoopToGPUPatterns(OwningRewritePatternList &patterns,
- MLIRContext *ctx);
+void populateParallelLoopToGPUPatterns(OwningRewritePatternList &patterns);
/// Configures the rewrite target such that only `scf.parallel` operations that
/// are not rewritten by the provided patterns are legal.
#include <memory>
namespace mlir {
-class MLIRContext;
class Pass;
// Owning list of rewriting patterns.
/// Collects a set of patterns to lower from scf.for, scf.if, and
/// loop.terminator to CFG operations within the SPIR-V dialect.
-void populateSCFToSPIRVPatterns(MLIRContext *context,
- SPIRVTypeConverter &typeConverter,
+void populateSCFToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
ScfToSPIRVContext &scfToSPIRVContext,
OwningRewritePatternList &patterns);
} // namespace mlir
namespace mlir {
struct LogicalResult;
-class MLIRContext;
class Pass;
class RewritePattern;
/// Collect a set of patterns to lower from scf.for, scf.if, and
/// loop.terminator to CFG operations within the Standard dialect, in particular
/// convert structured control flow into CFG branch-based control flow.
-void populateLoopToStdConversionPatterns(OwningRewritePatternList &patterns,
- MLIRContext *ctx);
+void populateLoopToStdConversionPatterns(OwningRewritePatternList &patterns);
/// Creates a pass to convert scf.for, scf.if and loop.terminator ops to CFG.
std::unique_ptr<Pass> createLowerToCFGPass();
void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter);
/// Populates the given list with patterns that convert from SPIR-V to LLVM.
-void populateSPIRVToLLVMConversionPatterns(MLIRContext *context,
- LLVMTypeConverter &typeConverter,
+void populateSPIRVToLLVMConversionPatterns(LLVMTypeConverter &typeConverter,
OwningRewritePatternList &patterns);
/// Populates the given list with patterns for function conversion from SPIR-V
/// to LLVM.
void populateSPIRVToLLVMFunctionConversionPatterns(
- MLIRContext *context, LLVMTypeConverter &typeConverter,
- OwningRewritePatternList &patterns);
+ LLVMTypeConverter &typeConverter, OwningRewritePatternList &patterns);
/// Populates the given patterns for module conversion from SPIR-V to LLVM.
void populateSPIRVToLLVMModuleConversionPatterns(
- MLIRContext *context, LLVMTypeConverter &typeConverter,
- OwningRewritePatternList &patterns);
+ LLVMTypeConverter &typeConverter, OwningRewritePatternList &patterns);
} // namespace mlir
namespace mlir {
class FuncOp;
-class MLIRContext;
class ModuleOp;
template <typename T>
class OperationPass;
class OwningRewritePatternList;
void populateShapeToStandardConversionPatterns(
- OwningRewritePatternList &patterns, MLIRContext *ctx);
+ OwningRewritePatternList &patterns);
std::unique_ptr<OperationPass<ModuleOp>> createConvertShapeToStandardPass();
void populateConvertShapeConstraintsConversionPatterns(
- OwningRewritePatternList &patterns, MLIRContext *ctx);
+ OwningRewritePatternList &patterns);
std::unique_ptr<OperationPass<FuncOp>> createConvertShapeConstraintsPass();
/// Appends to a pattern list additional patterns for translating standard ops
/// to SPIR-V ops. Also adds the patterns to legalize ops not directly
/// translated to SPIR-V dialect.
-void populateStandardToSPIRVPatterns(MLIRContext *context,
- SPIRVTypeConverter &typeConverter,
+void populateStandardToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
OwningRewritePatternList &patterns);
/// Appends to a pattern list additional patterns for translating tensor ops
/// variables. SPIR-V consumers in GPU drivers may or may not optimize that
/// away. So this has implications over register pressure. Therefore, a
/// threshold is used to control when the patterns should kick in.
-void populateTensorToSPIRVPatterns(MLIRContext *context,
- SPIRVTypeConverter &typeConverter,
+void populateTensorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
int64_t byteCountThreshold,
OwningRewritePatternList &patterns);
/// Appends to a pattern list patterns to legalize ops that are not directly
/// lowered to SPIR-V.
void populateStdLegalizationPatternsForSPIRVLowering(
- MLIRContext *context, OwningRewritePatternList &patterns);
+ OwningRewritePatternList &patterns);
} // namespace mlir
/// Populates conversion passes from TOSA dialect to Linalg dialect.
void populateTosaToLinalgOnTensorsConversionPatterns(
- MLIRContext *context, OwningRewritePatternList *patterns);
+ OwningRewritePatternList *patterns);
} // namespace tosa
} // namespace mlir
std::unique_ptr<Pass> createTosaToSCF();
-void populateTosaToSCFConversionPatterns(MLIRContext *context,
- OwningRewritePatternList *patterns);
+void populateTosaToSCFConversionPatterns(OwningRewritePatternList *patterns);
/// Populates passes to convert from TOSA to SCF.
void addTosaToSCFPasses(OpPassManager &pm);
std::unique_ptr<Pass> createTosaToStandard();
void populateTosaToStandardConversionPatterns(
- MLIRContext *context, OwningRewritePatternList *patterns);
+ OwningRewritePatternList *patterns);
void populateTosaRescaleToStandardConversionPatterns(
- MLIRContext *context, OwningRewritePatternList *patterns);
+ OwningRewritePatternList *patterns);
/// Populates passes to convert from TOSA to Standard.
void addTosaToStandardPasses(OpPassManager &pm);
/// Collect a set of patterns to convert from the Vector dialect to SCF + std.
void populateVectorToSCFConversionPatterns(
- OwningRewritePatternList &patterns, MLIRContext *context,
+ OwningRewritePatternList &patterns,
const VectorTransferToSCFOptions &options = VectorTransferToSCFOptions());
/// Create a pass to convert a subset of vector ops to SCF.
/// Appends to a pattern list additional patterns for translating Vector Ops to
/// SPIR-V ops.
-void populateVectorToSPIRVPatterns(MLIRContext *context,
- SPIRVTypeConverter &typeConverter,
+void populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
OwningRewritePatternList &patterns);
} // namespace mlir
std::unique_ptr<OperationPass<FuncOp>> createGpuAsyncRegionPass();
/// Collect a set of patterns to rewrite all-reduce ops within the GPU dialect.
-void populateGpuAllReducePatterns(MLIRContext *context,
- OwningRewritePatternList &patterns);
+void populateGpuAllReducePatterns(OwningRewritePatternList &patterns);
/// Collect all patterns to rewrite ops within the GPU dialect.
-inline void populateGpuRewritePatterns(MLIRContext *context,
- OwningRewritePatternList &patterns) {
- populateGpuAllReducePatterns(context, patterns);
+inline void populateGpuRewritePatterns(OwningRewritePatternList &patterns) {
+ populateGpuAllReducePatterns(patterns);
}
namespace gpu {
/// Populate patterns that convert `ElementwiseMappable` ops to linalg
/// parallel loops.
void populateElementwiseToLinalgConversionPatterns(
- OwningRewritePatternList &patterns, MLIRContext *ctx);
+ OwningRewritePatternList &patterns);
/// Create a pass to conver named Linalg operations to Linalg generic
/// operations.
/// producer (consumer) generic operation by expanding the dimensionality of the
/// loop in the generic op.
void populateFoldReshapeOpsByExpansionPatterns(
- MLIRContext *context, OwningRewritePatternList &patterns);
+ OwningRewritePatternList &patterns);
/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
/// producer (consumer) generic/indexed_generic operation by linearizing the
/// indexing map used to access the source (target) of the reshape operation in
/// the generic/indexed_generic operation.
void populateFoldReshapeOpsByLinearizationPatterns(
- MLIRContext *context, OwningRewritePatternList &patterns);
+ OwningRewritePatternList &patterns);
/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
/// producer (consumer) generic/indexed_generic operation by linearizing the
/// the tensor reshape involved is collapsing (introducing) unit-extent
/// dimensions.
void populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
- MLIRContext *context, OwningRewritePatternList &patterns);
+ OwningRewritePatternList &patterns);
/// Patterns for fusing linalg operation on tensors.
-void populateLinalgTensorOpsFusionPatterns(MLIRContext *context,
- OwningRewritePatternList &patterns);
+void populateLinalgTensorOpsFusionPatterns(OwningRewritePatternList &patterns);
/// Patterns to fold unit-extent dimensions in operands/results of linalg ops on
/// tensors.
void populateLinalgFoldUnitExtentDimsPatterns(
- MLIRContext *context, OwningRewritePatternList &patterns);
+ OwningRewritePatternList &patterns);
//===----------------------------------------------------------------------===//
// Registration
typename = std::enable_if_t<std::is_member_function_pointer<
decltype(&ConcreteOpType::getOperationName)>::value>>
void sfinae_enqueue(OwningRewritePatternList &patternList, OptionsType options,
- MLIRContext *context, StringRef opName,
- linalg::LinalgTransformationFilter m) {
+ StringRef opName, linalg::LinalgTransformationFilter m) {
assert(opName == ConcreteOpType::getOperationName() &&
"explicit name must match ConcreteOpType::getOperationName");
- patternList.insert<PatternType<ConcreteOpType>>(context, options, m);
+ patternList.insert<PatternType<ConcreteOpType>>(patternList.getContext(),
+ options, m);
}
/// SFINAE: Enqueue helper for OpType that do not have a `getOperationName`
template <template <typename> class PatternType, typename OpType,
typename OptionsType>
void sfinae_enqueue(OwningRewritePatternList &patternList, OptionsType options,
- MLIRContext *context, StringRef opName,
- linalg::LinalgTransformationFilter m) {
+ StringRef opName, linalg::LinalgTransformationFilter m) {
assert(!opName.empty() && "opName must not be empty");
- patternList.insert<PatternType<OpType>>(opName, context, options, m);
+ patternList.insert<PatternType<OpType>>(opName, patternList.getContext(),
+ options, m);
}
template <typename PatternType, typename OpType, typename OptionsType>
void enqueue(OwningRewritePatternList &patternList, OptionsType options,
- MLIRContext *context, StringRef opName,
- linalg::LinalgTransformationFilter m) {
+ StringRef opName, linalg::LinalgTransformationFilter m) {
if (!opName.empty())
- patternList.insert<PatternType>(opName, context, options, m);
+ patternList.insert<PatternType>(opName, patternList.getContext(), options,
+ m);
else
patternList.insert<PatternType>(m.addOpFilter<OpType>(), options);
}
/// Promotion transformation enqueues a particular stage-1 pattern for
/// `Tile<LinalgOpType>`with the appropriate `options`.
-template <typename LinalgOpType> struct Tile : public Transformation {
+template <typename LinalgOpType>
+struct Tile : public Transformation {
explicit Tile(linalg::LinalgTilingOptions options,
linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
: Transformation(f), opName(LinalgOpType::getOperationName()),
OwningRewritePatternList
buildRewritePatterns(MLIRContext *context,
linalg::LinalgTransformationFilter m) override {
- OwningRewritePatternList tilingPatterns;
+ OwningRewritePatternList tilingPatterns(context);
sfinae_enqueue<linalg::LinalgTilingPattern, LinalgOpType>(
- tilingPatterns, options, context, opName, m);
+ tilingPatterns, options, opName, m);
return tilingPatterns;
}
/// Promotion transformation enqueues a particular stage-1 pattern for
/// `Promote<LinalgOpType>`with the appropriate `options`.
-template <typename LinalgOpType> struct Promote : public Transformation {
+template <typename LinalgOpType>
+struct Promote : public Transformation {
explicit Promote(
linalg::LinalgPromotionOptions options,
linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
OwningRewritePatternList
buildRewritePatterns(MLIRContext *context,
linalg::LinalgTransformationFilter m) override {
- OwningRewritePatternList promotionPatterns;
+ OwningRewritePatternList promotionPatterns(context);
sfinae_enqueue<linalg::LinalgPromotionPattern, LinalgOpType>(
- promotionPatterns, options, context, opName, m);
+ promotionPatterns, options, opName, m);
return promotionPatterns;
}
OwningRewritePatternList
buildRewritePatterns(MLIRContext *context,
linalg::LinalgTransformationFilter m) override {
- OwningRewritePatternList vectorizationPatterns;
+ OwningRewritePatternList vectorizationPatterns(context);
enqueue<linalg::LinalgVectorizationPattern, LinalgOpType>(
- vectorizationPatterns, options, context, opName, m);
+ vectorizationPatterns, options, opName, m);
vectorizationPatterns.insert<linalg::LinalgCopyVTRForwardingPattern,
linalg::LinalgCopyVTWForwardingPattern>(
context, /*benefit=*/2);
ArrayRef<int64_t> tileSizes);
/// Populates the given list with patterns to bufferize linalg ops.
-void populateLinalgBufferizePatterns(MLIRContext *context,
- BufferizeTypeConverter &converter,
+void populateLinalgBufferizePatterns(BufferizeTypeConverter &converter,
OwningRewritePatternList &patterns);
/// Performs standalone tiling of a single LinalgOp by `tileSizes`.
OwningRewritePatternList
getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx);
void populateLinalgTilingCanonicalizationPatterns(
- OwningRewritePatternList &patterns, MLIRContext *ctx);
+ OwningRewritePatternList &patterns);
/// Base pattern that applied the tiling transformation specified by `options`.
/// Abort and return failure in 2 cases:
typename = std::enable_if_t<detect_has_get_operation_name<OpType>::value>,
typename = void>
void insertVectorizationPatternImpl(OwningRewritePatternList &patternList,
- MLIRContext *context,
linalg::LinalgVectorizationOptions options,
linalg::LinalgTransformationFilter f) {
patternList.insert<linalg::LinalgVectorizationPattern>(
- OpType::getOperationName(), context, options, f);
+ OpType::getOperationName(), patternList.getContext(), options, f);
}
/// SFINAE helper for single C++ class without a `getOperationName` method (e.g.
template <typename OpType, typename = std::enable_if_t<
!detect_has_get_operation_name<OpType>::value>>
void insertVectorizationPatternImpl(OwningRewritePatternList &patternList,
- MLIRContext *context,
linalg::LinalgVectorizationOptions options,
linalg::LinalgTransformationFilter f) {
patternList.insert<linalg::LinalgVectorizationPattern>(
/// Variadic helper function to insert vectorization patterns for C++ ops.
template <typename... OpTypes>
void insertVectorizationPatterns(OwningRewritePatternList &patternList,
- MLIRContext *context,
linalg::LinalgVectorizationOptions options,
linalg::LinalgTransformationFilter f =
linalg::LinalgTransformationFilter()) {
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
- (void)std::initializer_list<int>{0, (insertVectorizationPatternImpl<OpTypes>(
- patternList, context, options, f),
- 0)...};
+ (void)std::initializer_list<int>{
+ 0, (insertVectorizationPatternImpl<OpTypes>(
+ patternList, patternList.getContext(), options, f),
+ 0)...};
}
///
/// Populates `patterns` with patterns to convert spec-generated named ops to
/// linalg.generic ops.
void populateLinalgNamedOpsGeneralizationPatterns(
- MLIRContext *context, OwningRewritePatternList &patterns,
+ OwningRewritePatternList &patterns,
LinalgTransformationFilter filter = LinalgTransformationFilter());
/// Populates `patterns` with patterns to convert linalg.conv ops to
/// linalg.generic ops.
void populateLinalgConvGeneralizationPatterns(
- MLIRContext *context, OwningRewritePatternList &patterns,
+ OwningRewritePatternList &patterns,
LinalgTransformationFilter filter = LinalgTransformationFilter());
//===----------------------------------------------------------------------===//
PatternRewriter &rewriter) const override;
};
- /// Helper struct to return the results of `substituteMin`.
+/// Helper struct to return the results of `substituteMin`.
struct AffineMapAndOperands {
AffineMap map;
SmallVector<Value> dims;
/// Return a new AffineMap, dims and symbols that have been canonicalized and
/// simplified.
AffineMapAndOperands substituteMin(
- AffineMinOp affineMinOp,
- llvm::function_ref<bool(Operation *)> substituteOperation = nullptr);
+ AffineMinOp affineMinOp,
+ llvm::function_ref<bool(Operation *)> substituteOperation = nullptr);
/// Converts Convolution op into vector contraction.
///
/// Sets up sparsification rewriting rules with the given options.
void populateSparsificationPatterns(
- MLIRContext *context, OwningRewritePatternList &patterns,
+ OwningRewritePatternList &patterns,
const SparsificationOptions &options = SparsificationOptions());
/// Sets up sparsification conversion rules with the given options.
void populateSparsificationConversionPatterns(
- MLIRContext *context, OwningRewritePatternList &patterns);
+ OwningRewritePatternList &patterns);
} // namespace linalg
} // namespace mlir
#ifndef MLIR_DIALECT_MATH_TRANSFORMS_PASSES_H_
#define MLIR_DIALECT_MATH_TRANSFORMS_PASSES_H_
-#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/Bufferize.h"
-
namespace mlir {
class OwningRewritePatternList;
-void populateExpandTanhPattern(OwningRewritePatternList &patterns,
- MLIRContext *ctx);
+void populateExpandTanhPattern(OwningRewritePatternList &patterns);
void populateMathPolynomialApproximationPatterns(
- OwningRewritePatternList &patterns, MLIRContext *ctx);
+ OwningRewritePatternList &patterns);
} // namespace mlir
/// corresponding scf.yield ops need to update their types accordingly to the
/// TypeConverter, but otherwise don't care what type conversions are happening.
void populateSCFStructuralTypeConversionsAndLegality(
- MLIRContext *context, TypeConverter &typeConverter,
- OwningRewritePatternList &patterns, ConversionTarget &target);
+ TypeConverter &typeConverter, OwningRewritePatternList &patterns,
+ ConversionTarget &target);
} // namespace scf
} // namespace mlir
namespace mlir {
namespace spirv {
void populateSPIRVGLSLCanonicalizationPatterns(
- mlir::OwningRewritePatternList &results, mlir::MLIRContext *context);
+ mlir::OwningRewritePatternList &results);
} // namespace spirv
} // namespace mlir
/// `func` op to the SPIR-V dialect. These patterns do not handle shader
/// interface/ABI; they convert function parameters to be of SPIR-V allowed
/// types.
-void populateBuiltinFuncToSPIRVPatterns(MLIRContext *context,
- SPIRVTypeConverter &typeConverter,
+void populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
OwningRewritePatternList &patterns);
namespace spirv {
std::unique_ptr<Pass> createShapeToShapeLowering();
/// Collects a set of patterns to rewrite ops within the Shape dialect.
-void populateShapeRewritePatterns(MLIRContext *context,
- OwningRewritePatternList &patterns);
+void populateShapeRewritePatterns(OwningRewritePatternList &patterns);
// Collects a set of patterns to replace all constraints with passing witnesses.
// This is intended to then allow all ShapeConstraint related ops and data to
// canonicalization and dead code elimination.
//
// After this pass, no cstr_ operations exist.
-void populateRemoveShapeConstraintsPatterns(OwningRewritePatternList &patterns,
- MLIRContext *ctx);
+void populateRemoveShapeConstraintsPatterns(OwningRewritePatternList &patterns);
std::unique_ptr<FunctionPass> createRemoveShapeConstraintsPass();
/// Populates patterns for shape dialect structural type conversions and sets up
/// do for a structural type conversion is to update both of their types
/// consistently to the new types prescribed by the TypeConverter.
void populateShapeStructuralTypeConversionsAndLegality(
- MLIRContext *context, TypeConverter &typeConverter,
- OwningRewritePatternList &patterns, ConversionTarget &target);
+ TypeConverter &typeConverter, OwningRewritePatternList &patterns,
+ ConversionTarget &target);
// Bufferizes shape dialect ops.
//
/// Add a pattern to the given pattern list to convert the operand and result
/// types of a CallOp with the given type converter.
void populateCallOpTypeConversionPattern(OwningRewritePatternList &patterns,
- MLIRContext *ctx,
TypeConverter &converter);
/// Add a pattern to the given pattern list to rewrite branch operations to use
/// be done if the branch operation implements the BranchOpInterface. Only
/// needed for partial conversions.
void populateBranchOpInterfaceTypeConversionPattern(
- OwningRewritePatternList &patterns, MLIRContext *ctx,
- TypeConverter &converter);
+ OwningRewritePatternList &patterns, TypeConverter &converter);
/// Return true if op is a BranchOpInterface op whose operands are all legal
/// according to converter.
/// Add a pattern to the given pattern list to rewrite `return` ops to use
/// operands that have been legalized by the conversion framework.
void populateReturnOpTypeConversionPattern(OwningRewritePatternList &patterns,
- MLIRContext *ctx,
TypeConverter &converter);
/// For ReturnLike ops (except `return`), return True. If op is a `return` &&
class OwningRewritePatternList;
-void populateStdBufferizePatterns(MLIRContext *context,
- BufferizeTypeConverter &typeConverter,
+void populateStdBufferizePatterns(BufferizeTypeConverter &typeConverter,
OwningRewritePatternList &patterns);
/// Creates an instance of std bufferization pass.
std::unique_ptr<Pass> createStdExpandOpsPass();
/// Collects a set of patterns to rewrite ops within the Std dialect.
-void populateStdExpandOpsPatterns(MLIRContext *context,
- OwningRewritePatternList &patterns);
+void populateStdExpandOpsPatterns(OwningRewritePatternList &patterns);
//===----------------------------------------------------------------------===//
// Registration
class OwningRewritePatternList;
-void populateTensorBufferizePatterns(MLIRContext *context,
- BufferizeTypeConverter &typeConverter,
+void populateTensorBufferizePatterns(BufferizeTypeConverter &typeConverter,
OwningRewritePatternList &patterns);
/// Creates an instance of `tensor` dialect bufferization pass.
/// Collect a set of vector-to-vector canonicalization patterns.
void populateVectorToVectorCanonicalizationPatterns(
- OwningRewritePatternList &patterns, MLIRContext *context);
+ OwningRewritePatternList &patterns);
/// Collect a set of vector-to-vector transformation patterns.
void populateVectorToVectorTransformationPatterns(
- OwningRewritePatternList &patterns, MLIRContext *context);
+ OwningRewritePatternList &patterns);
/// Collect a set of patterns to split transfer read/write ops.
///
/// of being generic canonicalization patterns. Also one can let the
/// `ignoreFilter` to return true to fail matching for fine-grained control.
void populateSplitVectorTransferPatterns(
- OwningRewritePatternList &patterns, MLIRContext *context,
+ OwningRewritePatternList &patterns,
std::function<bool(Operation *)> ignoreFilter = nullptr);
/// Collect a set of leading one dimension removal patterns.
/// With them, there are more chances that we can cancel out extract-insert
/// pairs or forward write-read pairs.
void populateCastAwayVectorLeadingOneDimPatterns(
- OwningRewritePatternList &patterns, MLIRContext *context);
+ OwningRewritePatternList &patterns);
/// Collect a set of patterns that bubble up/down bitcast ops.
///
/// These patterns move vector.bitcast ops to be before insert ops or after
/// extract ops where suitable. With them, bitcast will happen on smaller
/// vectors and there are more chances to share extract/insert ops.
-void populateBubbleVectorBitCastOpPatterns(OwningRewritePatternList &patterns,
- MLIRContext *context);
+void populateBubbleVectorBitCastOpPatterns(OwningRewritePatternList &patterns);
/// Collect a set of vector slices transformation patterns:
/// ExtractSlicesOpLowering, InsertSlicesOpLowering
/// use for "slices" ops), this lowering removes all tuple related
/// operations as well (through DCE and folding). If tuple values
/// "leak" coming in, however, some tuple related ops will remain.
-void populateVectorSlicesLoweringPatterns(OwningRewritePatternList &patterns,
- MLIRContext *context);
+void populateVectorSlicesLoweringPatterns(OwningRewritePatternList &patterns);
/// Collect a set of transfer read/write lowering patterns.
///
/// These patterns lower transfer ops to simpler ops like `vector.load`,
/// `vector.store` and `vector.broadcast`.
-void populateVectorTransferLoweringPatterns(OwningRewritePatternList &patterns,
- MLIRContext *context);
+void populateVectorTransferLoweringPatterns(OwningRewritePatternList &patterns);
/// An attribute that specifies the combining function for `vector.contract`,
/// and `vector.reduction`.
/// These transformation express higher level vector ops in terms of more
/// elementary extraction, insertion, reduction, product, and broadcast ops.
void populateVectorContractLoweringPatterns(
- OwningRewritePatternList &patterns, MLIRContext *context,
+ OwningRewritePatternList &patterns,
VectorTransformsOptions vectorTransformOptions = VectorTransformsOptions());
/// Returns the integer type required for subscripts in the vector dialect.
PDLValue(ValueRange *value) : value(value), kind(Kind::ValueRange) {}
/// Returns true if the type of the held value is `T`.
- template <typename T> bool isa() const {
+ template <typename T>
+ bool isa() const {
assert(value && "isa<> used on a null value");
return kind == getKindOf<T>();
}
/// Cast this value to type `T`, asserts if this value is not an instance of
/// `T`.
- template <typename T> T cast() const {
+ template <typename T>
+ T cast() const {
assert(isa<T>() && "expected value to be of type `T`");
return castImpl<T>();
}
private:
/// Find the index of a given type in a range of other types.
- template <typename...> struct index_of_t;
+ template <typename...>
+ struct index_of_t;
template <typename T, typename... R>
struct index_of_t<T, T, R...> : std::integral_constant<size_t, 0> {};
template <typename T, typename F, typename... R>
: std::integral_constant<size_t, 1 + index_of_t<T, R...>::value> {};
/// Return the kind used for the given T.
- template <typename T> static Kind getKindOf() {
+ template <typename T>
+ static Kind getKindOf() {
return static_cast<Kind>(index_of_t<T, Attribute, Operation *, Type,
TypeRange, Value, ValueRange>::value);
}
using NativePatternListT = std::vector<std::unique_ptr<RewritePattern>>;
public:
- OwningRewritePatternList() = default;
+ OwningRewritePatternList(MLIRContext *context) : context(context) {}
/// Construct a OwningRewritePatternList populated with the given pattern.
- OwningRewritePatternList(std::unique_ptr<RewritePattern> pattern) {
+ OwningRewritePatternList(MLIRContext *context,
+ std::unique_ptr<RewritePattern> pattern)
+ : context(context) {
nativePatterns.emplace_back(std::move(pattern));
}
OwningRewritePatternList(PDLPatternModule &&pattern)
- : pdlPatterns(std::move(pattern)) {}
+ : context(pattern.getModule()->getContext()),
+ pdlPatterns(std::move(pattern)) {}
+
+ MLIRContext *getContext() const { return context; }
/// Return the native patterns held in this list.
NativePatternListT &getNativePatterns() { return nativePatterns; }
typename... ConstructorArgs,
typename = std::enable_if_t<sizeof...(Ts) != 0>>
OwningRewritePatternList &insert(ConstructorArg &&arg,
- ConstructorArgs &&...args) {
+ ConstructorArgs &&... args) {
// The following expands a call to emplace_back for each of the pattern
// types 'Ts'. This magic is necessary due to a limitation in the places
// that a parameter pack can be expanded in c++11.
/// Add an instance of each of the pattern types 'Ts'. Return a reference to
/// `this` for chaining insertions.
- template <typename... Ts> OwningRewritePatternList &insert() {
+ template <typename... Ts>
+ OwningRewritePatternList &insert() {
(void)std::initializer_list<int>{0, (insertImpl<Ts>(), 0)...};
return *this;
}
/// chaining insertions.
template <typename T, typename... Args>
std::enable_if_t<std::is_base_of<RewritePattern, T>::value>
- insertImpl(Args &&...args) {
+ insertImpl(Args &&... args) {
nativePatterns.emplace_back(
std::make_unique<T>(std::forward<Args>(args)...));
}
template <typename T, typename... Args>
std::enable_if_t<std::is_base_of<PDLPatternModule, T>::value>
- insertImpl(Args &&...args) {
+ insertImpl(Args &&... args) {
pdlPatterns.mergeIn(T(std::forward<Args>(args)...));
}
+ MLIRContext *const context;
NativePatternListT nativePatterns;
PDLPatternModule pdlPatterns;
};
///
/// In particular, these are the tensor_load/buffer_cast ops.
void populateEliminateBufferizeMaterializationsPatterns(
- MLIRContext *context, BufferizeTypeConverter &typeConverter,
- OwningRewritePatternList &patterns);
+ BufferizeTypeConverter &typeConverter, OwningRewritePatternList &patterns);
} // end namespace mlir
/// FunctionLike ops which use FunctionType to represent their type.
void populateFunctionLikeTypeConversionPattern(
StringRef functionLikeOpName, OwningRewritePatternList &patterns,
- MLIRContext *ctx, TypeConverter &converter);
+ TypeConverter &converter);
template <typename FuncOpT>
void populateFunctionLikeTypeConversionPattern(
- OwningRewritePatternList &patterns, MLIRContext *ctx,
- TypeConverter &converter) {
+ OwningRewritePatternList &patterns, TypeConverter &converter) {
populateFunctionLikeTypeConversionPattern(FuncOpT::getOperationName(),
- patterns, ctx, converter);
+ patterns, converter);
}
/// Add a pattern to the given pattern list to convert the signature of a FuncOp
/// with the given type converter.
void populateFuncOpTypeConversionPattern(OwningRewritePatternList &patterns,
- MLIRContext *ctx,
TypeConverter &converter);
//===----------------------------------------------------------------------===//
/// Register a legality action for the given operation.
void setOpAction(OperationName op, LegalizationAction action);
- template <typename OpT> void setOpAction(LegalizationAction action) {
+ template <typename OpT>
+ void setOpAction(LegalizationAction action) {
setOpAction(OperationName(OpT::getOperationName(), &ctx), action);
}
/// Register the given operations as legal.
- template <typename OpT> void addLegalOp() {
+ template <typename OpT>
+ void addLegalOp() {
setOpAction<OpT>(LegalizationAction::Legal);
}
- template <typename OpT, typename OpT2, typename... OpTs> void addLegalOp() {
+ template <typename OpT, typename OpT2, typename... OpTs>
+ void addLegalOp() {
addLegalOp<OpT>();
addLegalOp<OpT2, OpTs...>();
}
/// Register the given operation as dynamically legal, i.e. requiring custom
/// handling by the target via 'isDynamicallyLegal'.
- template <typename OpT> void addDynamicallyLegalOp() {
+ template <typename OpT>
+ void addDynamicallyLegalOp() {
setOpAction<OpT>(LegalizationAction::Dynamic);
}
template <typename OpT, typename OpT2, typename... OpTs>
/// Register the given operation as illegal, i.e. this operation is known to
/// not be supported by this target.
- template <typename OpT> void addIllegalOp() {
+ template <typename OpT>
+ void addIllegalOp() {
setOpAction<OpT>(LegalizationAction::Illegal);
}
- template <typename OpT, typename OpT2, typename... OpTs> void addIllegalOp() {
+ template <typename OpT, typename OpT2, typename... OpTs>
+ void addIllegalOp() {
addIllegalOp<OpT>();
addIllegalOp<OpT2, OpTs...>();
}
SmallVector<StringRef, 2> dialectNames({name, names...});
setDialectAction(dialectNames, LegalizationAction::Legal);
}
- template <typename... Args> void addLegalDialect() {
+ template <typename... Args>
+ void addLegalDialect() {
SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
setDialectAction(dialectNames, LegalizationAction::Legal);
}
SmallVector<StringRef, 2> dialectNames({name, names...});
setDialectAction(dialectNames, LegalizationAction::Illegal);
}
- template <typename... Args> void addIllegalDialect() {
+ template <typename... Args>
+ void addIllegalDialect() {
SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
setDialectAction(dialectNames, LegalizationAction::Illegal);
}
} // end namespace
void mlir::populateAffineToStdConversionPatterns(
- OwningRewritePatternList &patterns, MLIRContext *ctx) {
+ OwningRewritePatternList &patterns) {
// clang-format off
patterns.insert<
AffineApplyLowering,
AffineStoreLowering,
AffineForLowering,
AffineIfLowering,
- AffineYieldOpLowering>(ctx);
+ AffineYieldOpLowering>(patterns.getContext());
// clang-format on
}
void mlir::populateAffineToVectorConversionPatterns(
- OwningRewritePatternList &patterns, MLIRContext *ctx) {
+ OwningRewritePatternList &patterns) {
// clang-format off
patterns.insert<
AffineVectorLoadLowering,
- AffineVectorStoreLowering>(ctx);
+ AffineVectorStoreLowering>(patterns.getContext());
// clang-format on
}
namespace {
class LowerAffinePass : public ConvertAffineToStandardBase<LowerAffinePass> {
void runOnOperation() override {
- OwningRewritePatternList patterns;
- populateAffineToStdConversionPatterns(patterns, &getContext());
- populateAffineToVectorConversionPatterns(patterns, &getContext());
+ OwningRewritePatternList patterns(&getContext());
+ populateAffineToStdConversionPatterns(patterns);
+ populateAffineToVectorConversionPatterns(patterns);
ConversionTarget target(getContext());
target.addLegalDialect<memref::MemRefDialect, scf::SCFDialect,
StandardOpsDialect, VectorDialect>();
// Convert async dialect types and operations to LLVM dialect.
AsyncRuntimeTypeConverter converter;
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(ctx);
// We use conversion to LLVM type to lower async.runtime load and store
// operations.
llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes);
// Convert async types in function signatures and function calls.
- populateFuncOpTypeConversionPattern(patterns, ctx, converter);
- populateCallOpTypeConversionPattern(patterns, ctx, converter);
+ populateFuncOpTypeConversionPattern(patterns, converter);
+ populateCallOpTypeConversionPattern(patterns, converter);
// Convert return operations inside async.execute regions.
patterns.insert<ReturnOpOpConversion>(converter, ctx);
}
void mlir::populateAsyncStructuralTypeConversionsAndLegality(
- MLIRContext *context, TypeConverter &typeConverter,
- OwningRewritePatternList &patterns, ConversionTarget &target) {
+ TypeConverter &typeConverter, OwningRewritePatternList &patterns,
+ ConversionTarget &target) {
typeConverter.addConversion([&](TokenType type) { return type; });
typeConverter.addConversion([&](ValueType type) {
return ValueType::get(typeConverter.convertType(type.getValueType()));
patterns
.insert<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>(
- typeConverter, context);
+ typeConverter, patterns.getContext());
target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>(
[&](Operation *op) { return typeConverter.isLegal(op); });
auto module = getOperation();
// Convert to the LLVM IR dialect using the converter defined above.
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(&getContext());
LLVMTypeConverter converter(&getContext());
populateComplexToLLVMConversionPatterns(converter, patterns);
void GpuToLLVMConversionPass::runOnOperation() {
LLVMTypeConverter converter(&getContext());
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(&getContext());
LLVMConversionTarget target(getContext());
populateVectorToLLVMConversionPatterns(converter, patterns);
populateStdToLLVMConversionPatterns(converter, patterns);
- populateAsyncStructuralTypeConversionsAndLegality(&getContext(), converter,
- patterns, target);
+ populateAsyncStructuralTypeConversionsAndLegality(converter, patterns,
+ target);
converter.addConversion(
[context = &converter.getContext()](gpu::AsyncTokenType type) -> Type {
return converter.convertType(MemRefType::Builder(type).setMemorySpace(0));
});
- OwningRewritePatternList patterns, llvmPatterns;
+ OwningRewritePatternList patterns(m.getContext());
+ OwningRewritePatternList llvmPatterns(m.getContext());
// Apply in-dialect lowering first. In-dialect lowering will replace ops
// which need to be lowered further, which is not supported by a single
// conversion pass.
- populateGpuRewritePatterns(m.getContext(), patterns);
+ populateGpuRewritePatterns(patterns);
(void)applyPatternsAndFoldGreedily(m, std::move(patterns));
populateStdToLLVMConversionPatterns(converter, llvmPatterns);
/*useAlignedAlloc =*/false};
LLVMTypeConverter converter(m.getContext(), options);
- OwningRewritePatternList patterns, llvmPatterns;
+ OwningRewritePatternList patterns(m.getContext());
+ OwningRewritePatternList llvmPatterns(m.getContext());
- populateGpuRewritePatterns(m.getContext(), patterns);
+ populateGpuRewritePatterns(patterns);
(void)applyPatternsAndFoldGreedily(m, std::move(patterns));
populateVectorToLLVMConversionPatterns(converter, llvmPatterns);
#include "GPUToSPIRV.cpp.inc"
}
-void mlir::populateGPUToSPIRVPatterns(MLIRContext *context,
- SPIRVTypeConverter &typeConverter,
+void mlir::populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
OwningRewritePatternList &patterns) {
+ auto *context = patterns.getContext();
populateWithGenerated(context, patterns);
patterns.insert<
GPUFuncOpConversion, GPUModuleConversion, GPUReturnOpConversion,
spirv::SPIRVConversionTarget::get(targetAttr);
SPIRVTypeConverter typeConverter(targetAttr);
- OwningRewritePatternList patterns;
- populateGPUToSPIRVPatterns(context, typeConverter, patterns);
- populateStandardToSPIRVPatterns(context, typeConverter, patterns);
+ OwningRewritePatternList patterns(context);
+ populateGPUToSPIRVPatterns(typeConverter, patterns);
+ populateStandardToSPIRVPatterns(typeConverter, patterns);
if (failed(applyFullConversion(kernelModules, *target, std::move(patterns))))
return signalPassFailure();
auto module = getOperation();
// Convert to the LLVM IR dialect using the converter defined above.
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(&getContext());
LLVMTypeConverter converter(&getContext());
populateLinalgToLLVMConversionPatterns(converter, patterns);
// Pattern population
//===----------------------------------------------------------------------===//
-void mlir::populateLinalgToSPIRVPatterns(MLIRContext *context,
- SPIRVTypeConverter &typeConverter,
+void mlir::populateLinalgToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
OwningRewritePatternList &patterns) {
- patterns.insert<SingleWorkgroupReduction>(typeConverter, context);
+ patterns.insert<SingleWorkgroupReduction>(typeConverter,
+ patterns.getContext());
}
spirv::SPIRVConversionTarget::get(targetAttr);
SPIRVTypeConverter typeConverter(targetAttr);
- OwningRewritePatternList patterns;
- populateLinalgToSPIRVPatterns(context, typeConverter, patterns);
- populateBuiltinFuncToSPIRVPatterns(context, typeConverter, patterns);
+ OwningRewritePatternList patterns(context);
+ populateLinalgToSPIRVPatterns(typeConverter, patterns);
+ populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
// Allow builtin ops.
target->addLegalOp<ModuleOp, ModuleTerminatorOp>();
/// Populate the given list with patterns that convert from Linalg to Standard.
void mlir::linalg::populateLinalgToStandardConversionPatterns(
- OwningRewritePatternList &patterns, MLIRContext *ctx) {
+ OwningRewritePatternList &patterns) {
// TODO: ConvOp conversion needs to export a descriptor with relevant
// attribute values such as kernel striding and dilation.
// clang-format off
patterns.insert<
CopyOpToLibraryCallRewrite,
CopyTransposeRewrite,
- IndexedGenericOpToLibraryCallRewrite>(ctx);
+ IndexedGenericOpToLibraryCallRewrite>(patterns.getContext());
patterns.insert<LinalgOpToLibraryCallRewrite>();
// clang-format on
}
StandardOpsDialect>();
target.addLegalOp<ModuleOp, FuncOp, ModuleTerminatorOp, ReturnOp>();
target.addLegalOp<linalg::ReshapeOp, linalg::RangeOp>();
- OwningRewritePatternList patterns;
- populateLinalgToStandardConversionPatterns(patterns, &getContext());
+ OwningRewritePatternList patterns(&getContext());
+ populateLinalgToStandardConversionPatterns(patterns);
if (failed(applyFullConversion(module, target, std::move(patterns))))
signalPassFailure();
}
auto module = getOperation();
// Convert to OpenMP operations with LLVM IR dialect
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(&getContext());
LLVMTypeConverter converter(&getContext());
populateStdToLLVMConversionPatterns(converter, patterns);
populateOpenMPToLLVMConversionPatterns(converter, patterns);
return success();
}
-void mlir::populateParallelLoopToGPUPatterns(OwningRewritePatternList &patterns,
- MLIRContext *ctx) {
- patterns.insert<ParallelToGpuLaunchLowering>(ctx);
+void mlir::populateParallelLoopToGPUPatterns(
+ OwningRewritePatternList &patterns) {
+ patterns.insert<ParallelToGpuLaunchLowering>(patterns.getContext());
}
void mlir::configureParallelLoopToGPULegality(ConversionTarget &target) {
struct ParallelLoopToGpuPass
: public ConvertParallelLoopToGpuBase<ParallelLoopToGpuPass> {
void runOnOperation() override {
- OwningRewritePatternList patterns;
- populateParallelLoopToGPUPatterns(patterns, &getContext());
+ OwningRewritePatternList patterns(&getContext());
+ populateParallelLoopToGPUPatterns(patterns);
ConversionTarget target(getContext());
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
configureParallelLoopToGPULegality(target);
[](scf::YieldOp op) { return !isa<scf::ParallelOp>(op->getParentOp()); });
target.addLegalDialect<omp::OpenMPDialect>();
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(func.getContext());
patterns.insert<ParallelOpLowering>(func.getContext());
FrozenRewritePatternList frozen(std::move(patterns));
return applyPartialConversion(func, target, frozen);
// Hooks
//===----------------------------------------------------------------------===//
-void mlir::populateSCFToSPIRVPatterns(MLIRContext *context,
- SPIRVTypeConverter &typeConverter,
+void mlir::populateSCFToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
ScfToSPIRVContext &scfToSPIRVContext,
OwningRewritePatternList &patterns) {
patterns.insert<ForOpConversion, IfOpConversion, TerminatorOpConversion>(
- context, typeConverter, scfToSPIRVContext.getImpl());
+ patterns.getContext(), typeConverter, scfToSPIRVContext.getImpl());
}
SPIRVTypeConverter typeConverter(targetAttr);
ScfToSPIRVContext scfContext;
- OwningRewritePatternList patterns;
- populateSCFToSPIRVPatterns(context, typeConverter, scfContext, patterns);
- populateStandardToSPIRVPatterns(context, typeConverter, patterns);
- populateBuiltinFuncToSPIRVPatterns(context, typeConverter, patterns);
+ OwningRewritePatternList patterns(context);
+ populateSCFToSPIRVPatterns(typeConverter, scfContext, patterns);
+ populateStandardToSPIRVPatterns(typeConverter, patterns);
+ populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
if (failed(applyPartialConversion(module, *target, std::move(patterns))))
return signalPassFailure();
}
void mlir::populateLoopToStdConversionPatterns(
- OwningRewritePatternList &patterns, MLIRContext *ctx) {
+ OwningRewritePatternList &patterns) {
patterns.insert<ForLowering, IfLowering, ParallelLowering, WhileLowering>(
- ctx);
- patterns.insert<DoWhileLowering>(ctx, /*benefit=*/2);
+ patterns.getContext());
+ patterns.insert<DoWhileLowering>(patterns.getContext(), /*benefit=*/2);
}
void SCFToStandardPass::runOnOperation() {
- OwningRewritePatternList patterns;
- populateLoopToStdConversionPatterns(patterns, &getContext());
+ OwningRewritePatternList patterns(&getContext());
+ populateLoopToStdConversionPatterns(patterns);
// Configure conversion to lower out scf.for, scf.if, scf.parallel and
// scf.while. Anything else is fine.
ConversionTarget target(getContext());
/*emitCWrappers=*/true,
/*indexBitwidth=*/kDeriveIndexBitwidthFromDataLayout};
auto *context = module.getContext();
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(context);
LLVMTypeConverter typeConverter(context, options);
populateStdToLLVMConversionPatterns(typeConverter, patterns);
patterns.insert<GPULaunchLowering>(typeConverter);
}
void mlir::populateSPIRVToLLVMConversionPatterns(
- MLIRContext *context, LLVMTypeConverter &typeConverter,
- OwningRewritePatternList &patterns) {
+ LLVMTypeConverter &typeConverter, OwningRewritePatternList &patterns) {
patterns.insert<
// Arithmetic ops
DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
// Return ops
- ReturnPattern, ReturnValuePattern>(context, typeConverter);
+ ReturnPattern, ReturnValuePattern>(patterns.getContext(), typeConverter);
}
void mlir::populateSPIRVToLLVMFunctionConversionPatterns(
- MLIRContext *context, LLVMTypeConverter &typeConverter,
- OwningRewritePatternList &patterns) {
- patterns.insert<FuncConversionPattern>(context, typeConverter);
+ LLVMTypeConverter &typeConverter, OwningRewritePatternList &patterns) {
+ patterns.insert<FuncConversionPattern>(patterns.getContext(), typeConverter);
}
void mlir::populateSPIRVToLLVMModuleConversionPatterns(
- MLIRContext *context, LLVMTypeConverter &typeConverter,
- OwningRewritePatternList &patterns) {
+ LLVMTypeConverter &typeConverter, OwningRewritePatternList &patterns) {
patterns.insert<ModuleConversionPattern, ModuleEndConversionPattern>(
- context, typeConverter);
+ patterns.getContext(), typeConverter);
}
//===----------------------------------------------------------------------===//
// Encode global variable's descriptor set and binding if they exist.
encodeBindAttribute(module);
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(context);
populateSPIRVToLLVMTypeConversion(converter);
- populateSPIRVToLLVMModuleConversionPatterns(context, converter, patterns);
- populateSPIRVToLLVMConversionPatterns(context, converter, patterns);
- populateSPIRVToLLVMFunctionConversionPatterns(context, converter, patterns);
+ populateSPIRVToLLVMModuleConversionPatterns(converter, patterns);
+ populateSPIRVToLLVMConversionPatterns(converter, patterns);
+ populateSPIRVToLLVMFunctionConversionPatterns(converter, patterns);
- ConversionTarget target(getContext());
+ ConversionTarget target(*context);
target.addIllegalDialect<spirv::SPIRVDialect>();
target.addLegalDialect<LLVM::LLVMDialect>();
} // namespace
void mlir::populateConvertShapeConstraintsConversionPatterns(
- OwningRewritePatternList &patterns, MLIRContext *ctx) {
- patterns.insert<CstrBroadcastableToRequire>(ctx);
- patterns.insert<CstrEqToRequire>(ctx);
- patterns.insert<ConvertCstrRequireOp>(ctx);
+ OwningRewritePatternList &patterns) {
+ patterns.insert<CstrBroadcastableToRequire>(patterns.getContext());
+ patterns.insert<CstrEqToRequire>(patterns.getContext());
+ patterns.insert<ConvertCstrRequireOp>(patterns.getContext());
}
namespace {
auto func = getOperation();
auto *context = &getContext();
- OwningRewritePatternList patterns;
- populateConvertShapeConstraintsConversionPatterns(patterns, context);
+ OwningRewritePatternList patterns(context);
+ populateConvertShapeConstraintsConversionPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns))))
return signalPassFailure();
target.addLegalOp<CstrRequireOp, FuncOp, ModuleOp, ModuleTerminatorOp>();
// Setup conversion patterns.
- OwningRewritePatternList patterns;
- populateShapeToStandardConversionPatterns(patterns, &ctx);
+ OwningRewritePatternList patterns(&ctx);
+ populateShapeToStandardConversionPatterns(patterns);
// Apply conversion.
auto module = getOperation();
}
void mlir::populateShapeToStandardConversionPatterns(
- OwningRewritePatternList &patterns, MLIRContext *ctx) {
+ OwningRewritePatternList &patterns) {
// clang-format off
- populateWithGenerated(ctx, patterns);
+ populateWithGenerated(patterns.getContext(), patterns);
patterns.insert<
AnyOpConversion,
BinaryOpConversion<AddOp, AddIOp>,
ShapeEqOpConverter,
ShapeOfOpConversion,
SplitAtOpConversion,
- ToExtentTensorOpConversion>(ctx);
+ ToExtentTensorOpConversion>(patterns.getContext());
// clang-format on
}
llvm::DataLayout(this->dataLayout)};
LLVMTypeConverter typeConverter(&getContext(), options);
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(&getContext());
populateStdToLLVMConversionPatterns(typeConverter, patterns);
LLVMConversionTarget target(getContext());
//===----------------------------------------------------------------------===//
void mlir::populateStdLegalizationPatternsForSPIRVLowering(
- MLIRContext *context, OwningRewritePatternList &patterns) {
+ OwningRewritePatternList &patterns) {
patterns.insert<LoadOpOfSubViewFolder<memref::LoadOp>,
LoadOpOfSubViewFolder<vector::TransferReadOp>,
StoreOpOfSubViewFolder<memref::StoreOp>,
- StoreOpOfSubViewFolder<vector::TransferWriteOp>>(context);
+ StoreOpOfSubViewFolder<vector::TransferWriteOp>>(
+ patterns.getContext());
}
//===----------------------------------------------------------------------===//
} // namespace
void SPIRVLegalization::runOnOperation() {
- OwningRewritePatternList patterns;
- auto *context = &getContext();
- populateStdLegalizationPatternsForSPIRVLowering(context, patterns);
+ OwningRewritePatternList patterns(&getContext());
+ populateStdLegalizationPatternsForSPIRVLowering(patterns);
(void)applyPatternsAndFoldGreedily(getOperation()->getRegions(),
std::move(patterns));
}
//===----------------------------------------------------------------------===//
namespace mlir {
-void populateStandardToSPIRVPatterns(MLIRContext *context,
- SPIRVTypeConverter &typeConverter,
+void populateStandardToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
OwningRewritePatternList &patterns) {
+ MLIRContext *context = patterns.getContext();
+
patterns.insert<
// Math dialect operations.
// TODO: Move to separate pass.
/*benefit=*/2);
}
-void populateTensorToSPIRVPatterns(MLIRContext *context,
- SPIRVTypeConverter &typeConverter,
+void populateTensorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
int64_t byteCountThreshold,
OwningRewritePatternList &patterns) {
- patterns.insert<TensorExtractPattern>(typeConverter, context,
+ patterns.insert<TensorExtractPattern>(typeConverter, patterns.getContext(),
byteCountThreshold);
}
spirv::SPIRVConversionTarget::get(targetAttr);
SPIRVTypeConverter typeConverter(targetAttr);
- OwningRewritePatternList patterns;
- populateStandardToSPIRVPatterns(context, typeConverter, patterns);
- populateTensorToSPIRVPatterns(context, typeConverter,
+ OwningRewritePatternList patterns(context);
+ populateStandardToSPIRVPatterns(typeConverter, patterns);
+ populateTensorToSPIRVPatterns(typeConverter,
/*byteCountThreshold=*/64, patterns);
- populateBuiltinFuncToSPIRVPatterns(context, typeConverter, patterns);
+ populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
if (failed(applyPartialConversion(module, *target, std::move(patterns))))
return signalPassFailure();
} // namespace
void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
- MLIRContext *context, OwningRewritePatternList *patterns) {
+ OwningRewritePatternList *patterns) {
patterns->insert<
PointwiseConverter<tosa::AddOp>, PointwiseConverter<tosa::SubOp>,
PointwiseConverter<tosa::MulOp>, PointwiseConverter<tosa::NegateOp>,
IdentityNConverter<tosa::IdentityNOp>, ReduceConverter<tosa::ReduceMinOp>,
ReduceConverter<tosa::ReduceMaxOp>, ReduceConverter<tosa::ReduceSumOp>,
ReduceConverter<tosa::ReduceProdOp>, ConcatConverter, ReshapeConverter,
- RescaleConverter, ReverseConverter, TransposeConverter>(context);
+ RescaleConverter, ReverseConverter, TransposeConverter>(
+ patterns->getContext());
}
}
void runOnFunction() override {
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(&getContext());
ConversionTarget target(getContext());
target.addLegalDialect<linalg::LinalgDialect, memref::MemRefDialect,
StandardOpsDialect>();
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
FuncOp func = getFunction();
- mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
- func.getContext(), &patterns);
+ mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(&patterns);
if (failed(applyFullConversion(func, target, std::move(patterns))))
signalPassFailure();
}
} // namespace
void mlir::tosa::populateTosaToSCFConversionPatterns(
- MLIRContext *context, OwningRewritePatternList *patterns) {
- patterns->insert<IfOpConverter>(context);
- patterns->insert<WhileOpConverter>(context);
+ OwningRewritePatternList *patterns) {
+ patterns->insert<IfOpConverter>(patterns->getContext());
+ patterns->insert<WhileOpConverter>(patterns->getContext());
}
struct TosaToSCF : public TosaToSCFBase<TosaToSCF> {
public:
void runOnOperation() override {
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(&getContext());
ConversionTarget target(getContext());
target.addLegalDialect<tensor::TensorDialect, scf::SCFDialect>();
target.addIllegalOp<tosa::IfOp, tosa::WhileOp>();
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
auto *op = getOperation();
- mlir::tosa::populateTosaToSCFConversionPatterns(op->getContext(),
- &patterns);
+ mlir::tosa::populateTosaToSCFConversionPatterns(&patterns);
if (failed(applyPartialConversion(op, target, std::move(patterns))))
signalPassFailure();
}
} // namespace
void mlir::tosa::populateTosaToStandardConversionPatterns(
- MLIRContext *context, OwningRewritePatternList *patterns) {
+ OwningRewritePatternList *patterns) {
patterns->insert<ApplyScaleOpConverter, ConstOpConverter, SliceOpConverter>(
- context);
+ patterns->getContext());
}
void mlir::tosa::populateTosaRescaleToStandardConversionPatterns(
- MLIRContext *context, OwningRewritePatternList *patterns) {
- patterns->insert<ApplyScaleOpConverter>(context);
+ OwningRewritePatternList *patterns) {
+ patterns->insert<ApplyScaleOpConverter>(patterns->getContext());
}
struct TosaToStandard : public TosaToStandardBase<TosaToStandard> {
public:
void runOnOperation() override {
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(&getContext());
ConversionTarget target(getContext());
target.addIllegalOp<tosa::ConstOp>();
target.addIllegalOp<tosa::SliceOp>();
target.addIllegalOp<tosa::ApplyScaleOp>();
target.addLegalDialect<StandardOpsDialect>();
- auto *op = getOperation();
- mlir::tosa::populateTosaToStandardConversionPatterns(op->getContext(),
- &patterns);
- if (failed(applyPartialConversion(op, target, std::move(patterns))))
+ mlir::tosa::populateTosaToStandardConversionPatterns(&patterns);
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns))))
signalPassFailure();
}
};
// Perform progressive lowering of operations on slices and
// all contraction operations. Also applies folding and DCE.
{
- OwningRewritePatternList patterns;
- populateVectorToVectorCanonicalizationPatterns(patterns, &getContext());
- populateVectorSlicesLoweringPatterns(patterns, &getContext());
- populateVectorContractLoweringPatterns(patterns, &getContext());
+ OwningRewritePatternList patterns(&getContext());
+ populateVectorToVectorCanonicalizationPatterns(patterns);
+ populateVectorSlicesLoweringPatterns(patterns);
+ populateVectorContractLoweringPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
// Convert to the LLVM IR dialect.
LLVMTypeConverter converter(&getContext());
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(&getContext());
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
populateVectorToLLVMConversionPatterns(
converter, patterns, reassociateFPReductions, enableIndexOptimizations);
return false;
};
// Remove any ArmSVE-specific types from function signatures and results.
- populateFuncOpTypeConversionPattern(patterns, &getContext(), converter);
+ populateFuncOpTypeConversionPattern(patterns, converter);
target.addDynamicallyLegalOp<FuncOp>([hasScalableVectorType](FuncOp op) {
return !hasScalableVectorType(op.getType().getInputs()) &&
!hasScalableVectorType(op.getType().getResults());
void LowerVectorToROCDLPass::runOnOperation() {
LLVMTypeConverter converter(&getContext());
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(&getContext());
populateVectorToROCDLConversionPatterns(converter, patterns);
populateStdToLLVMConversionPatterns(converter, patterns);
}
void populateVectorToSCFConversionPatterns(
- OwningRewritePatternList &patterns, MLIRContext *context,
+ OwningRewritePatternList &patterns,
const VectorTransferToSCFOptions &options) {
patterns.insert<VectorTransferRewriter<vector::TransferReadOp>,
- VectorTransferRewriter<vector::TransferWriteOp>>(options,
- context);
+ VectorTransferRewriter<vector::TransferWriteOp>>(
+ options, patterns.getContext());
}
} // namespace mlir
}
void runOnFunction() override {
- OwningRewritePatternList patterns;
- auto *context = getFunction().getContext();
+ OwningRewritePatternList patterns(getFunction().getContext());
populateVectorToSCFConversionPatterns(
- patterns, context, VectorTransferToSCFOptions().setUnroll(fullUnroll));
+ patterns, VectorTransferToSCFOptions().setUnroll(fullUnroll));
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
};
} // namespace
-void mlir::populateVectorToSPIRVPatterns(MLIRContext *context,
- SPIRVTypeConverter &typeConverter,
+void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
OwningRewritePatternList &patterns) {
patterns.insert<VectorBitcastConvert, VectorBroadcastConvert,
VectorExtractElementOpConvert, VectorExtractOpConvert,
VectorExtractStridedSliceOpConvert, VectorFmaOpConvert,
VectorInsertElementOpConvert, VectorInsertOpConvert,
- VectorInsertStridedSliceOpConvert>(typeConverter, context);
+ VectorInsertStridedSliceOpConvert>(typeConverter,
+ patterns.getContext());
}
spirv::SPIRVConversionTarget::get(targetAttr);
SPIRVTypeConverter typeConverter(targetAttr);
- OwningRewritePatternList patterns;
- populateVectorToSPIRVPatterns(context, typeConverter, patterns);
+ OwningRewritePatternList patterns(context);
+ populateVectorToSPIRVPatterns(typeConverter, patterns);
target->addLegalOp<ModuleOp, ModuleTerminatorOp>();
target->addLegalOp<FuncOp>();
// Promoting single iteration loops could lead to simplification of
// contained load's/store's, and the latter could anyway also be
// canonicalized.
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(&getContext());
AffineLoadOp::getCanonicalizationPatterns(patterns, &getContext());
AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext());
FrozenRewritePatternList frozenPatterns(std::move(patterns));
void SimplifyAffineStructures::runOnFunction() {
auto func = getFunction();
simplifiedAttributes.clear();
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(func.getContext());
AffineForOp::getCanonicalizationPatterns(patterns, func.getContext());
AffineIfOp::getCanonicalizationPatterns(patterns, func.getContext());
AffineApplyOp::getCanonicalizationPatterns(patterns, func.getContext());
// effective (no unused operands). Since the pattern rewriter's folding is
// entangled with application of patterns, we may fold/end up erasing the op,
// in which case we return with `folded` being set.
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(ifOp.getContext());
AffineIfOp::getCanonicalizationPatterns(patterns, ifOp.getContext());
bool erased;
FrozenRewritePatternList frozenPatterns(std::move(patterns));
void AsyncParallelForPass::runOnFunction() {
MLIRContext *ctx = &getContext();
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(ctx);
patterns.insert<AsyncParallelForRewrite>(ctx, numConcurrentAsyncExecute);
if (failed(applyPatternsAndFoldGreedily(getFunction(), std::move(patterns))))
// Lower async operations to async.runtime operations.
MLIRContext *ctx = module->getContext();
- OwningRewritePatternList asyncPatterns;
+ OwningRewritePatternList asyncPatterns(ctx);
// Async lowering does not use type converter because it must preserve all
// types for async.runtime operations.
};
} // namespace
-void mlir::populateGpuAllReducePatterns(MLIRContext *context,
- OwningRewritePatternList &patterns) {
- patterns.insert<GpuAllReduceConversion>(context);
+void mlir::populateGpuAllReducePatterns(OwningRewritePatternList &patterns) {
+ patterns.insert<GpuAllReduceConversion>(patterns.getContext());
}
target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOperation);
target.addDynamicallyLegalOp<ConstantOp>(isLegalOperation);
- OwningRewritePatternList patterns;
- populateLinalgBufferizePatterns(&context, typeConverter, patterns);
+ OwningRewritePatternList patterns(&context);
+ populateLinalgBufferizePatterns(typeConverter, patterns);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
}
void mlir::linalg::populateLinalgBufferizePatterns(
- MLIRContext *context, BufferizeTypeConverter &typeConverter,
- OwningRewritePatternList &patterns) {
+ BufferizeTypeConverter &typeConverter, OwningRewritePatternList &patterns) {
patterns.insert<BufferizeAnyLinalgOp>(typeConverter);
// TODO: Drop this once tensor constants work in standard.
// clang-format off
BufferizeInitTensorOp,
SubTensorOpConverter,
SubTensorInsertOpConverter
- >(typeConverter, context);
+ >(typeConverter, patterns.getContext());
// clang-format on
}
// Programmatic splitting of slow/fast path vector transfers.
if (lateCodegenStrategyOptions.enableVectorTransferPartialRewrite) {
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(context);
patterns.insert<vector::VectorTransferFullPartialRewriter>(
context, vectorTransformsOptions);
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
// Programmatic controlled lowering of vector.contract only.
if (lateCodegenStrategyOptions.enableVectorContractLowering) {
- OwningRewritePatternList vectorContractLoweringPatterns;
+ OwningRewritePatternList vectorContractLoweringPatterns(context);
vectorContractLoweringPatterns
.insert<ContractionOpToOuterProductOpLowering,
ContractionOpToMatmulOpLowering, ContractionOpLowering>(
// Programmatic controlled lowering of vector.transfer only.
if (lateCodegenStrategyOptions.enableVectorToSCFConversion) {
- OwningRewritePatternList vectorToLoopsPatterns;
- populateVectorToSCFConversionPatterns(vectorToLoopsPatterns, context,
+ OwningRewritePatternList vectorToLoopsPatterns(context);
+ populateVectorToSCFConversionPatterns(vectorToLoopsPatterns,
vectorToSCFOptions);
(void)applyPatternsAndFoldGreedily(func, std::move(vectorToLoopsPatterns));
}
void runOnFunction() override {
auto *context = &getContext();
DetensorizeTypeConverter typeConverter;
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(context);
ConversionTarget target(*context);
target.addDynamicallyLegalOp<GenericOp>([&](GenericOp op) {
context, typeConverter);
// Since non-entry block arguments get detensorized, we also need to update
// the control flow inside the function to reflect the correct types.
- populateBranchOpInterfaceTypeConversionPattern(patterns, context,
- typeConverter);
+ populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter);
if (failed(applyFullConversion(getFunction(), target, std::move(patterns))))
signalPassFailure();
- OwningRewritePatternList canonPatterns;
+ OwningRewritePatternList canonPatterns(context);
canonPatterns.insert<ExtractFromReshapeFromElements>(context);
if (failed(applyPatternsAndFoldGreedily(getFunction(),
std::move(canonPatterns))))
/// Patterns that are used to canonicalize the use of unit-extent dims for
/// broadcasting.
void mlir::populateLinalgFoldUnitExtentDimsPatterns(
- MLIRContext *context, OwningRewritePatternList &patterns) {
+ OwningRewritePatternList &patterns) {
+ auto *context = patterns.getContext();
patterns
.insert<FoldUnitDimLoops<GenericOp>, FoldUnitDimLoops<IndexedGenericOp>,
ReplaceUnitExtentTensors<GenericOp>,
ReplaceUnitExtentTensors<IndexedGenericOp>>(context);
TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
patterns.insert<FoldReshapeOpWithUnitExtent>(context);
- populateFoldUnitDimsReshapeOpsByLinearizationPatterns(context, patterns);
+ populateFoldUnitDimsReshapeOpsByLinearizationPatterns(patterns);
}
namespace {
struct LinalgFoldUnitExtentDimsPass
: public LinalgFoldUnitExtentDimsBase<LinalgFoldUnitExtentDimsPass> {
void runOnFunction() override {
- OwningRewritePatternList patterns;
FuncOp funcOp = getFunction();
MLIRContext *context = funcOp.getContext();
+ OwningRewritePatternList patterns(context);
if (foldOneTripLoopsOnly)
patterns.insert<FoldUnitDimLoops<GenericOp>,
FoldUnitDimLoops<IndexedGenericOp>>(context);
else
- populateLinalgFoldUnitExtentDimsPatterns(context, patterns);
+ populateLinalgFoldUnitExtentDimsPatterns(patterns);
(void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
}
};
} // namespace
void mlir::populateElementwiseToLinalgConversionPatterns(
- OwningRewritePatternList &patterns, MLIRContext *) {
+ OwningRewritePatternList &patterns) {
patterns.insert<ConvertAnyElementwiseMappableOpOnRankedTensors>();
}
auto func = getOperation();
auto *context = &getContext();
ConversionTarget target(*context);
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(context);
- populateElementwiseToLinalgConversionPatterns(patterns, context);
+ populateElementwiseToLinalgConversionPatterns(patterns);
target.markUnknownOpDynamicallyLegal([](Operation *op) {
return !isElementwiseMappableOpOnRankedTensors(op);
});
struct FusionOfTensorOpsPass
: public LinalgFusionOfTensorOpsBase<FusionOfTensorOpsPass> {
void runOnOperation() override {
- OwningRewritePatternList patterns;
Operation *op = getOperation();
- populateLinalgTensorOpsFusionPatterns(op->getContext(), patterns);
+ OwningRewritePatternList patterns(op->getContext());
+ populateLinalgTensorOpsFusionPatterns(patterns);
(void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
}
};
: public LinalgFoldReshapeOpsByLinearizationBase<
FoldReshapeOpsByLinearizationPass> {
void runOnOperation() override {
- OwningRewritePatternList patterns;
Operation *op = getOperation();
- populateFoldReshapeOpsByLinearizationPatterns(op->getContext(), patterns);
+ OwningRewritePatternList patterns(op->getContext());
+ populateFoldReshapeOpsByLinearizationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
}
};
} // namespace
void mlir::populateFoldReshapeOpsByLinearizationPatterns(
- MLIRContext *context, OwningRewritePatternList &patterns) {
+ OwningRewritePatternList &patterns) {
patterns.insert<FoldProducerReshapeOpByLinearization<GenericOp, false>,
FoldProducerReshapeOpByLinearization<IndexedGenericOp, false>,
- FoldConsumerReshapeOpByLinearization<false>>(context);
+ FoldConsumerReshapeOpByLinearization<false>>(
+ patterns.getContext());
}
void mlir::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
- MLIRContext *context, OwningRewritePatternList &patterns) {
+ OwningRewritePatternList &patterns) {
patterns.insert<FoldProducerReshapeOpByLinearization<GenericOp, true>,
FoldProducerReshapeOpByLinearization<IndexedGenericOp, true>,
- FoldConsumerReshapeOpByLinearization<true>>(context);
+ FoldConsumerReshapeOpByLinearization<true>>(
+ patterns.getContext());
}
void mlir::populateFoldReshapeOpsByExpansionPatterns(
- MLIRContext *context, OwningRewritePatternList &patterns) {
+ OwningRewritePatternList &patterns) {
patterns.insert<FoldReshapeWithGenericOpByExpansion,
FoldWithProducerReshapeOpByExpansion<GenericOp>,
FoldWithProducerReshapeOpByExpansion<IndexedGenericOp>>(
- context);
+ patterns.getContext());
}
void mlir::populateLinalgTensorOpsFusionPatterns(
- MLIRContext *context, OwningRewritePatternList &patterns) {
+ OwningRewritePatternList &patterns) {
+ auto *context = patterns.getContext();
patterns.insert<FuseTensorOps<GenericOp>, FuseTensorOps<IndexedGenericOp>,
FoldSplatConstants<GenericOp>,
FoldSplatConstants<IndexedGenericOp>>(context);
- populateFoldReshapeOpsByExpansionPatterns(context, patterns);
+ populateFoldReshapeOpsByExpansionPatterns(patterns);
GenericOp::getCanonicalizationPatterns(patterns, context);
IndexedGenericOp::getCanonicalizationPatterns(patterns, context);
TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
void LinalgGeneralizationPass::runOnFunction() {
FuncOp func = getFunction();
- OwningRewritePatternList patterns;
- linalg::populateLinalgConvGeneralizationPatterns(&getContext(), patterns);
- linalg::populateLinalgNamedOpsGeneralizationPatterns(&getContext(), patterns);
+ OwningRewritePatternList patterns(&getContext());
+ linalg::populateLinalgConvGeneralizationPatterns(patterns);
+ linalg::populateLinalgNamedOpsGeneralizationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(func.getBody(), std::move(patterns));
}
}
void mlir::linalg::populateLinalgConvGeneralizationPatterns(
- MLIRContext *context, OwningRewritePatternList &patterns,
+ OwningRewritePatternList &patterns,
linalg::LinalgTransformationFilter marker) {
- patterns.insert<GeneralizeConvOp>(context, marker);
+ patterns.insert<GeneralizeConvOp>(patterns.getContext(), marker);
}
void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns(
- MLIRContext *context, OwningRewritePatternList &patterns,
+ OwningRewritePatternList &patterns,
linalg::LinalgTransformationFilter marker) {
- patterns.insert<LinalgNamedOpGeneralizationPattern>(context, marker);
+ patterns.insert<LinalgNamedOpGeneralizationPattern>(patterns.getContext(),
+ marker);
}
std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgGeneralizationPass() {
// Apply canonicalization so the newForOp + yield folds immediately, thus
// cleaning up the IR and potentially enabling more hoisting.
if (changed) {
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(func->getContext());
scf::ForOp::getCanonicalizationPatterns(patterns, func->getContext());
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
}
static void lowerLinalgToLoopsImpl(FuncOp funcOp,
ArrayRef<unsigned> interchangeVector) {
MLIRContext *context = funcOp.getContext();
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(context);
patterns.insert<LinalgRewritePattern<LoopType>>(interchangeVector);
memref::DimOp::getCanonicalizationPatterns(patterns, context);
AffineApplyOp::getCanonicalizationPatterns(patterns, context);
/// Populates the given patterns list with conversion rules required for
/// the sparsification of linear algebra operations.
void linalg::populateSparsificationConversionPatterns(
- MLIRContext *context, OwningRewritePatternList &patterns) {
+ OwningRewritePatternList &patterns) {
patterns.insert<TensorFromPointerConverter, TensorToDimSizeConverter,
TensorToPointersConverter, TensorToIndicesConverter,
- TensorToValuesConverter>(context);
+ TensorToValuesConverter>(patterns.getContext());
}
/// Populates the given patterns list with rewriting rules required for
/// the sparsification of linear algebra operations.
void linalg::populateSparsificationPatterns(
- MLIRContext *context, OwningRewritePatternList &patterns,
- const SparsificationOptions &options) {
- patterns.insert<GenericOpSparsifier>(context, options);
+ OwningRewritePatternList &patterns, const SparsificationOptions &options) {
+ patterns.insert<GenericOpSparsifier>(patterns.getContext(), options);
}
template <>
class CanonicalizationPatternList<> {
public:
- static void insert(OwningRewritePatternList &patterns, MLIRContext *ctx) {}
+ static void insert(OwningRewritePatternList &patterns) {}
};
template <typename OpTy, typename... OpTypes>
class CanonicalizationPatternList<OpTy, OpTypes...> {
public:
- static void insert(OwningRewritePatternList &patterns, MLIRContext *ctx) {
- OpTy::getCanonicalizationPatterns(patterns, ctx);
- CanonicalizationPatternList<OpTypes...>::insert(patterns, ctx);
+ static void insert(OwningRewritePatternList &patterns) {
+ OpTy::getCanonicalizationPatterns(patterns, patterns.getContext());
+ CanonicalizationPatternList<OpTypes...>::insert(patterns);
}
};
class RewritePatternList<> {
public:
static void insert(OwningRewritePatternList &patterns,
- const LinalgTilingOptions &options, MLIRContext *ctx) {}
+ const LinalgTilingOptions &options) {}
};
template <typename OpTy, typename... OpTypes>
class RewritePatternList<OpTy, OpTypes...> {
public:
static void insert(OwningRewritePatternList &patterns,
- const LinalgTilingOptions &options, MLIRContext *ctx) {
+ const LinalgTilingOptions &options) {
+ auto *ctx = patterns.getContext();
patterns.insert<LinalgTilingPattern<OpTy>>(
ctx, options,
LinalgTransformationFilter(ArrayRef<Identifier>{},
Identifier::get("tiled", ctx)));
- RewritePatternList<OpTypes...>::insert(patterns, options, ctx);
+ RewritePatternList<OpTypes...>::insert(patterns, options);
}
};
} // namespace
OwningRewritePatternList
mlir::linalg::getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx) {
- OwningRewritePatternList patterns;
- populateLinalgTilingCanonicalizationPatterns(patterns, ctx);
+ OwningRewritePatternList patterns(ctx);
+ populateLinalgTilingCanonicalizationPatterns(patterns);
return patterns;
}
void mlir::linalg::populateLinalgTilingCanonicalizationPatterns(
- OwningRewritePatternList &patterns, MLIRContext *ctx) {
+ OwningRewritePatternList &patterns) {
+ auto *ctx = patterns.getContext();
AffineApplyOp::getCanonicalizationPatterns(patterns, ctx);
AffineForOp::getCanonicalizationPatterns(patterns, ctx);
AffineMinOp::getCanonicalizationPatterns(patterns, ctx);
CanonicalizationPatternList<
#define GET_OP_LIST
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
- >::insert(patterns, ctx);
+ >::insert(patterns);
}
/// Populate the given list with patterns that apply Linalg tiling.
static void insertTilingPatterns(OwningRewritePatternList &patterns,
- const LinalgTilingOptions &options,
- MLIRContext *ctx) {
+ const LinalgTilingOptions &options) {
RewritePatternList<GenericOp, IndexedGenericOp,
#define GET_OP_LIST
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
- >::insert(patterns, options, ctx);
+ >::insert(patterns, options);
}
static void applyTilingToLoopPatterns(LinalgTilingLoopType loopType,
auto options =
LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(loopType);
MLIRContext *ctx = funcOp.getContext();
- OwningRewritePatternList patterns;
- insertTilingPatterns(patterns, options, ctx);
+ OwningRewritePatternList patterns(ctx);
+ insertTilingPatterns(patterns, options);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
(void)applyPatternsAndFoldGreedily(
funcOp, getLinalgTilingCanonicalizationPatterns(ctx));
populateVectorizationPatterns(OwningRewritePatternList &tilingPatterns,
OwningRewritePatternList &promotionPatterns,
OwningRewritePatternList &vectorizationPatterns,
- ArrayRef<int64_t> tileSizes,
- MLIRContext *context) {
+ ArrayRef<int64_t> tileSizes) {
+ auto *context = tilingPatterns.getContext();
if (tileSizes.size() < N)
return;
void mlir::linalg::populateConvVectorizationPatterns(
MLIRContext *context, SmallVectorImpl<OwningRewritePatternList> &patterns,
ArrayRef<int64_t> tileSizes) {
- OwningRewritePatternList tiling, promotion, vectorization;
+ OwningRewritePatternList tiling(context);
+ OwningRewritePatternList promotion(context);
+ OwningRewritePatternList vectorization(context);
populateVectorizationPatterns<ConvWOp, 1>(tiling, promotion, vectorization,
- tileSizes, context);
+ tileSizes);
populateVectorizationPatterns<ConvNWCOp, 3>(tiling, promotion, vectorization,
- tileSizes, context);
+ tileSizes);
populateVectorizationPatterns<ConvInputNWCFilterWCFOp, 3>(
- tiling, promotion, vectorization, tileSizes, context);
+ tiling, promotion, vectorization, tileSizes);
populateVectorizationPatterns<ConvNCWOp, 3>(tiling, promotion, vectorization,
- tileSizes, context);
+ tileSizes);
populateVectorizationPatterns<ConvInputNCWFilterWCFOp, 3>(
- tiling, promotion, vectorization, tileSizes, context);
+ tiling, promotion, vectorization, tileSizes);
populateVectorizationPatterns<ConvHWOp, 2>(tiling, promotion, vectorization,
- tileSizes, context);
+ tileSizes);
populateVectorizationPatterns<ConvNHWCOp, 4>(tiling, promotion, vectorization,
- tileSizes, context);
+ tileSizes);
populateVectorizationPatterns<ConvInputNHWCFilterHWCFOp, 4>(
- tiling, promotion, vectorization, tileSizes, context);
+ tiling, promotion, vectorization, tileSizes);
populateVectorizationPatterns<ConvNCHWOp, 4>(tiling, promotion, vectorization,
- tileSizes, context);
+ tileSizes);
populateVectorizationPatterns<ConvInputNCHWFilterHWCFOp, 4>(
- tiling, promotion, vectorization, tileSizes, context);
+ tiling, promotion, vectorization, tileSizes);
populateVectorizationPatterns<ConvDHWOp, 3>(tiling, promotion, vectorization,
- tileSizes, context);
+ tileSizes);
- populateVectorizationPatterns<ConvNDHWCOp, 5>(
- tiling, promotion, vectorization, tileSizes, context);
+ populateVectorizationPatterns<ConvNDHWCOp, 5>(tiling, promotion,
+ vectorization, tileSizes);
populateVectorizationPatterns<ConvInputNDHWCFilterDHWCFOp, 5>(
- tiling, promotion, vectorization, tileSizes, context);
+ tiling, promotion, vectorization, tileSizes);
- populateVectorizationPatterns<ConvNCDHWOp, 5>(
- tiling, promotion, vectorization, tileSizes, context);
+ populateVectorizationPatterns<ConvNCDHWOp, 5>(tiling, promotion,
+ vectorization, tileSizes);
populateVectorizationPatterns<ConvInputNCDHWFilterDHWCFOp, 5>(
- tiling, promotion, vectorization, tileSizes, context);
+ tiling, promotion, vectorization, tileSizes);
patterns.push_back(std::move(tiling));
patterns.push_back(std::move(promotion));
};
} // namespace
-void mlir::populateExpandTanhPattern(OwningRewritePatternList &patterns,
- MLIRContext *ctx) {
- patterns.insert<TanhOpConverter>(ctx);
+void mlir::populateExpandTanhPattern(OwningRewritePatternList &patterns) {
+ patterns.insert<TanhOpConverter>(patterns.getContext());
}
// that do not rely on any of the library functions.
//
//===----------------------------------------------------------------------===//
+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/Transforms/Bufferize.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include <limits.h>
+#include <climits>
using namespace mlir;
using namespace mlir::vector;
//----------------------------------------------------------------------------//
void mlir::populateMathPolynomialApproximationPatterns(
- OwningRewritePatternList &patterns, MLIRContext *ctx) {
+ OwningRewritePatternList &patterns) {
patterns.insert<TanhApproximation, LogApproximation, Log2Approximation,
- ExpApproximation>(ctx);
+ ExpApproximation>(patterns.getContext());
}
}
void ConvertConstPass::runOnFunction() {
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(&getContext());
auto func = getFunction();
auto *context = &getContext();
patterns.insert<QuantizedConstRewrite>(context);
void ConvertSimulatedQuantPass::runOnFunction() {
bool hadFailure = false;
- OwningRewritePatternList patterns;
auto func = getFunction();
+ OwningRewritePatternList patterns(func.getContext());
auto ctx = func.getContext();
patterns.insert<ConstFakeQuantRewrite, ConstFakeQuantPerAxisRewrite>(
ctx, &hadFailure);
auto *context = &getContext();
BufferizeTypeConverter typeConverter;
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(context);
ConversionTarget target(*context);
populateBufferizeMaterializationLegality(target);
- populateSCFStructuralTypeConversionsAndLegality(context, typeConverter,
- patterns, target);
+ populateSCFStructuralTypeConversionsAndLegality(typeConverter, patterns,
+ target);
if (failed(applyPartialConversion(func, target, std::move(patterns))))
return signalPassFailure();
};
} // namespace
void mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
- MLIRContext *context, TypeConverter &typeConverter,
- OwningRewritePatternList &patterns, ConversionTarget &target) {
+ TypeConverter &typeConverter, OwningRewritePatternList &patterns,
+ ConversionTarget &target) {
patterns.insert<ConvertForOpTypes, ConvertIfOpTypes, ConvertYieldOpTypes>(
- typeConverter, context);
+ typeConverter, patterns.getContext());
target.addDynamicallyLegalOp<ForOp, IfOp>([&](Operation *op) {
return typeConverter.isLegal(op->getResultTypes());
});
namespace mlir {
namespace spirv {
void populateSPIRVGLSLCanonicalizationPatterns(
- OwningRewritePatternList &results, MLIRContext *context) {
+ OwningRewritePatternList &results) {
results.insert<ConvertComparisonIntoClampSPV_FOrdLessThanOp,
ConvertComparisonIntoClampSPV_FOrdLessThanEqualOp,
ConvertComparisonIntoClampSPV_SLessThanOp,
ConvertComparisonIntoClampSPV_SLessThanEqualOp,
ConvertComparisonIntoClampSPV_ULessThanOp,
- ConvertComparisonIntoClampSPV_ULessThanEqualOp>(context);
+ ConvertComparisonIntoClampSPV_ULessThanEqualOp>(
+ results.getContext());
}
} // namespace spirv
} // namespace mlir
};
} // namespace
-static void populateSPIRVLayoutInfoPatterns(OwningRewritePatternList &patterns,
- MLIRContext *ctx) {
+static void
+populateSPIRVLayoutInfoPatterns(OwningRewritePatternList &patterns) {
patterns.insert<SPIRVGlobalVariableOpLayoutInfoDecoration,
- SPIRVAddressOfOpLayoutInfoDecoration>(ctx);
+ SPIRVAddressOfOpLayoutInfoDecoration>(patterns.getContext());
}
namespace {
void DecorateSPIRVCompositeTypeLayoutPass::runOnOperation() {
auto module = getOperation();
- OwningRewritePatternList patterns;
- populateSPIRVLayoutInfoPatterns(patterns, module.getContext());
+ OwningRewritePatternList patterns(module.getContext());
+ populateSPIRVLayoutInfoPatterns(patterns);
ConversionTarget target(*(module.getContext()));
target.addLegalDialect<spirv::SPIRVDialect>();
target.addLegalOp<FuncOp>();
return builder.create<spirv::BitcastOp>(loc, type, inputs[0]).getResult();
});
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(context);
patterns.insert<ProcessInterfaceVarABI>(typeConverter, context);
ConversionTarget target(*context);
}
void mlir::populateBuiltinFuncToSPIRVPatterns(
- MLIRContext *context, SPIRVTypeConverter &typeConverter,
- OwningRewritePatternList &patterns) {
- patterns.insert<FuncOpConversion>(typeConverter, context);
+ SPIRVTypeConverter &typeConverter, OwningRewritePatternList &patterns) {
+ patterns.insert<FuncOpConversion>(typeConverter, patterns.getContext());
}
//===----------------------------------------------------------------------===//
void runOnFunction() override {
MLIRContext &ctx = getContext();
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(&ctx);
BufferizeTypeConverter typeConverter;
- ConversionTarget target(getContext());
+ ConversionTarget target(ctx);
populateBufferizeMaterializationLegality(target);
- populateShapeStructuralTypeConversionsAndLegality(&ctx, typeConverter,
- patterns, target);
+ populateShapeStructuralTypeConversionsAndLegality(typeConverter, patterns,
+ target);
if (failed(
applyPartialConversion(getFunction(), target, std::move(patterns))))
void runOnFunction() override {
MLIRContext &ctx = getContext();
- OwningRewritePatternList patterns;
- populateRemoveShapeConstraintsPatterns(patterns, &ctx);
+ OwningRewritePatternList patterns(&ctx);
+ populateRemoveShapeConstraintsPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
} // namespace
void mlir::populateRemoveShapeConstraintsPatterns(
- OwningRewritePatternList &patterns, MLIRContext *ctx) {
- patterns.insert<RemoveCstrBroadcastableOp, RemoveCstrEqOp>(ctx);
+ OwningRewritePatternList &patterns) {
+ patterns.insert<RemoveCstrBroadcastableOp, RemoveCstrEqOp>(
+ patterns.getContext());
}
std::unique_ptr<FunctionPass> mlir::createRemoveShapeConstraintsPass() {
void ShapeToShapeLowering::runOnFunction() {
MLIRContext &ctx = getContext();
- OwningRewritePatternList patterns;
- populateShapeRewritePatterns(&ctx, patterns);
+ OwningRewritePatternList patterns(&ctx);
+ populateShapeRewritePatterns(patterns);
ConversionTarget target(getContext());
target.addLegalDialect<ShapeDialect, StandardOpsDialect>();
signalPassFailure();
}
-void mlir::populateShapeRewritePatterns(MLIRContext *context,
- OwningRewritePatternList &patterns) {
- patterns.insert<NumElementsOpConverter>(context);
+void mlir::populateShapeRewritePatterns(OwningRewritePatternList &patterns) {
+ patterns.insert<NumElementsOpConverter>(patterns.getContext());
}
std::unique_ptr<Pass> mlir::createShapeToShapeLowering() {
} // namespace
void mlir::populateShapeStructuralTypeConversionsAndLegality(
- MLIRContext *context, TypeConverter &typeConverter,
- OwningRewritePatternList &patterns, ConversionTarget &target) {
+ TypeConverter &typeConverter, OwningRewritePatternList &patterns,
+ ConversionTarget &target) {
patterns.insert<ConvertAssumingOpTypes, ConvertAssumingYieldOpTypes>(
- typeConverter, context);
+ typeConverter, patterns.getContext());
target.addDynamicallyLegalOp<AssumingOp>([&](AssumingOp op) {
return typeConverter.isLegal(op.getResultTypes());
});
};
} // namespace
-void mlir::populateStdBufferizePatterns(MLIRContext *context,
- BufferizeTypeConverter &typeConverter,
+void mlir::populateStdBufferizePatterns(BufferizeTypeConverter &typeConverter,
OwningRewritePatternList &patterns) {
- patterns.insert<BufferizeDimOp, BufferizeSelectOp>(typeConverter, context);
+ patterns.insert<BufferizeDimOp, BufferizeSelectOp>(typeConverter,
+ patterns.getContext());
}
namespace {
void runOnFunction() override {
auto *context = &getContext();
BufferizeTypeConverter typeConverter;
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(context);
ConversionTarget target(*context);
target.addLegalDialect<memref::MemRefDialect>();
target.addLegalDialect<StandardOpsDialect>();
target.addLegalDialect<scf::SCFDialect>();
- populateStdBufferizePatterns(context, typeConverter, patterns);
+ populateStdBufferizePatterns(typeConverter, patterns);
// We only bufferize the case of tensor selected type and scalar condition,
// as that boils down to a select over memref descriptors (don't need to
// touch the data).
void runOnFunction() override {
MLIRContext &ctx = getContext();
- OwningRewritePatternList patterns;
- populateStdExpandOpsPatterns(&ctx, patterns);
+ OwningRewritePatternList patterns(&ctx);
+ populateStdExpandOpsPatterns(patterns);
ConversionTarget target(getContext());
} // namespace
-void mlir::populateStdExpandOpsPatterns(MLIRContext *context,
- OwningRewritePatternList &patterns) {
+void mlir::populateStdExpandOpsPatterns(OwningRewritePatternList &patterns) {
patterns.insert<AtomicRMWOpConverter, MemRefReshapeOpConverter,
SignedCeilDivIOpConverter, SignedFloorDivIOpConverter>(
- context);
+ patterns.getContext());
}
std::unique_ptr<Pass> mlir::createStdExpandOpsPass() {
auto *context = &getContext();
BufferizeTypeConverter typeConverter;
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(context);
ConversionTarget target(*context);
- populateFuncOpTypeConversionPattern(patterns, context, typeConverter);
+ populateFuncOpTypeConversionPattern(patterns, typeConverter);
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
return typeConverter.isSignatureLegal(op.getType()) &&
typeConverter.isLegal(&op.getBody());
});
- populateCallOpTypeConversionPattern(patterns, context, typeConverter);
+ populateCallOpTypeConversionPattern(patterns, typeConverter);
target.addDynamicallyLegalOp<CallOp>(
[&](CallOp op) { return typeConverter.isLegal(op); });
- populateBranchOpInterfaceTypeConversionPattern(patterns, context,
- typeConverter);
- populateReturnOpTypeConversionPattern(patterns, context, typeConverter);
+ populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter);
+ populateReturnOpTypeConversionPattern(patterns, typeConverter);
target.addLegalOp<ModuleOp, ModuleTerminatorOp, memref::TensorLoadOp,
memref::BufferCastOp>();
} // end anonymous namespace
void mlir::populateCallOpTypeConversionPattern(
- OwningRewritePatternList &patterns, MLIRContext *ctx,
- TypeConverter &converter) {
- patterns.insert<CallOpSignatureConversion>(converter, ctx);
+ OwningRewritePatternList &patterns, TypeConverter &converter) {
+ patterns.insert<CallOpSignatureConversion>(converter, patterns.getContext());
}
namespace {
} // end anonymous namespace
void mlir::populateBranchOpInterfaceTypeConversionPattern(
- OwningRewritePatternList &patterns, MLIRContext *ctx,
- TypeConverter &typeConverter) {
- patterns.insert<BranchOpInterfaceTypeConversion>(typeConverter, ctx);
+ OwningRewritePatternList &patterns, TypeConverter &typeConverter) {
+ patterns.insert<BranchOpInterfaceTypeConversion>(typeConverter,
+ patterns.getContext());
}
bool mlir::isLegalForBranchOpInterfaceTypeConversionPattern(
}
void mlir::populateReturnOpTypeConversionPattern(
- OwningRewritePatternList &patterns, MLIRContext *ctx,
- TypeConverter &typeConverter) {
- patterns.insert<ReturnOpTypeConversion>(typeConverter, ctx);
+ OwningRewritePatternList &patterns, TypeConverter &typeConverter) {
+ patterns.insert<ReturnOpTypeConversion>(typeConverter, patterns.getContext());
}
bool mlir::isLegalForReturnOpTypeConversionPattern(Operation *op,
auto *context = &getContext();
BufferizeTypeConverter typeConverter;
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(context);
ConversionTarget target(*context);
target.addLegalDialect<memref::MemRefDialect>();
} // namespace
void mlir::populateTensorBufferizePatterns(
- MLIRContext *context, BufferizeTypeConverter &typeConverter,
- OwningRewritePatternList &patterns) {
+ BufferizeTypeConverter &typeConverter, OwningRewritePatternList &patterns) {
patterns.insert<BufferizeCastOp, BufferizeExtractOp, BufferizeFromElementsOp,
- BufferizeGenerateOp>(typeConverter, context);
+ BufferizeGenerateOp>(typeConverter, patterns.getContext());
}
namespace {
void runOnFunction() override {
auto *context = &getContext();
BufferizeTypeConverter typeConverter;
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(context);
ConversionTarget target(*context);
populateBufferizeMaterializationLegality(target);
- populateTensorBufferizePatterns(context, typeConverter, patterns);
+ populateTensorBufferizePatterns(typeConverter, patterns);
target.addIllegalOp<tensor::CastOp, tensor::ExtractOp,
tensor::FromElementsOp, tensor::GenerateOp>();
target.addLegalDialect<memref::MemRefDialect>();
public:
void runOnFunction() override {
auto func = getFunction();
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(func.getContext());
MLIRContext *ctx = func.getContext();
// Add the generated patterns to the list.
patterns.insert<ConvertTosaOp<tosa::AddOp>>(ctx);
}
void mlir::vector::populateVectorToVectorCanonicalizationPatterns(
- OwningRewritePatternList &patterns, MLIRContext *context) {
+ OwningRewritePatternList &patterns) {
patterns.insert<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder,
GatherFolder, ScatterFolder, ExpandLoadFolder,
CompressStoreFolder, StridedSliceConstantMaskFolder,
- TransposeFolder>(context);
+ TransposeFolder>(patterns.getContext());
}
#define GET_OP_CLASSES
// If broadcasting is required and the number of loaded elements is 1 then
// we can create `memref.load` instead of `vector.load`.
loadOp = rewriter.create<memref::LoadOp>(read.getLoc(), read.source(),
- read.indices());
+ read.indices());
} else {
// Otherwise create `vector.load`.
loadOp = rewriter.create<vector::LoadOp>(read.getLoc(),
// TODO: Add pattern to rewrite ExtractSlices(ConstantMaskOp).
// TODO: Add this as DRR pattern.
void mlir::vector::populateVectorToVectorTransformationPatterns(
- OwningRewritePatternList &patterns, MLIRContext *context) {
+ OwningRewritePatternList &patterns) {
patterns.insert<ShapeCastOpDecomposer, ShapeCastOpFolder, TupleGetFolderOp,
TransferReadExtractPattern, TransferWriteInsertPattern>(
- context);
+ patterns.getContext());
}
void mlir::vector::populateSplitVectorTransferPatterns(
- OwningRewritePatternList &patterns, MLIRContext *context,
+ OwningRewritePatternList &patterns,
std::function<bool(Operation *)> ignoreFilter) {
- patterns.insert<SplitTransferReadOp, SplitTransferWriteOp>(context,
- ignoreFilter);
+ patterns.insert<SplitTransferReadOp, SplitTransferWriteOp>(
+ patterns.getContext(), ignoreFilter);
}
void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
- OwningRewritePatternList &patterns, MLIRContext *context) {
+ OwningRewritePatternList &patterns) {
patterns.insert<CastAwayExtractStridedSliceLeadingOneDim,
CastAwayInsertStridedSliceLeadingOneDim,
CastAwayTransferReadLeadingOneDim,
CastAwayTransferWriteLeadingOneDim, ShapeCastOpFolder>(
- context);
+ patterns.getContext());
}
void mlir::vector::populateBubbleVectorBitCastOpPatterns(
- OwningRewritePatternList &patterns, MLIRContext *context) {
+ OwningRewritePatternList &patterns) {
patterns.insert<BubbleDownVectorBitCastForExtract,
BubbleDownBitCastForStridedSliceExtract,
- BubbleUpBitCastForStridedSliceInsert>(context);
+ BubbleUpBitCastForStridedSliceInsert>(patterns.getContext());
}
void mlir::vector::populateVectorSlicesLoweringPatterns(
- OwningRewritePatternList &patterns, MLIRContext *context) {
- patterns.insert<ExtractSlicesOpLowering, InsertSlicesOpLowering>(context);
+ OwningRewritePatternList &patterns) {
+ patterns.insert<ExtractSlicesOpLowering, InsertSlicesOpLowering>(
+ patterns.getContext());
}
void mlir::vector::populateVectorContractLoweringPatterns(
- OwningRewritePatternList &patterns, MLIRContext *context,
- VectorTransformsOptions parameters) {
+ OwningRewritePatternList &patterns, VectorTransformsOptions parameters) {
// clang-format off
patterns.insert<BroadcastOpLowering,
CreateMaskOpLowering,
OuterProductOpLowering,
ShapeCastOp2DDownCastRewritePattern,
ShapeCastOp2DUpCastRewritePattern,
- ShapeCastOpRewritePattern>(context);
+ ShapeCastOpRewritePattern>(patterns.getContext());
patterns.insert<TransposeOpLowering,
ContractionOpLowering,
ContractionOpToMatmulOpLowering,
- ContractionOpToOuterProductOpLowering>(parameters, context);
+ ContractionOpToOuterProductOpLowering>(parameters, patterns.getContext());
// clang-format on
}
void mlir::vector::populateVectorTransferLoweringPatterns(
- OwningRewritePatternList &patterns, MLIRContext *context) {
+ OwningRewritePatternList &patterns) {
patterns.insert<TransferReadToVectorLoadLowering,
- TransferWriteToVectorStoreLowering>(context);
+ TransferWriteToVectorStoreLowering>(patterns.getContext());
}
} // namespace
void mlir::populateEliminateBufferizeMaterializationsPatterns(
- MLIRContext *context, BufferizeTypeConverter &typeConverter,
- OwningRewritePatternList &patterns) {
- patterns.insert<BufferizeTensorLoadOp, BufferizeCastOp>(typeConverter,
- context);
+ BufferizeTypeConverter &typeConverter, OwningRewritePatternList &patterns) {
+ patterns.insert<BufferizeTensorLoadOp, BufferizeCastOp>(
+ typeConverter, patterns.getContext());
}
namespace {
auto *context = &getContext();
BufferizeTypeConverter typeConverter;
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(context);
ConversionTarget target(*context);
- populateEliminateBufferizeMaterializationsPatterns(context, typeConverter,
- patterns);
+ populateEliminateBufferizeMaterializationsPatterns(typeConverter, patterns);
// If all result types are legal, and all block arguments are legal (ensured
// by func conversion above), then all types in the program are legal.
/// Initialize the canonicalizer by building the set of patterns used during
/// execution.
LogicalResult initialize(MLIRContext *context) override {
- OwningRewritePatternList owningPatterns;
+ OwningRewritePatternList owningPatterns(context);
for (auto *op : context->getRegisteredOperations())
op->getCanonicalizationPatterns(owningPatterns, context);
patterns = std::move(owningPatterns);
/// A utility function to log a successful result for the given reason.
template <typename... Args>
-static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
+static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt,
+ Args &&... args) {
LLVM_DEBUG({
os.unindent();
os.startLine() << "} -> SUCCESS";
/// A utility function to log a failure result for the given reason.
template <typename... Args>
-static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
+static void logFailure(llvm::ScopedPrinter &os, StringRef fmt,
+ Args &&... args) {
LLVM_DEBUG({
os.unindent();
os.startLine() << "} -> FAILURE : "
void mlir::populateFunctionLikeTypeConversionPattern(
StringRef functionLikeOpName, OwningRewritePatternList &patterns,
- MLIRContext *ctx, TypeConverter &converter) {
- patterns.insert<FunctionLikeSignatureConversion>(functionLikeOpName, ctx,
- converter);
+ TypeConverter &converter) {
+ patterns.insert<FunctionLikeSignatureConversion>(
+ functionLikeOpName, patterns.getContext(), converter);
}
void mlir::populateFuncOpTypeConversionPattern(
- OwningRewritePatternList &patterns, MLIRContext *ctx,
- TypeConverter &converter) {
- populateFunctionLikeTypeConversionPattern<FuncOp>(patterns, ctx, converter);
+ OwningRewritePatternList &patterns, TypeConverter &converter) {
+ populateFunctionLikeTypeConversionPattern<FuncOp>(patterns, converter);
}
//===----------------------------------------------------------------------===//
if (res) {
// Simplify/canonicalize the affine.for.
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(res.getContext());
AffineForOp::getCanonicalizationPatterns(patterns, res.getContext());
bool erased;
(void)applyOpPatternsAndFold(res, std::move(patterns), &erased);
// Promoting single iteration loops could lead to simplification of
// generated load's/store's, and the latter could anyway also be
// canonicalized.
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(&getContext());
for (auto op : copyOps) {
patterns.clear();
if (isa<AffineLoadOp>(op)) {
auto target = spirv::SPIRVConversionTarget::get(targetEnv);
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(context);
patterns.insert<ConvertToAtomCmpExchangeWeak, ConvertToBitReverse,
ConvertToGroupNonUniformBallot, ConvertToModule,
ConvertToSubgroupBallot>(context);
} // namespace
void TestGLSLCanonicalizationPass::runOnOperation() {
- OwningRewritePatternList patterns;
- spirv::populateSPIRVGLSLCanonicalizationPatterns(patterns, &getContext());
+ OwningRewritePatternList patterns(&getContext());
+ spirv::populateSPIRVGLSLCanonicalizationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
struct TestPatternDriver : public PassWrapper<TestPatternDriver, FunctionPass> {
void runOnFunction() override {
- mlir::OwningRewritePatternList patterns;
+ mlir::OwningRewritePatternList patterns(&getContext());
populateWithGenerated(&getContext(), patterns);
// Verify named pattern is generated with expected name.
void runOnOperation() override {
TestTypeConverter converter;
- mlir::OwningRewritePatternList patterns;
+ mlir::OwningRewritePatternList patterns(&getContext());
populateWithGenerated(&getContext(), patterns);
patterns.insert<
TestRegionRewriteBlockMovement, TestRegionRewriteUndo, TestCreateBlock,
TestNonRootReplacement, TestBoundedRecursiveRewrite,
TestNestedOpCreationUndoRewrite>(&getContext());
patterns.insert<TestDropOpSignatureConversion>(&getContext(), converter);
- mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
- converter);
- mlir::populateCallOpTypeConversionPattern(patterns, &getContext(),
- converter);
+ mlir::populateFuncOpTypeConversionPattern(patterns, converter);
+ mlir::populateCallOpTypeConversionPattern(patterns, converter);
// Define the conversion target used for the test.
ConversionTarget target(getContext());
struct TestRemappedValue
: public mlir::PassWrapper<TestRemappedValue, FunctionPass> {
void runOnFunction() override {
- mlir::OwningRewritePatternList patterns;
+ mlir::OwningRewritePatternList patterns(&getContext());
patterns.insert<OneVResOneVOperandOp1Converter>(&getContext());
mlir::ConversionTarget target(getContext());
struct TestUnknownRootOpDriver
: public mlir::PassWrapper<TestUnknownRootOpDriver, FunctionPass> {
void runOnFunction() override {
- mlir::OwningRewritePatternList patterns;
+ mlir::OwningRewritePatternList patterns(&getContext());
patterns.insert<RemoveTestDialectOps>();
mlir::ConversionTarget target(getContext());
});
// Initialize the set of rewrite patterns.
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(&getContext());
patterns.insert<TestTypeConsumerForward, TestTypeConversionProducer,
TestSignatureConversionUndo>(converter, &getContext());
patterns.insert<TestTypeConversionAnotherProducer>(&getContext());
- mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
- converter);
+ mlir::populateFuncOpTypeConversionPattern(patterns, converter);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
: public PassWrapper<TestMergeBlocksPatternDriver,
OperationPass<ModuleOp>> {
void runOnOperation() override {
- mlir::OwningRewritePatternList patterns;
MLIRContext *context = &getContext();
+ mlir::OwningRewritePatternList patterns(context);
patterns
.insert<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>(
context);
: public PassWrapper<TestSelectiveReplacementPatternDriver,
OperationPass<>> {
void runOnOperation() override {
- mlir::OwningRewritePatternList patterns;
MLIRContext *context = &getContext();
+ mlir::OwningRewritePatternList patterns(context);
patterns.insert<TestSelectiveOpReplacementPattern>(context);
(void)applyPatternsAndFoldGreedily(getOperation()->getRegions(),
std::move(patterns));
struct TestTraitFolder : public PassWrapper<TestTraitFolder, FunctionPass> {
void runOnFunction() override {
(void)applyPatternsAndFoldGreedily(getFunction(),
- OwningRewritePatternList());
+ OwningRewritePatternList(&getContext()));
}
};
} // end anonymous namespace
};
void TosaTestQuantUtilAPI::runOnFunction() {
- OwningRewritePatternList patterns;
auto *ctx = &getContext();
+ OwningRewritePatternList patterns(ctx);
auto func = getFunction();
patterns.insert<ConvertTosaNegateOp>(ctx);
VectorTransformsOptions vectorTransformsOptions{
VectorContractLowering::Dot, VectorTransposeLowering::EltWise};
- OwningRewritePatternList vectorTransferPatterns;
+ OwningRewritePatternList vectorTransferPatterns(context);
// Pattern is not applied because rank-reducing vector transfer is not yet
// supported as can be seen in splitFullAndPartialTransferPrecondition,
// VectorTransforms.cpp
llvm_unreachable("Unexpected failure in linalg to loops pass.");
// Programmatic controlled lowering of vector.contract only.
- OwningRewritePatternList vectorContractLoweringPatterns;
+ OwningRewritePatternList vectorContractLoweringPatterns(context);
populateVectorContractLoweringPatterns(vectorContractLoweringPatterns,
- context, vectorTransformsOptions);
+ vectorTransformsOptions);
(void)applyPatternsAndFoldGreedily(module,
std::move(vectorContractLoweringPatterns));
// Programmatic controlled lowering of vector.transfer only.
- OwningRewritePatternList vectorToLoopsPatterns;
- populateVectorToSCFConversionPatterns(vectorToLoopsPatterns, context,
+ OwningRewritePatternList vectorToLoopsPatterns(context);
+ populateVectorToSCFConversionPatterns(vectorToLoopsPatterns,
VectorTransferToSCFOptions());
(void)applyPatternsAndFoldGreedily(module, std::move(vectorToLoopsPatterns));
ModuleOp m = getOperation();
// Populate type conversions.
- LLVMTypeConverter type_converter(m.getContext());
- type_converter.addConversion([&](test::TestType type) {
+ LLVMTypeConverter typeConverter(m.getContext());
+ typeConverter.addConversion([&](test::TestType type) {
return LLVM::LLVMPointerType::get(IntegerType::get(m.getContext(), 8));
});
// Populate patterns.
- OwningRewritePatternList patterns;
- populateStdToLLVMConversionPatterns(type_converter, patterns);
- patterns.insert<TestTypeProducerOpConverter>(type_converter);
+ OwningRewritePatternList patterns(m.getContext());
+ populateStdToLLVMConversionPatterns(typeConverter, patterns);
+ patterns.insert<TestTypeProducerOpConverter>(typeConverter);
// Set target.
ConversionTarget target(getContext());
TypeConverter typeConverter;
ConversionTarget target(*context);
ValueDecomposer decomposer;
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(context);
target.addLegalDialect<test::TestDialect>();
} // end anonymous namespace
void TestExpandTanhPass::runOnFunction() {
- OwningRewritePatternList patterns;
- populateExpandTanhPattern(patterns, &getContext());
+ OwningRewritePatternList patterns(&getContext());
+ populateExpandTanhPattern(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
registry.insert<StandardOpsDialect, memref::MemRefDialect>();
}
void runOnOperation() override {
- OwningRewritePatternList patterns;
- populateGpuRewritePatterns(&getContext(), patterns);
+ OwningRewritePatternList patterns(&getContext());
+ populateGpuRewritePatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
void runOnFunction() override {
MLIRContext *context = &this->getContext();
FuncOp funcOp = this->getFunction();
- OwningRewritePatternList fusionPatterns;
+ OwningRewritePatternList fusionPatterns(context);
Aliases alias;
LinalgDependenceGraph dependenceGraph =
LinalgDependenceGraph::buildDependenceGraph(alias, funcOp);
static void applyPatterns(FuncOp funcOp) {
MLIRContext *ctx = funcOp.getContext();
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(ctx);
//===--------------------------------------------------------------------===//
// Linalg tiling patterns.
FuncOp funcOp, StringRef startMarker,
SmallVectorImpl<OwningRewritePatternList> &patternsVector) {
MLIRContext *ctx = funcOp.getContext();
- patternsVector.emplace_back(std::make_unique<LinalgTilingPattern<MatmulOp>>(
- ctx,
- LinalgTilingOptions().setTileSizes({8, 12, 16}).setInterchange({1, 0, 2}),
- LinalgTransformationFilter(Identifier::get(startMarker, ctx),
- Identifier::get("L1", ctx))));
+ patternsVector.emplace_back(
+ ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>(
+ ctx,
+ LinalgTilingOptions()
+ .setTileSizes({8, 12, 16})
+ .setInterchange({1, 0, 2}),
+ LinalgTransformationFilter(Identifier::get(startMarker, ctx),
+ Identifier::get("L1", ctx))));
patternsVector.emplace_back(
+ ctx,
std::make_unique<LinalgPromotionPattern<MatmulOp>>(
ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
LinalgTransformationFilter(Identifier::get("L1", ctx),
Identifier::get("VEC", ctx))));
- patternsVector.emplace_back(std::make_unique<LinalgVectorizationPattern>(
- MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(),
- LinalgTransformationFilter(Identifier::get("VEC", ctx))));
+ patternsVector.emplace_back(
+ ctx, std::make_unique<LinalgVectorizationPattern>(
+ MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(),
+ LinalgTransformationFilter(Identifier::get("VEC", ctx))));
patternsVector.back().insert<LinalgVectorizationPattern>(
LinalgTransformationFilter().addFilter(
[](Operation *op) { return success(isa<FillOp, CopyOp>(op)); }));
fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("START", ctx),
stage1Patterns);
} else if (testMatmulToVectorPatterns2dTiling) {
- stage1Patterns.emplace_back(std::make_unique<LinalgTilingPattern<MatmulOp>>(
- ctx,
- LinalgTilingOptions()
- .setTileSizes({768, 264, 768})
- .setInterchange({1, 2, 0}),
- LinalgTransformationFilter(Identifier::get("START", ctx),
- Identifier::get("L2", ctx))));
+ stage1Patterns.emplace_back(
+ ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>(
+ ctx,
+ LinalgTilingOptions()
+ .setTileSizes({768, 264, 768})
+ .setInterchange({1, 2, 0}),
+ LinalgTransformationFilter(Identifier::get("START", ctx),
+ Identifier::get("L2", ctx))));
fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx),
stage1Patterns);
}
}
static void applyVectorTransferForwardingPatterns(FuncOp funcOp) {
- OwningRewritePatternList forwardPattern;
+ OwningRewritePatternList forwardPattern(funcOp.getContext());
forwardPattern.insert<LinalgCopyVTRForwardingPattern>(funcOp.getContext());
forwardPattern.insert<LinalgCopyVTWForwardingPattern>(funcOp.getContext());
(void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern));
}
static void applyLinalgToVectorPatterns(FuncOp funcOp) {
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(funcOp.getContext());
patterns.insert<LinalgVectorizationPattern>(
LinalgTransformationFilter()
.addOpFilter<ContractionOpInterface, FillOp, CopyOp, GenericOp>());
}
static void applyAffineMinSCFCanonicalizationPatterns(FuncOp funcOp) {
- OwningRewritePatternList foldPattern;
+ OwningRewritePatternList foldPattern(funcOp.getContext());
foldPattern.insert<AffineMinSCFCanonicalizationPattern>(funcOp.getContext());
FrozenRewritePatternList frozenPatterns(std::move(foldPattern));
static void applyTileAndPadPattern(FuncOp funcOp) {
MLIRContext *context = funcOp.getContext();
- OwningRewritePatternList tilingPattern;
+ OwningRewritePatternList tilingPattern(context);
auto linalgTilingOptions =
linalg::LinalgTilingOptions()
.setTileSizes({2, 3, 4})
std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda};
if (testPromotionOptions) {
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(&getContext());
fillPromotionCallBackPatterns(&getContext(), patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
return;
}
if (testTileAndDistributionOptions) {
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(&getContext());
fillTileAndDistributePatterns(&getContext(), patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
return;
} // end anonymous namespace
void TestMathPolynomialApproximationPass::runOnFunction() {
- OwningRewritePatternList patterns;
- populateMathPolynomialApproximationPatterns(patterns, &getContext());
+ OwningRewritePatternList patterns(&getContext());
+ populateMathPolynomialApproximationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
/// Runs the test on a function.
void runOnOperation() override {
auto *ctx = &getContext();
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(ctx);
// Translate strategy flags to strategy options.
linalg::SparsificationOptions options(parallelOption(), vectorOption(),
vectorLength, typeOption(ptrType),
typeOption(indType), fastOutput);
// Apply rewriting.
- linalg::populateSparsificationPatterns(ctx, patterns, options);
- vector::populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
+ linalg::populateSparsificationPatterns(patterns, options);
+ vector::populateVectorToVectorCanonicalizationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
// Lower sparse primitives to calls into runtime support library.
if (lower) {
- OwningRewritePatternList conversionPatterns;
+ OwningRewritePatternList conversionPatterns(ctx);
ConversionTarget target(*ctx);
target.addIllegalOp<linalg::SparseTensorFromPointerOp,
linalg::SparseTensorToPointersMemRefOp,
linalg::SparseTensorToIndicesMemRefOp,
linalg::SparseTensorToValuesMemRefOp>();
target.addLegalOp<CallOp>();
- linalg::populateSparsificationConversionPatterns(ctx, conversionPatterns);
+ linalg::populateSparsificationConversionPatterns(conversionPatterns);
if (failed(applyPartialConversion(getOperation(), target,
std::move(conversionPatterns))))
signalPassFailure();
llvm::cl::init(false)};
void runOnFunction() override {
- OwningRewritePatternList patterns;
auto *ctx = &getContext();
+ OwningRewritePatternList patterns(ctx);
if (unroll) {
patterns.insert<UnrollVectorPattern>(
ctx,
UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint(
filter));
}
- populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
- populateVectorToVectorTransformationPatterns(patterns, ctx);
- populateBubbleVectorBitCastOpPatterns(patterns, ctx);
- populateCastAwayVectorLeadingOneDimPatterns(patterns, ctx);
- populateSplitVectorTransferPatterns(patterns, ctx);
+ populateVectorToVectorCanonicalizationPatterns(patterns);
+ populateVectorToVectorTransformationPatterns(patterns);
+ populateBubbleVectorBitCastOpPatterns(patterns);
+ populateCastAwayVectorLeadingOneDimPatterns(patterns);
+ populateSplitVectorTransferPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
struct TestVectorSlicesConversion
: public PassWrapper<TestVectorSlicesConversion, FunctionPass> {
void runOnFunction() override {
- OwningRewritePatternList patterns;
- populateVectorSlicesLoweringPatterns(patterns, &getContext());
+ OwningRewritePatternList patterns(&getContext());
+ populateVectorSlicesLoweringPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
};
llvm::cl::init(false)};
void runOnFunction() override {
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(&getContext());
// Test on one pattern in isolation.
if (lowerToOuterProduct) {
if (lowerToFlatTranspose)
transposeLowering = VectorTransposeLowering::Flat;
VectorTransformsOptions options{contractLowering, transposeLowering};
- populateVectorContractLoweringPatterns(patterns, &getContext(), options);
+ populateVectorContractLoweringPatterns(patterns, options);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
};
TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) {}
void runOnFunction() override {
MLIRContext *ctx = &getContext();
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(ctx);
patterns.insert<UnrollVectorPattern>(
ctx, UnrollVectorOptions()
.setNativeShape(ArrayRef<int64_t>{2, 2})
return success(isa<ContractionOp>(op));
}));
}
- populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
- populateVectorToVectorTransformationPatterns(patterns, ctx);
+ populateVectorToVectorCanonicalizationPatterns(patterns);
+ populateVectorToVectorTransformationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
void runOnFunction() override {
MLIRContext *ctx = &getContext();
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(ctx);
FuncOp func = getFunction();
func.walk([&](AddFOp op) {
OpBuilder builder(op);
}
});
patterns.insert<PointwiseExtractPattern>(ctx);
- populateVectorToVectorTransformationPatterns(patterns, ctx);
+ populateVectorToVectorTransformationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
};
llvm::cl::init(32)};
void runOnFunction() override {
MLIRContext *ctx = &getContext();
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(ctx);
FuncOp func = getFunction();
func.walk([&](AddFOp op) {
// Check that the operation type can be broken down into a loop.
return mlir::WalkResult::interrupt();
});
patterns.insert<PointwiseExtractPattern>(ctx);
- populateVectorToVectorTransformationPatterns(patterns, ctx);
+ populateVectorToVectorTransformationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
};
}
void runOnFunction() override {
MLIRContext *ctx = &getContext();
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(ctx);
patterns.insert<UnrollVectorPattern>(
ctx,
UnrollVectorOptions()
return success(
isa<vector::TransferReadOp, vector::TransferWriteOp>(op));
}));
- populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
- populateVectorToVectorTransformationPatterns(patterns, ctx);
+ populateVectorToVectorCanonicalizationPatterns(patterns);
+ populateVectorToVectorTransformationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
};
llvm::cl::init(false)};
void runOnFunction() override {
MLIRContext *ctx = &getContext();
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(ctx);
VectorTransformsOptions options;
if (useLinalgOps)
options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy);
registry.insert<memref::MemRefDialect>();
}
void runOnFunction() override {
- OwningRewritePatternList patterns;
- populateVectorTransferLoweringPatterns(patterns, &getContext());
+ OwningRewritePatternList patterns(&getContext());
+ populateVectorTransferLoweringPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
};
bool *called;
};
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(&context);
bool called1 = false;
bool called2 = false;