void linalg::getDescriptorConverters(mlir::OwningRewritePatternList &patterns,
mlir::MLIRContext *context) {
- ConversionListBuilder<DropConsumer, RangeOpConversion, SliceOpConversion,
- ViewOpConversion>::build(patterns, context);
+ RewriteListBuilder<DropConsumer, RangeOpConversion, SliceOpConversion,
+ ViewOpConversion>::build(patterns, context);
}
namespace {
static void getConversions(mlir::OwningRewritePatternList &patterns,
mlir::MLIRContext *context) {
linalg::getDescriptorConverters(patterns, context);
- ConversionListBuilder<LoadOpConversion, StoreOpConversion>::build(patterns,
- context);
+ RewriteListBuilder<LoadOpConversion, StoreOpConversion>::build(patterns,
+ context);
}
void linalg::convertLinalg3ToLLVM(Module &module) {
// Register our patterns for rewrite by the Canonicalization framework.
void ReshapeOp::getCanonicalizationPatterns(
mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) {
- results.push_back(llvm::make_unique<SimplifyReshapeConstant>(context));
- results.push_back(llvm::make_unique<SimplifyReshapeReshape>(context));
- results.push_back(llvm::make_unique<SimplifyNullReshape>(context));
+ mlir::RewriteListBuilder<SimplifyReshapeConstant, SimplifyReshapeReshape,
+ SimplifyNullReshape>::build(results, context);
}
} // namespace toy
// Initialize the list of converters.
void initConverters(OwningRewritePatternList &patterns,
MLIRContext *context) override {
- ConversionListBuilder<MulOpConversion>::build(patterns, context);
+ RewriteListBuilder<MulOpConversion>::build(patterns, context);
}
};
/// Initialize the list of converters.
void initConverters(OwningRewritePatternList &patterns,
MLIRContext *context) override {
- ConversionListBuilder<AddOpConversion, PrintOpConversion,
- ConstantOpConversion, TransposeOpConversion,
- ReturnOpConversion>::build(patterns, context);
+ RewriteListBuilder<AddOpConversion, PrintOpConversion, ConstantOpConversion,
+ TransposeOpConversion,
+ ReturnOpConversion>::build(patterns, context);
}
/// Convert a Toy type, this gets called for block and region arguments, and
// Register our patterns for rewrite by the Canonicalization framework.
void ReshapeOp::getCanonicalizationPatterns(
mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) {
- results.push_back(llvm::make_unique<SimplifyReshapeConstant>(context));
- results.push_back(llvm::make_unique<SimplifyReshapeReshape>(context));
- results.push_back(llvm::make_unique<SimplifyNullReshape>(context));
+ mlir::RewriteListBuilder<SimplifyReshapeConstant, SimplifyReshapeReshape,
+ SimplifyNullReshape>::build(results, context);
}
namespace {
// conversion instances given a list of classes as template parameters.
// These instances will be allocated within `allocator` and their lifetime
// is managed by the Lowering class.
- return ConversionListBuilder<
+ return RewriteListBuilder<
LoadOpConversion, SliceOpConversion, StoreOpConversion,
ViewOpConversion>::build(allocator, context);
}
// This gets called once to set up operation converters.
llvm::DenseSet<DialectOpConversion *>
initConverters(MLIRContext *context) override {
- return ConversionListBuilder<MulOpConversion,
- PrintOpConversion,
- TransposeOpConversion>::build(allocator, context);
+ RewriteListBuilder<MulOpConversion, PrintOpConversion,
+ TransposeOpConversion>::build(allocator, context);
}
private:
///
bool applyPatternsGreedily(Function &fn, OwningRewritePatternList &&patterns);
+/// Helper class to create a list of rewrite patterns given a list of their
+/// types and a list of attributes perfect-forwarded to each of the conversion
+/// constructors.
+template <typename Arg, typename... Args> struct RewriteListBuilder {
+ template <typename... ConstructorArgs>
+ static void build(OwningRewritePatternList &patterns,
+ ConstructorArgs &&... constructorArgs) {
+ RewriteListBuilder<Args...>::build(
+ patterns, std::forward<ConstructorArgs>(constructorArgs)...);
+ RewriteListBuilder<Arg>::build(
+ patterns, std::forward<ConstructorArgs>(constructorArgs)...);
+ }
+};
+
+// Template specialization to stop recursion.
+template <typename Arg> struct RewriteListBuilder<Arg> {
+ template <typename... ConstructorArgs>
+ static void build(OwningRewritePatternList &patterns,
+ ConstructorArgs &&... constructorArgs) {
+ patterns.emplace_back(llvm::make_unique<Arg>(
+ std::forward<ConstructorArgs>(constructorArgs)...));
+ }
+};
} // end namespace mlir
#endif // MLIR_PATTERN_MATCH_H
using RewritePattern::rewrite;
};
-// Helper class to create a list of dialect conversion patterns given a list of
-// their types and a list of attributes perfect-forwarded to each of the
-// conversion constructors.
-template <typename Arg, typename... Args> struct ConversionListBuilder {
- template <typename... ConstructorArgs>
- static void build(OwningRewritePatternList &patterns,
- ConstructorArgs &&... constructorArgs) {
- ConversionListBuilder<Args...>::build(
- patterns, std::forward<ConstructorArgs>(constructorArgs)...);
- ConversionListBuilder<Arg>::build(
- patterns, std::forward<ConstructorArgs>(constructorArgs)...);
- }
-};
-
-// Template specialization to stop recursion.
-template <typename Arg> struct ConversionListBuilder<Arg> {
- template <typename... ConstructorArgs>
- static void build(OwningRewritePatternList &patterns,
- ConstructorArgs &&... constructorArgs) {
- patterns.emplace_back(llvm::make_unique<Arg>(
- std::forward<ConstructorArgs>(constructorArgs)...));
- }
-};
-
/// Base class for dialect conversion interface. Specific converters must
/// derive this class and implement the pure virtual functions.
///
module = &llvmDialect->getLLVMModule();
// FIXME: this should be tablegen'ed
- ConversionListBuilder<
+ RewriteListBuilder<
AddFOpLowering, AddIOpLowering, AndOpLowering, AllocOpLowering,
BranchOpLowering, CallIndirectOpLowering, CallOpLowering, CmpIOpLowering,
CondBranchOpLowering, ConstLLVMOpLowering, DeallocOpLowering,
class Lowering : public LLVMLowering {
protected:
void initAdditionalConverters(OwningRewritePatternList &patterns) override {
- return ConversionListBuilder<
- BufferAllocOpConversion, BufferDeallocOpConversion,
- BufferSizeOpConversion, DimOpConversion, DotOpConversion,
- LoadOpConversion, RangeOpConversion, RangeIntersectOpConversion,
- SliceOpConversion, StoreOpConversion,
- ViewOpConversion>::build(patterns, llvmDialect->getContext(), *this);
+ RewriteListBuilder<BufferAllocOpConversion, BufferDeallocOpConversion,
+ BufferSizeOpConversion, DimOpConversion, DotOpConversion,
+ LoadOpConversion, RangeOpConversion,
+ RangeIntersectOpConversion, SliceOpConversion,
+ StoreOpConversion,
+ ViewOpConversion>::build(patterns,
+ llvmDialect->getContext(),
+ *this);
}
Type convertAdditionalType(Type t) override {
void AllocOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
- results.push_back(llvm::make_unique<SimplifyAllocConst>(context));
- results.push_back(llvm::make_unique<SimplifyDeadAlloc>(context));
+ RewriteListBuilder<SimplifyAllocConst, SimplifyDeadAlloc>::build(results,
+ context);
}
//===----------------------------------------------------------------------===//