namespace mlir {
class DialectConversion;
-class DialectOpConversion;
+class DialectConversionPattern;
class MLIRContext;
class Module;
class RewritePattern;
}
// RangeOp creates a new range descriptor.
-class RangeOpConversion : public DialectOpConversion {
+class RangeOpConversion : public DialectConversionPattern {
public:
explicit RangeOpConversion(MLIRContext *context)
- : DialectOpConversion(linalg::RangeOp::getOperationName(), 1, context) {}
+ : DialectConversionPattern(linalg::RangeOp::getOperationName(), 1,
+ context) {}
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
}
};
-class ViewOpConversion : public DialectOpConversion {
+class ViewOpConversion : public DialectConversionPattern {
public:
explicit ViewOpConversion(MLIRContext *context)
- : DialectOpConversion(linalg::ViewOp::getOperationName(), 1, context) {}
+ : DialectConversionPattern(linalg::ViewOp::getOperationName(), 1,
+ context) {}
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
}
};
-class SliceOpConversion : public DialectOpConversion {
+class SliceOpConversion : public DialectConversionPattern {
public:
explicit SliceOpConversion(MLIRContext *context)
- : DialectOpConversion(linalg::SliceOp::getOperationName(), 1, context) {}
+ : DialectConversionPattern(linalg::SliceOp::getOperationName(), 1,
+ context) {}
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
// When converting the "some_consumer" operation, don't emit anything and
// effectively drop it.
-class DropConsumer : public DialectOpConversion {
+class DropConsumer : public DialectConversionPattern {
public:
explicit DropConsumer(MLIRContext *context)
- : DialectOpConversion("some_consumer", 1, context) {}
+ : DialectConversionPattern("some_consumer", 1, context) {}
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
// Common functionality for Linalg LoadOp and StoreOp conversion to the
// LLVM IR Dialect.
template <typename Op>
-class LoadStoreOpConversion : public DialectOpConversion {
+class LoadStoreOpConversion : public DialectConversionPattern {
public:
explicit LoadStoreOpConversion(MLIRContext *context)
- : DialectOpConversion(Op::getOperationName(), 1, context) {}
+ : DialectConversionPattern(Op::getOperationName(), 1, context) {}
using Base = LoadStoreOpConversion<Op>;
// Compute the pointer to an element of the buffer underlying the view given
/// Lower toy.mul to Linalg `matmul`.
///
-/// This class inherit from `DialectOpConversion` and override `rewrite`,
+/// This class inherit from `DialectConversionPattern` and override `rewrite`,
/// similarly to the PatternRewriter introduced in the previous chapter.
/// It will be called by the DialectConversion framework (see `LateLowering`
/// class below).
-class MulOpConversion : public DialectOpConversion {
+class MulOpConversion : public DialectConversionPattern {
public:
explicit MulOpConversion(MLIRContext *context)
- : DialectOpConversion(toy::MulOp::getOperationName(), 1, context) {}
+ : DialectConversionPattern(toy::MulOp::getOperationName(), 1, context) {}
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
/// Lower a toy.add to an affine loop nest.
///
-/// This class inherit from `DialectOpConversion` and override `rewrite`,
+/// This class inherit from `DialectConversionPattern` and override `rewrite`,
/// similarly to the PatternRewriter introduced in the previous chapter.
/// It will be called by the DialectConversion framework (see `LateLowering`
/// class below).
-class AddOpConversion : public DialectOpConversion {
+class AddOpConversion : public DialectConversionPattern {
public:
explicit AddOpConversion(MLIRContext *context)
- : DialectOpConversion(toy::AddOp::getOperationName(), 1, context) {}
+ : DialectConversionPattern(toy::AddOp::getOperationName(), 1, context) {}
/// Lower the `op` by generating IR using the `rewriter` builder. The builder
/// is setup with a new function, the `operands` array has been populated with
/// Lowers `toy.print` to a loop nest calling `printf` on every individual
/// elements of the array.
-class PrintOpConversion : public DialectOpConversion {
+class PrintOpConversion : public DialectConversionPattern {
public:
explicit PrintOpConversion(MLIRContext *context)
- : DialectOpConversion(toy::PrintOp::getOperationName(), 1, context) {}
+ : DialectConversionPattern(toy::PrintOp::getOperationName(), 1, context) {
+ }
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
};
/// Lowers constant to a sequence of store in a buffer.
-class ConstantOpConversion : public DialectOpConversion {
+class ConstantOpConversion : public DialectConversionPattern {
public:
explicit ConstantOpConversion(MLIRContext *context)
- : DialectOpConversion(toy::ConstantOp::getOperationName(), 1, context) {}
+ : DialectConversionPattern(toy::ConstantOp::getOperationName(), 1,
+ context) {}
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
};
/// Lower transpose operation to an affine loop nest.
-class TransposeOpConversion : public DialectOpConversion {
+class TransposeOpConversion : public DialectConversionPattern {
public:
explicit TransposeOpConversion(MLIRContext *context)
- : DialectOpConversion(toy::TransposeOp::getOperationName(), 1, context) {}
+ : DialectConversionPattern(toy::TransposeOp::getOperationName(), 1,
+ context) {}
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
};
// Lower toy.return to standard return operation.
-class ReturnOpConversion : public DialectOpConversion {
+class ReturnOpConversion : public DialectConversionPattern {
public:
explicit ReturnOpConversion(MLIRContext *context)
- : DialectOpConversion(toy::ReturnOp::getOperationName(), 1, context) {}
+ : DialectConversionPattern(toy::ReturnOp::getOperationName(), 1,
+ context) {}
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
are available directly in the type signature.
An operation conversion is defined as special pattern by inheriting from
-`mlir::DialectOpConversion` and by reimplementing the matching and the rewriting
-functions:
+`mlir::DialectConversionPattern` and by reimplementing the matching and the
+rewriting functions:
```c++
-class ViewOpConversion : public DialectOpConversion {
+class ViewOpConversion : public DialectConversionPattern {
public:
// A conversion constructor, may take arbtirary operands but must be able
// to obtain an MLIRContext from them to call the parent constructor.
}
```
-The `DialectOpConversion` constructor takes, in addition to the context, the
-name of the main operation to be matched and the "benefit" of a match. These
+The `DialectConversionPattern` constructor takes, in addition to the context,
+the name of the main operation to be matched and the "benefit" of a match. These
operands are intended to be useful for defining an optimization problem across
multiple possible conversions but are currently ignored by the conversion
framework.
```c++
ViewOpConversion::ViewOpConversion(MLIRContext *context)
- : DialectOpConversion(linalg::ViewOp::getOperationName(), 1, context) {}
+ : DialectConversionPattern(linalg::ViewOp::getOperationName(), 1, context) {}
```
The matching function can be used, for example, to apply different conversions
protected:
// Produce a set of operation conversion patterns. This is called once per
// conversion.
- llvm::DenseSet<DialectOpConversion *>
+ llvm::DenseSet<DialectConversionPattern *>
initConverter(MLIRContext *context) override {
allocator.Reset();
// Call a helper function provided by MLIR to build a set of operation
SmallVectorImpl<NamedAttributeList> &convertedArgAttrs) { /*...*/ }
// This gets called once to set up operation converters.
- llvm::DenseSet<DialectOpConversion *>
+ llvm::DenseSet<DialectConversionPattern *>
initConverters(MLIRContext *context) override {
RewriteListBuilder<MulOpConversion, PrintOpConversion,
TransposeOpConversion>::build(allocator, context);
```c++
/// Lower a toy.add to an affine loop nest.
///
-/// This class inherit from `DialectOpConversion` and override `rewrite`,
+/// This class inherit from `DialectConversionPattern` and override `rewrite`,
/// similarly to the PatternRewriter introduced in the previous chapter.
/// It will be called by the DialectConversion framework (see `LateLowering`
/// class below).
-class AddOpConversion : public DialectOpConversion {
+class AddOpConversion : public DialectConversionPattern {
public:
explicit AddOpConversion(MLIRContext *context)
- : DialectOpConversion(toy::AddOp::getOperationName(), 1, context) {}
+ : DialectConversionPattern(toy::AddOp::getOperationName(), 1, context) {}
/// Lower the `op` by generating IR using the `rewriter` builder. The builder
/// is setup with a new function, the `operands` array has been populated with
/// Base class for operation conversions targeting the LLVM IR dialect. Provides
/// conversion patterns with an access to the containing LLVMLowering for the
/// purpose of type conversions.
-class LLVMOpLowering : public DialectOpConversion {
+class LLVMOpLowering : public DialectConversionPattern {
public:
LLVMOpLowering(StringRef rootOpName, MLIRContext *context,
LLVMLowering &lowering);
class Type;
class Value;
-// Private implementation class.
-namespace impl {
-class FunctionConverter;
-}
-
-/// Base class for the dialect op conversion patterns. Specific conversions
-/// must derive this class and implement least one `rewrite` method. Optionally
-/// they can also override `PatternMatch match(Operation *)` to match more
-/// specific operations than the `rootName` provided in the constructor.
+/// Base class for the dialect conversion patterns that require type changes.
+/// Specific conversions must derive this class and implement least one
+/// `rewrite` method.
/// NOTE: These conversion patterns can only be used with the DialectConversion
/// class.
-class DialectOpConversion : public RewritePattern {
+class DialectConversionPattern : public RewritePattern {
public:
- /// Construct an DialectOpConversion. `rootName` must correspond to the
+ /// Construct an DialectConversionPattern. `rootName` must correspond to the
/// canonical name of the first operation matched by the pattern.
- DialectOpConversion(StringRef rootName, PatternBenefit benefit,
- MLIRContext *ctx)
+ DialectConversionPattern(StringRef rootName, PatternBenefit benefit,
+ MLIRContext *ctx)
: RewritePattern(rootName, benefit, ctx) {}
/// Hook for derived classes to implement matching. Dialect conversion
/// Hook for derived classes to implement rewriting. `op` is the (first)
/// operation matched by the pattern, `operands` is a list of rewritten values
/// that are passed to this operation, `rewriter` can be used to emit the new
- /// operations. This function must be reimplemented if the DialectOpConversion
- /// ever needs to replace an operation that does not have successors. This
- /// function should not fail. If some specific cases of the operation are not
- /// supported, these cases should not be matched.
+ /// operations. This function must be reimplemented if the
+ /// DialectConversionPattern ever needs to replace an operation that does not
+ /// have successors. This function should not fail. If some specific cases of
+ /// the operation are not supported, these cases should not be matched.
virtual void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const {
llvm_unreachable("unimplemented rewrite");
/// of (potentially rewritten) successor blocks, `operands` is a list of lists
/// of rewritten values passed to each of the successors, co-indexed with
/// `destinations`, `rewriter` can be used to emit the new operations. It must
- /// be reimplemented if the DialectOpConversion ever needs to replace a
+ /// be reimplemented if the DialectConversionPattern ever needs to replace a
/// terminator operation that has successors. This function should not fail
/// the pass. If some specific cases of the operation are not supported,
/// these cases should not be matched.
///
/// If the conversion fails, the module is not modified.
class DialectConversion {
- friend class impl::FunctionConverter;
-
public:
virtual ~DialectConversion() = default;
LLVM_NODISCARD
LogicalResult convert(Module *m);
-protected:
/// Derived classes must implement this hook to produce a set of conversion
/// patterns to apply. They may use `mlirContext` to obtain registered
/// dialects or operations. This will be called in the beginning of the
LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context,
LLVMLowering &lowering_)
- : DialectOpConversion(rootOpName, /*benefit=*/1, context),
+ : DialectConversionPattern(rootOpName, /*benefit=*/1, context),
lowering(lowering_) {}
namespace {
using namespace mlir;
using namespace mlir::impl;
+//===----------------------------------------------------------------------===//
+// ProducerGenerator
+//===----------------------------------------------------------------------===//
namespace {
/// This class provides a simple interface for generating fake producers during
/// the conversion process. These fake producers are used when replacing the
UnknownLoc loc;
};
-/// This class implements a pattern rewriter for DialectOpConversion patterns.
-/// It automatically performs remapping of replaced operation values.
+//===----------------------------------------------------------------------===//
+// DialectConversionRewriter
+//===----------------------------------------------------------------------===//
+
+/// This class implements a pattern rewriter for DialectConversionPattern
+/// patterns. It automatically performs remapping of replaced operation values.
struct DialectConversionRewriter final : public PatternRewriter {
DialectConversionRewriter(Function *fn)
: PatternRewriter(fn), tempGenerator(fn->getContext()) {}
};
} // end anonymous namespace
+//===----------------------------------------------------------------------===//
+// DialectConversionPattern
+//===----------------------------------------------------------------------===//
+
/// Rewrite the IR rooted at the specified operation with the result of
/// this pattern, generating any new operations with the specified
/// builder. If an unexpected error is encountered (an internal
/// compiler error), it is emitted through the normal MLIR diagnostic
/// hooks and the IR is left in a valid state.
-void DialectOpConversion::rewrite(Operation *op,
- PatternRewriter &rewriter) const {
+void DialectConversionPattern::rewrite(Operation *op,
+ PatternRewriter &rewriter) const {
SmallVector<Value *, 4> operands;
auto &dialectRewriter = static_cast<DialectConversionRewriter &>(rewriter);
dialectRewriter.lookupValues(op->getOperands(), operands);
destinations, operandsPerDestination, rewriter);
}
-namespace mlir {
-namespace impl {
+//===----------------------------------------------------------------------===//
+// FunctionConverter
+//===----------------------------------------------------------------------===//
+namespace {
// Implementation detail class of the DialectConversion utility. Performs
// function-by-function conversions by creating new functions, filling them in
// with converted blocks, updating the function attributes, and replacing the
/// The matcher to use when converting operations.
RewritePatternMatcher &matcher;
};
-} // end namespace impl
-} // end namespace mlir
+} // end anonymous namespace
LogicalResult
FunctionConverter::convertArgument(DialectConversionRewriter &rewriter,
return newFunc;
}
+//===----------------------------------------------------------------------===//
+// DialectConversion
+//===----------------------------------------------------------------------===//
+
// Create a function type with arguments and results converted, and argument
// attributes passed through.
FunctionType DialectConversion::convertFunctionSignatureType(