From 68250edbfacb3b0d355ee633752272cb140848ac Mon Sep 17 00:00:00 2001 From: River Riddle Date: Sun, 19 May 2019 20:54:13 -0700 Subject: [PATCH] NFC: Tidy up DialectConversion.cpp and rename DialectOpConversion to DialectConversionPattern. -- PiperOrigin-RevId: 248980810 --- .../Linalg1/include/linalg1/ConvertToLLVMDialect.h | 2 +- .../Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp | 19 +++++++------ .../Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp | 4 +-- mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp | 6 ++-- mlir/examples/toy/Ch5/mlir/LateLowering.cpp | 26 +++++++++-------- mlir/g3doc/Tutorials/Linalg/LLVMConversion.md | 14 ++++----- mlir/g3doc/Tutorials/Toy/Ch-5.md | 8 +++--- mlir/include/mlir/LLVMIR/LLVMLowering.h | 2 +- mlir/include/mlir/Transforms/DialectConversion.h | 33 ++++++++-------------- .../lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp | 2 +- mlir/lib/Transforms/DialectConversion.cpp | 32 +++++++++++++++------ 11 files changed, 81 insertions(+), 67 deletions(-) diff --git a/mlir/examples/Linalg/Linalg1/include/linalg1/ConvertToLLVMDialect.h b/mlir/examples/Linalg/Linalg1/include/linalg1/ConvertToLLVMDialect.h index 2abe6ba..cf5d7ef 100644 --- a/mlir/examples/Linalg/Linalg1/include/linalg1/ConvertToLLVMDialect.h +++ b/mlir/examples/Linalg/Linalg1/include/linalg1/ConvertToLLVMDialect.h @@ -25,7 +25,7 @@ namespace mlir { class DialectConversion; -class DialectOpConversion; +class DialectConversionPattern; class MLIRContext; class Module; class RewritePattern; diff --git a/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp b/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp index 0b8d6bd..a23828d 100644 --- a/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp @@ -142,10 +142,11 @@ static ArrayAttr makePositionAttr(FuncBuilder &builder, } // RangeOp creates a new range descriptor. -class RangeOpConversion : public DialectOpConversion { +class RangeOpConversion : public DialectConversionPattern { public: explicit RangeOpConversion(MLIRContext *context) - : DialectOpConversion(linalg::RangeOp::getOperationName(), 1, context) {} + : DialectConversionPattern(linalg::RangeOp::getOperationName(), 1, + context) {} void rewrite(Operation *op, ArrayRef operands, PatternRewriter &rewriter) const override { @@ -168,10 +169,11 @@ public: } }; -class ViewOpConversion : public DialectOpConversion { +class ViewOpConversion : public DialectConversionPattern { public: explicit ViewOpConversion(MLIRContext *context) - : DialectOpConversion(linalg::ViewOp::getOperationName(), 1, context) {} + : DialectConversionPattern(linalg::ViewOp::getOperationName(), 1, + context) {} void rewrite(Operation *op, ArrayRef operands, PatternRewriter &rewriter) const override { @@ -294,10 +296,11 @@ public: } }; -class SliceOpConversion : public DialectOpConversion { +class SliceOpConversion : public DialectConversionPattern { public: explicit SliceOpConversion(MLIRContext *context) - : DialectOpConversion(linalg::SliceOp::getOperationName(), 1, context) {} + : DialectConversionPattern(linalg::SliceOp::getOperationName(), 1, + context) {} void rewrite(Operation *op, ArrayRef operands, PatternRewriter &rewriter) const override { @@ -387,10 +390,10 @@ public: // When converting the "some_consumer" operation, don't emit anything and // effectively drop it. -class DropConsumer : public DialectOpConversion { +class DropConsumer : public DialectConversionPattern { public: explicit DropConsumer(MLIRContext *context) - : DialectOpConversion("some_consumer", 1, context) {} + : DialectConversionPattern("some_consumer", 1, context) {} void rewrite(Operation *op, ArrayRef operands, PatternRewriter &rewriter) const override { diff --git a/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp index b475ed2..98bd867 100644 --- a/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp +++ b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp @@ -52,10 +52,10 @@ namespace { // Common functionality for Linalg LoadOp and StoreOp conversion to the // LLVM IR Dialect. template -class LoadStoreOpConversion : public DialectOpConversion { +class LoadStoreOpConversion : public DialectConversionPattern { public: explicit LoadStoreOpConversion(MLIRContext *context) - : DialectOpConversion(Op::getOperationName(), 1, context) {} + : DialectConversionPattern(Op::getOperationName(), 1, context) {} using Base = LoadStoreOpConversion; // Compute the pointer to an element of the buffer underlying the view given diff --git a/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp b/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp index 093a595..b6e0703 100644 --- a/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp +++ b/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp @@ -77,14 +77,14 @@ Value *memRefTypeCast(FuncBuilder &builder, Value *val) { /// Lower toy.mul to Linalg `matmul`. /// -/// This class inherit from `DialectOpConversion` and override `rewrite`, +/// This class inherit from `DialectConversionPattern` and override `rewrite`, /// similarly to the PatternRewriter introduced in the previous chapter. /// It will be called by the DialectConversion framework (see `LateLowering` /// class below). -class MulOpConversion : public DialectOpConversion { +class MulOpConversion : public DialectConversionPattern { public: explicit MulOpConversion(MLIRContext *context) - : DialectOpConversion(toy::MulOp::getOperationName(), 1, context) {} + : DialectConversionPattern(toy::MulOp::getOperationName(), 1, context) {} void rewrite(Operation *op, ArrayRef operands, PatternRewriter &rewriter) const override { diff --git a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp index 00b3a74..d504865 100644 --- a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp +++ b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp @@ -77,14 +77,14 @@ Value *memRefTypeCast(FuncBuilder &builder, Value *val) { /// Lower a toy.add to an affine loop nest. /// -/// This class inherit from `DialectOpConversion` and override `rewrite`, +/// This class inherit from `DialectConversionPattern` and override `rewrite`, /// similarly to the PatternRewriter introduced in the previous chapter. /// It will be called by the DialectConversion framework (see `LateLowering` /// class below). -class AddOpConversion : public DialectOpConversion { +class AddOpConversion : public DialectConversionPattern { public: explicit AddOpConversion(MLIRContext *context) - : DialectOpConversion(toy::AddOp::getOperationName(), 1, context) {} + : DialectConversionPattern(toy::AddOp::getOperationName(), 1, context) {} /// Lower the `op` by generating IR using the `rewriter` builder. The builder /// is setup with a new function, the `operands` array has been populated with @@ -125,10 +125,11 @@ public: /// Lowers `toy.print` to a loop nest calling `printf` on every individual /// elements of the array. -class PrintOpConversion : public DialectOpConversion { +class PrintOpConversion : public DialectConversionPattern { public: explicit PrintOpConversion(MLIRContext *context) - : DialectOpConversion(toy::PrintOp::getOperationName(), 1, context) {} + : DialectConversionPattern(toy::PrintOp::getOperationName(), 1, context) { + } void rewrite(Operation *op, ArrayRef operands, PatternRewriter &rewriter) const override { @@ -227,10 +228,11 @@ private: }; /// Lowers constant to a sequence of store in a buffer. -class ConstantOpConversion : public DialectOpConversion { +class ConstantOpConversion : public DialectConversionPattern { public: explicit ConstantOpConversion(MLIRContext *context) - : DialectOpConversion(toy::ConstantOp::getOperationName(), 1, context) {} + : DialectConversionPattern(toy::ConstantOp::getOperationName(), 1, + context) {} void rewrite(Operation *op, ArrayRef operands, PatternRewriter &rewriter) const override { @@ -270,10 +272,11 @@ public: }; /// Lower transpose operation to an affine loop nest. -class TransposeOpConversion : public DialectOpConversion { +class TransposeOpConversion : public DialectConversionPattern { public: explicit TransposeOpConversion(MLIRContext *context) - : DialectOpConversion(toy::TransposeOp::getOperationName(), 1, context) {} + : DialectConversionPattern(toy::TransposeOp::getOperationName(), 1, + context) {} void rewrite(Operation *op, ArrayRef operands, PatternRewriter &rewriter) const override { @@ -302,10 +305,11 @@ public: }; // Lower toy.return to standard return operation. -class ReturnOpConversion : public DialectOpConversion { +class ReturnOpConversion : public DialectConversionPattern { public: explicit ReturnOpConversion(MLIRContext *context) - : DialectOpConversion(toy::ReturnOp::getOperationName(), 1, context) {} + : DialectConversionPattern(toy::ReturnOp::getOperationName(), 1, + context) {} void rewrite(Operation *op, ArrayRef operands, PatternRewriter &rewriter) const override { diff --git a/mlir/g3doc/Tutorials/Linalg/LLVMConversion.md b/mlir/g3doc/Tutorials/Linalg/LLVMConversion.md index 72ca5a7..9f5c552 100644 --- a/mlir/g3doc/Tutorials/Linalg/LLVMConversion.md +++ b/mlir/g3doc/Tutorials/Linalg/LLVMConversion.md @@ -212,11 +212,11 @@ offset, the min/max values or the strides. Their static (constant) dimensions are available directly in the type signature. An operation conversion is defined as special pattern by inheriting from -`mlir::DialectOpConversion` and by reimplementing the matching and the rewriting -functions: +`mlir::DialectConversionPattern` and by reimplementing the matching and the +rewriting functions: ```c++ -class ViewOpConversion : public DialectOpConversion { +class ViewOpConversion : public DialectConversionPattern { public: // A conversion constructor, may take arbtirary operands but must be able // to obtain an MLIRContext from them to call the parent constructor. @@ -237,15 +237,15 @@ public: } ``` -The `DialectOpConversion` constructor takes, in addition to the context, the -name of the main operation to be matched and the "benefit" of a match. These +The `DialectConversionPattern` constructor takes, in addition to the context, +the name of the main operation to be matched and the "benefit" of a match. These operands are intended to be useful for defining an optimization problem across multiple possible conversions but are currently ignored by the conversion framework. ```c++ ViewOpConversion::ViewOpConversion(MLIRContext *context) - : DialectOpConversion(linalg::ViewOp::getOperationName(), 1, context) {} + : DialectConversionPattern(linalg::ViewOp::getOperationName(), 1, context) {} ``` The matching function can be used, for example, to apply different conversions @@ -621,7 +621,7 @@ class Lowering : public DialectConversion { protected: // Produce a set of operation conversion patterns. This is called once per // conversion. - llvm::DenseSet + llvm::DenseSet initConverter(MLIRContext *context) override { allocator.Reset(); // Call a helper function provided by MLIR to build a set of operation diff --git a/mlir/g3doc/Tutorials/Toy/Ch-5.md b/mlir/g3doc/Tutorials/Toy/Ch-5.md index 8124c79..37ac637 100644 --- a/mlir/g3doc/Tutorials/Toy/Ch-5.md +++ b/mlir/g3doc/Tutorials/Toy/Ch-5.md @@ -49,7 +49,7 @@ public: SmallVectorImpl &convertedArgAttrs) { /*...*/ } // This gets called once to set up operation converters. - llvm::DenseSet + llvm::DenseSet initConverters(MLIRContext *context) override { RewriteListBuilder::build(allocator, context); @@ -65,14 +65,14 @@ Individual operation converters are following this pattern: ```c++ /// Lower a toy.add to an affine loop nest. /// -/// This class inherit from `DialectOpConversion` and override `rewrite`, +/// This class inherit from `DialectConversionPattern` and override `rewrite`, /// similarly to the PatternRewriter introduced in the previous chapter. /// It will be called by the DialectConversion framework (see `LateLowering` /// class below). -class AddOpConversion : public DialectOpConversion { +class AddOpConversion : public DialectConversionPattern { public: explicit AddOpConversion(MLIRContext *context) - : DialectOpConversion(toy::AddOp::getOperationName(), 1, context) {} + : DialectConversionPattern(toy::AddOp::getOperationName(), 1, context) {} /// Lower the `op` by generating IR using the `rewriter` builder. The builder /// is setup with a new function, the `operands` array has been populated with diff --git a/mlir/include/mlir/LLVMIR/LLVMLowering.h b/mlir/include/mlir/LLVMIR/LLVMLowering.h index 6cec38c..c2bf040 100644 --- a/mlir/include/mlir/LLVMIR/LLVMLowering.h +++ b/mlir/include/mlir/LLVMIR/LLVMLowering.h @@ -131,7 +131,7 @@ private: /// Base class for operation conversions targeting the LLVM IR dialect. Provides /// conversion patterns with an access to the containing LLVMLowering for the /// purpose of type conversions. -class LLVMOpLowering : public DialectOpConversion { +class LLVMOpLowering : public DialectConversionPattern { public: LLVMOpLowering(StringRef rootOpName, MLIRContext *context, LLVMLowering &lowering); diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 48376e5..8bbce1d 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -36,23 +36,17 @@ class Operation; class Type; class Value; -// Private implementation class. -namespace impl { -class FunctionConverter; -} - -/// Base class for the dialect op conversion patterns. Specific conversions -/// must derive this class and implement least one `rewrite` method. Optionally -/// they can also override `PatternMatch match(Operation *)` to match more -/// specific operations than the `rootName` provided in the constructor. +/// Base class for the dialect conversion patterns that require type changes. +/// Specific conversions must derive this class and implement least one +/// `rewrite` method. /// NOTE: These conversion patterns can only be used with the DialectConversion /// class. -class DialectOpConversion : public RewritePattern { +class DialectConversionPattern : public RewritePattern { public: - /// Construct an DialectOpConversion. `rootName` must correspond to the + /// Construct an DialectConversionPattern. `rootName` must correspond to the /// canonical name of the first operation matched by the pattern. - DialectOpConversion(StringRef rootName, PatternBenefit benefit, - MLIRContext *ctx) + DialectConversionPattern(StringRef rootName, PatternBenefit benefit, + MLIRContext *ctx) : RewritePattern(rootName, benefit, ctx) {} /// Hook for derived classes to implement matching. Dialect conversion @@ -65,10 +59,10 @@ public: /// Hook for derived classes to implement rewriting. `op` is the (first) /// operation matched by the pattern, `operands` is a list of rewritten values /// that are passed to this operation, `rewriter` can be used to emit the new - /// operations. This function must be reimplemented if the DialectOpConversion - /// ever needs to replace an operation that does not have successors. This - /// function should not fail. If some specific cases of the operation are not - /// supported, these cases should not be matched. + /// operations. This function must be reimplemented if the + /// DialectConversionPattern ever needs to replace an operation that does not + /// have successors. This function should not fail. If some specific cases of + /// the operation are not supported, these cases should not be matched. virtual void rewrite(Operation *op, ArrayRef operands, PatternRewriter &rewriter) const { llvm_unreachable("unimplemented rewrite"); @@ -80,7 +74,7 @@ public: /// of (potentially rewritten) successor blocks, `operands` is a list of lists /// of rewritten values passed to each of the successors, co-indexed with /// `destinations`, `rewriter` can be used to emit the new operations. It must - /// be reimplemented if the DialectOpConversion ever needs to replace a + /// be reimplemented if the DialectConversionPattern ever needs to replace a /// terminator operation that has successors. This function should not fail /// the pass. If some specific cases of the operation are not supported, /// these cases should not be matched. @@ -126,8 +120,6 @@ private: /// /// If the conversion fails, the module is not modified. class DialectConversion { - friend class impl::FunctionConverter; - public: virtual ~DialectConversion() = default; @@ -135,7 +127,6 @@ public: LLVM_NODISCARD LogicalResult convert(Module *m); -protected: /// Derived classes must implement this hook to produce a set of conversion /// patterns to apply. They may use `mlirContext` to obtain registered /// dialects or operations. This will be called in the beginning of the diff --git a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp index f1c43ef..36267e9 100644 --- a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp +++ b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp @@ -195,7 +195,7 @@ static Type getMemRefElementPtrType(MemRefType t, LLVMLowering &lowering) { LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context, LLVMLowering &lowering_) - : DialectOpConversion(rootOpName, /*benefit=*/1, context), + : DialectConversionPattern(rootOpName, /*benefit=*/1, context), lowering(lowering_) {} namespace { diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index 3b16323..9c428d4 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -25,6 +25,9 @@ using namespace mlir; using namespace mlir::impl; +//===----------------------------------------------------------------------===// +// ProducerGenerator +//===----------------------------------------------------------------------===// namespace { /// This class provides a simple interface for generating fake producers during /// the conversion process. These fake producers are used when replacing the @@ -87,8 +90,12 @@ struct ProducerGenerator { UnknownLoc loc; }; -/// This class implements a pattern rewriter for DialectOpConversion patterns. -/// It automatically performs remapping of replaced operation values. +//===----------------------------------------------------------------------===// +// DialectConversionRewriter +//===----------------------------------------------------------------------===// + +/// This class implements a pattern rewriter for DialectConversionPattern +/// patterns. It automatically performs remapping of replaced operation values. struct DialectConversionRewriter final : public PatternRewriter { DialectConversionRewriter(Function *fn) : PatternRewriter(fn), tempGenerator(fn->getContext()) {} @@ -130,13 +137,17 @@ struct DialectConversionRewriter final : public PatternRewriter { }; } // end anonymous namespace +//===----------------------------------------------------------------------===// +// DialectConversionPattern +//===----------------------------------------------------------------------===// + /// Rewrite the IR rooted at the specified operation with the result of /// this pattern, generating any new operations with the specified /// builder. If an unexpected error is encountered (an internal /// compiler error), it is emitted through the normal MLIR diagnostic /// hooks and the IR is left in a valid state. -void DialectOpConversion::rewrite(Operation *op, - PatternRewriter &rewriter) const { +void DialectConversionPattern::rewrite(Operation *op, + PatternRewriter &rewriter) const { SmallVector operands; auto &dialectRewriter = static_cast(rewriter); dialectRewriter.lookupValues(op->getOperands(), operands); @@ -168,8 +179,10 @@ void DialectOpConversion::rewrite(Operation *op, destinations, operandsPerDestination, rewriter); } -namespace mlir { -namespace impl { +//===----------------------------------------------------------------------===// +// FunctionConverter +//===----------------------------------------------------------------------===// +namespace { // Implementation detail class of the DialectConversion utility. Performs // function-by-function conversions by creating new functions, filling them in // with converted blocks, updating the function attributes, and replacing the @@ -211,8 +224,7 @@ public: /// The matcher to use when converting operations. RewritePatternMatcher &matcher; }; -} // end namespace impl -} // end namespace mlir +} // end anonymous namespace LogicalResult FunctionConverter::convertArgument(DialectConversionRewriter &rewriter, @@ -326,6 +338,10 @@ Function *FunctionConverter::convertFunction(Function *f) { return newFunc; } +//===----------------------------------------------------------------------===// +// DialectConversion +//===----------------------------------------------------------------------===// + // Create a function type with arguments and results converted, and argument // attributes passed through. FunctionType DialectConversion::convertFunctionSignatureType( -- 2.7.4