explicit RangeOpConversion(MLIRContext *context)
: ConversionPattern(linalg::RangeOp::getOperationName(), 1, context) {}
- PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
- PatternRewriter &rewriter) const override {
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
auto rangeOp = cast<linalg::RangeOp>(op);
auto rangeDescriptorType =
linalg::convertLinalgType(rangeOp.getResult()->getType());
explicit ViewOpConversion(MLIRContext *context)
: ConversionPattern(linalg::ViewOp::getOperationName(), 1, context) {}
- PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
- PatternRewriter &rewriter) const override {
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
auto viewOp = cast<linalg::ViewOp>(op);
auto viewDescriptorType = linalg::convertLinalgType(viewOp.getViewType());
auto memrefType =
explicit SliceOpConversion(MLIRContext *context)
: ConversionPattern(linalg::SliceOp::getOperationName(), 1, context) {}
- PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
- PatternRewriter &rewriter) const override {
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
auto sliceOp = cast<linalg::SliceOp>(op);
auto newViewDescriptorType =
linalg::convertLinalgType(sliceOp.getViewType());
explicit DropConsumer(MLIRContext *context)
: ConversionPattern("some_consumer", 1, context) {}
- PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
- PatternRewriter &rewriter) const override {
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOp(op, llvm::None);
return matchSuccess();
}
// an LLVM IR load.
class LoadOpConversion : public LoadStoreOpConversion<linalg::LoadOp> {
using Base::Base;
- PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
- PatternRewriter &rewriter) const override {
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
edsc::ScopedContext edscContext(rewriter, op->getLoc());
auto elementType = linalg::convertLinalgType(*op->result_type_begin());
Value *viewDescriptor = operands[0];
// an LLVM IR store.
class StoreOpConversion : public LoadStoreOpConversion<linalg::StoreOp> {
using Base::Base;
- PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
- PatternRewriter &rewriter) const override {
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
edsc::ScopedContext edscContext(rewriter, op->getLoc());
Value *viewDescriptor = operands[1];
Value *data = operands[0];
/// time both side of the cast (producer and consumer) will be lowered to a
/// dialect like LLVM and end up with the same LLVM representation, at which
/// point this becomes a no-op and is eliminated.
-Value *typeCast(PatternRewriter &builder, Value *val, Type destTy) {
+Value *typeCast(ConversionPatternRewriter &builder, Value *val, Type destTy) {
if (val->getType() == destTy)
return val;
return builder.create<toy::TypeCastOp>(val->getLoc(), val, destTy)
/// Create a type cast to turn a toy.array into a memref. The Toy Array will be
/// lowered to a memref during buffer allocation, at which point the type cast
/// becomes useless.
-Value *memRefTypeCast(PatternRewriter &builder, Value *val) {
+Value *memRefTypeCast(ConversionPatternRewriter &builder, Value *val) {
if (val->getType().isa<MemRefType>())
return val;
auto toyArrayTy = val->getType().dyn_cast<toy::ToyArrayType>();
explicit MulOpConversion(MLIRContext *context)
: ConversionPattern(toy::MulOp::getOperationName(), 1, context) {}
- PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
- PatternRewriter &rewriter) const override {
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
using namespace edsc;
using intrinsics::constant_index;
using linalg::intrinsics::range;
/// the rewritten operands for `op` in the new function.
/// The results created by the new IR with the builder are returned, and their
/// number must match the number of result of `op`.
- PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
- PatternRewriter &rewriter) const override {
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
auto add = cast<toy::AddOp>(op);
auto loc = add.getLoc();
// Create a `toy.alloc` operation to allocate the output buffer for this op.
explicit PrintOpConversion(MLIRContext *context)
: ConversionPattern(toy::PrintOp::getOperationName(), 1, context) {}
- PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
- PatternRewriter &rewriter) const override {
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
// Get or create the declaration of the printf function in the module.
FuncOp printfFunc = getPrintf(op->getParentOfType<ModuleOp>());
explicit ConstantOpConversion(MLIRContext *context)
: ConversionPattern(toy::ConstantOp::getOperationName(), 1, context) {}
- PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
- PatternRewriter &rewriter) const override {
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
toy::ConstantOp cstOp = cast<toy::ConstantOp>(op);
auto loc = cstOp.getLoc();
auto retTy = cstOp.getResult()->getType().cast<toy::ToyArrayType>();
explicit TransposeOpConversion(MLIRContext *context)
: ConversionPattern(toy::TransposeOp::getOperationName(), 1, context) {}
- PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
- PatternRewriter &rewriter) const override {
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
auto transpose = cast<toy::TransposeOp>(op);
auto loc = transpose.getLoc();
Value *result = memRefTypeCast(
explicit ReturnOpConversion(MLIRContext *context)
: ConversionPattern(toy::ReturnOp::getOperationName(), 1, context) {}
- PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
- PatternRewriter &rewriter) const override {
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
// Argument is optional, handle both cases.
if (op->getNumOperands())
rewriter.replaceOpWithNewOp<ReturnOp>(op, operands[0]);
/// (perhaps transitively) dead. If any of those values are dead, this will
/// remove them as well.
virtual void replaceOp(Operation *op, ArrayRef<Value *> newValues,
- ArrayRef<Value *> valuesToRemoveIfDead = {});
+ ArrayRef<Value *> valuesToRemoveIfDead);
+ void replaceOp(Operation *op, ArrayRef<Value *> newValues) {
+ replaceOp(op, newValues, llvm::None);
+ }
/// Replaces the result op with a new op that is created without verification.
/// The result values of the two ops must be the same types.
// Forward declarations.
class Block;
+class ConversionPatternRewriter;
class FuncOp;
class MLIRContext;
class Operation;
/// have successors. This function should not fail. If some specific cases of
/// the operation are not supported, these cases should not be matched.
virtual void rewrite(Operation *op, ArrayRef<Value *> operands,
- PatternRewriter &rewriter) const {
+ ConversionPatternRewriter &rewriter) const {
llvm_unreachable("unimplemented rewrite");
}
virtual void rewrite(Operation *op, ArrayRef<Value *> properOperands,
ArrayRef<Block *> destinations,
ArrayRef<ArrayRef<Value *>> operands,
- PatternRewriter &rewriter) const {
+ ConversionPatternRewriter &rewriter) const {
llvm_unreachable("unimplemented rewrite for terminators");
}
matchAndRewrite(Operation *op, ArrayRef<Value *> properOperands,
ArrayRef<Block *> destinations,
ArrayRef<ArrayRef<Value *>> operands,
- PatternRewriter &rewriter) const {
+ ConversionPatternRewriter &rewriter) const {
if (!match(op))
return matchFailure();
rewrite(op, properOperands, destinations, operands, rewriter);
}
/// Hook for derived classes to implement combined matching and rewriting.
- virtual PatternMatchResult matchAndRewrite(Operation *op,
- ArrayRef<Value *> operands,
- PatternRewriter &rewriter) const {
+ virtual PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const {
if (!match(op))
return matchFailure();
rewrite(op, operands, rewriter);
};
//===----------------------------------------------------------------------===//
+// Conversion PatternRewriter
+//===----------------------------------------------------------------------===//
+
+namespace detail {
+struct ConversionPatternRewriterImpl;
+} // end namespace detail
+
+/// This class implements a pattern rewriter for use with ConversionPatterns.
+class ConversionPatternRewriter final : public PatternRewriter {
+public:
+ ConversionPatternRewriter(MLIRContext *ctx, TypeConverter *converter);
+ ~ConversionPatternRewriter() override;
+
+ //===--------------------------------------------------------------------===//
+ // PatternRewriter Hooks
+ //===--------------------------------------------------------------------===//
+
+ /// PatternRewriter hook for replacing the results of an operation.
+ void replaceOp(Operation *op, ArrayRef<Value *> newValues,
+ ArrayRef<Value *> valuesToRemoveIfDead) override;
+ using PatternRewriter::replaceOp;
+
+ /// PatternRewriter hook for splitting a block into two parts.
+ Block *splitBlock(Block *block, Block::iterator before) override;
+
+ /// PatternRewriter hook for moving blocks out of a region.
+ void inlineRegionBefore(Region ®ion, Region &parent,
+ Region::iterator before) override;
+ using PatternRewriter::inlineRegionBefore;
+
+ /// PatternRewriter hook for creating a new operation.
+ Operation *createOperation(const OperationState &state) override;
+
+ /// PatternRewriter hook for updating the root operation in-place.
+ void notifyRootUpdated(Operation *op) override;
+
+ /// Return a reference to the internal implementation.
+ detail::ConversionPatternRewriterImpl &getImpl();
+
+private:
+ std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
+};
+
+//===----------------------------------------------------------------------===//
// ConversionTarget
//===----------------------------------------------------------------------===//
/// by the target.
Dynamic,
- /// This target explicitly does not support this operation.
+ /// The target explicitly does not support this operation.
Illegal,
};
ForLowering(MLIRContext *ctx)
: ConversionPattern(ForOp::getOperationName(), 1, ctx) {}
- PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
- PatternRewriter &rewriter) const override;
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override;
};
// Create a CFG subgraph for the loop.if operation (including its "then" and
IfLowering(MLIRContext *ctx)
: ConversionPattern(IfOp::getOperationName(), 1, ctx) {}
- PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
- PatternRewriter &rewriter) const override;
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override;
};
struct TerminatorLowering : public ConversionPattern {
TerminatorLowering(MLIRContext *ctx)
: ConversionPattern(TerminatorOp::getOperationName(), 1, ctx) {}
- PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
- PatternRewriter &rewriter) const override {
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOp(op, {});
return matchSuccess();
}
PatternMatchResult
ForLowering::matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
- PatternRewriter &rewriter) const {
+ ConversionPatternRewriter &rewriter) const {
auto forOp = cast<ForOp>(op);
Location loc = op->getLoc();
PatternMatchResult
IfLowering::matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
- PatternRewriter &rewriter) const {
+ ConversionPatternRewriter &rewriter) const {
auto ifOp = cast<IfOp>(op);
auto loc = op->getLoc();
}
// Create an LLVM IR pseudo-operation defining the given index constant.
- Value *createIndexConstant(PatternRewriter &builder, Location loc,
+ Value *createIndexConstant(ConversionPatternRewriter &builder, Location loc,
uint64_t value) const {
auto attr = builder.getIntegerAttr(builder.getIndexType(), value);
return builder.create<LLVM::ConstantOp>(loc, getIndexType(), attr);
// Get the array attribute named "position" containing the given list of
// integers as integer attribute elements.
- static ArrayAttr getIntegerArrayAttr(PatternRewriter &builder,
+ static ArrayAttr getIntegerArrayAttr(ConversionPatternRewriter &builder,
ArrayRef<int64_t> values) {
SmallVector<Attribute, 4> attrs;
attrs.reserve(values.size());
}
// Extract raw data pointer value from a value representing a memref.
- static Value *extractMemRefElementPtr(PatternRewriter &builder, Location loc,
+ static Value *extractMemRefElementPtr(ConversionPatternRewriter &builder,
+ Location loc,
Value *convertedMemRefValue,
Type elementTypePtr,
bool hasStaticShape) {
// Convert the type of the result to an LLVM type, pass operands as is,
// preserve attributes.
- PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
- PatternRewriter &rewriter) const override {
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
unsigned numResults = op->getNumResults();
Type packedType;
}
void rewrite(Operation *op, ArrayRef<Value *> operands,
- PatternRewriter &rewriter) const override {
+ ConversionPatternRewriter &rewriter) const override {
auto allocOp = cast<AllocOp>(op);
MemRefType type = allocOp.getType();
struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
using LLVMLegalizationPattern<DeallocOp>::LLVMLegalizationPattern;
- PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
- PatternRewriter &rewriter) const override {
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
assert(operands.size() == 1 && "dealloc takes one operand");
OperandAdaptor<DeallocOp> transformed(operands);
}
void rewrite(Operation *op, ArrayRef<Value *> operands,
- PatternRewriter &rewriter) const override {
+ ConversionPatternRewriter &rewriter) const override {
auto memRefCastOp = cast<MemRefCastOp>(op);
OperandAdaptor<MemRefCastOp> transformed(operands);
auto targetType = memRefCastOp.getType();
}
void rewrite(Operation *op, ArrayRef<Value *> operands,
- PatternRewriter &rewriter) const override {
+ ConversionPatternRewriter &rewriter) const override {
auto dimOp = cast<DimOp>(op);
OperandAdaptor<DimOp> transformed(operands);
MemRefType type = dimOp.getOperand()->getType().cast<MemRefType>();
// by accumulating the running linearized value.
// Note that `indices` and `allocSizes` are passed in the same order as they
// appear in load/store operations and memref type declarations.
- Value *linearizeSubscripts(PatternRewriter &builder, Location loc,
+ Value *linearizeSubscripts(ConversionPatternRewriter &builder, Location loc,
ArrayRef<Value *> indices,
ArrayRef<Value *> allocSizes) const {
assert(indices.size() == allocSizes.size() &&
Value *getElementPtr(Location loc, Type elementTypePtr,
ArrayRef<int64_t> shape, Value *memRefDescriptor,
ArrayRef<Value *> indices,
- PatternRewriter &rewriter) const {
+ ConversionPatternRewriter &rewriter) const {
// Get the list of MemRef sizes. Static sizes are defined as constants.
// Dynamic sizes are extracted from the MemRef descriptor, where they start
// from the position 1 (the buffer is at position 0).
Value *getRawElementPtr(Location loc, Type elementTypePtr,
ArrayRef<int64_t> shape, Value *rawDataPtr,
ArrayRef<Value *> indices,
- PatternRewriter &rewriter) const {
+ ConversionPatternRewriter &rewriter) const {
if (shape.empty())
return rawDataPtr;
}
Value *getDataPtr(Location loc, MemRefType type, Value *dataPtr,
- ArrayRef<Value *> indices, PatternRewriter &rewriter,
+ ArrayRef<Value *> indices,
+ ConversionPatternRewriter &rewriter,
llvm::Module &module) const {
auto ptrType = getMemRefElementPtrType(type, this->lowering);
auto shape = type.getShape();
struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> {
using Base::Base;
- PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
- PatternRewriter &rewriter) const override {
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
auto loadOp = cast<LoadOp>(op);
OperandAdaptor<LoadOp> transformed(operands);
auto type = loadOp.getMemRefType();
struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> {
using Base::Base;
- PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
- PatternRewriter &rewriter) const override {
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
auto type = cast<StoreOp>(op).getMemRefType();
OperandAdaptor<StoreOp> transformed(operands);
struct IndexCastOpLowering : public LLVMLegalizationPattern<IndexCastOp> {
using LLVMLegalizationPattern<IndexCastOp>::LLVMLegalizationPattern;
- PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
- PatternRewriter &rewriter) const override {
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
IndexCastOpOperandAdaptor transformed(operands);
auto indexCastOp = cast<IndexCastOp>(op);
struct CmpIOpLowering : public LLVMLegalizationPattern<CmpIOp> {
using LLVMLegalizationPattern<CmpIOp>::LLVMLegalizationPattern;
- PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
- PatternRewriter &rewriter) const override {
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
auto cmpiOp = cast<CmpIOp>(op);
CmpIOpOperandAdaptor transformed(operands);
using LLVMLegalizationPattern<SourceOp>::LLVMLegalizationPattern;
using Super = OneToOneLLVMTerminatorLowering<SourceOp, TargetOp>;
- PatternMatchResult matchAndRewrite(Operation *op,
- ArrayRef<Value *> properOperands,
- ArrayRef<Block *> destinations,
- ArrayRef<ArrayRef<Value *>> operands,
- PatternRewriter &rewriter) const override {
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> properOperands,
+ ArrayRef<Block *> destinations,
+ ArrayRef<ArrayRef<Value *>> operands,
+ ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<TargetOp>(op, properOperands, destinations,
operands, op->getAttrs());
return this->matchSuccess();
struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> {
using LLVMLegalizationPattern<ReturnOp>::LLVMLegalizationPattern;
- PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
- PatternRewriter &rewriter) const override {
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
unsigned numArguments = op->getNumOperands();
// If ReturnOp has 0 or 1 operand, create it and return immediately.
LLVMTypeConverter &lowering_)
: LLVMOpLowering(BufferAllocOp::getOperationName(), context, lowering_) {}
- PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
- PatternRewriter &rewriter) const override {
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
auto indexType = IndexType::get(op->getContext());
auto voidPtrTy =
LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
: LLVMOpLowering(BufferDeallocOp::getOperationName(), context,
lowering_) {}
- PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
- PatternRewriter &rewriter) const override {
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
auto voidPtrTy =
LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
// Insert the `free` declaration if it is not already present.
BufferSizeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
: LLVMOpLowering(BufferSizeOp::getOperationName(), context, lowering_) {}
- PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
- PatternRewriter &rewriter) const override {
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
edsc::ScopedContext context(rewriter, op->getLoc());
rewriter.replaceOp(
explicit DimOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
: LLVMOpLowering(linalg::DimOp::getOperationName(), context, lowering_) {}
- PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
- PatternRewriter &rewriter) const override {
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
auto dimOp = cast<linalg::DimOp>(op);
auto indexTy = lowering.convertType(rewriter.getIndexType());
edsc::ScopedContext context(rewriter, op->getLoc());
// a getelementptr. This must be called under an edsc::ScopedContext.
Value *obtainDataPtr(Operation *op, Value *viewDescriptor,
ArrayRef<Value *> indices,
- PatternRewriter &rewriter) const {
+ ConversionPatternRewriter &rewriter) const {
auto loadOp = cast<Op>(op);
auto elementTy = getPtrToElementType(loadOp.getViewType(), lowering);
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
// an LLVM IR load.
class LoadOpConversion : public LoadStoreOpConversion<linalg::LoadOp> {
using Base::Base;
- PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
- PatternRewriter &rewriter) const override {
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
edsc::ScopedContext edscContext(rewriter, op->getLoc());
auto elementTy = lowering.convertType(*op->result_type_begin());
Value *viewDescriptor = operands[0];
explicit RangeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
: LLVMOpLowering(RangeOp::getOperationName(), context, lowering_) {}
- PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
- PatternRewriter &rewriter) const override {
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
auto rangeOp = cast<RangeOp>(op);
auto rangeDescriptorTy =
convertLinalgType(rangeOp.getResult()->getType(), lowering);
: LLVMOpLowering(RangeIntersectOp::getOperationName(), context,
lowering_) {}
- PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
- PatternRewriter &rewriter) const override {
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
auto rangeIntersectOp = cast<RangeIntersectOp>(op);
auto rangeDescriptorTy =
convertLinalgType(rangeIntersectOp.getResult()->getType(), lowering);
explicit SliceOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
: LLVMOpLowering(SliceOp::getOperationName(), context, lowering_) {}
- PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
- PatternRewriter &rewriter) const override {
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
auto sliceOp = cast<SliceOp>(op);
auto viewDescriptorTy = convertLinalgType(sliceOp.getViewType(), lowering);
auto viewType = sliceOp.getBaseViewType();
// an LLVM IR store.
class StoreOpConversion : public LoadStoreOpConversion<linalg::StoreOp> {
using Base::Base;
- PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
- PatternRewriter &rewriter) const override {
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
edsc::ScopedContext edscContext(rewriter, op->getLoc());
Value *data = operands[0];
Value *viewDescriptor = operands[1];
explicit ViewOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
: LLVMOpLowering(ViewOp::getOperationName(), context, lowering_) {}
- PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
- PatternRewriter &rewriter) const override {
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
auto viewOp = cast<ViewOp>(op);
auto viewDescriptorTy = convertLinalgType(viewOp.getViewType(), lowering);
auto elementTy = getPtrToElementType(viewOp.getViewType(), lowering);
// Get function definition for the LinalgOp. If it doesn't exist, insert a
// definition.
template <typename LinalgOp>
-static FuncOp getLLVMLibraryCallDeclaration(Operation *op,
- LLVMTypeConverter &lowering,
- PatternRewriter &rewriter) {
+static FuncOp
+getLLVMLibraryCallDeclaration(Operation *op, LLVMTypeConverter &lowering,
+ ConversionPatternRewriter &rewriter) {
assert(isa<LinalgOp>(op));
auto fnName = LinalgOp::getLibraryCallName();
auto module = op->getParentOfType<ModuleOp>();
LinalgTypeConverter &lowering_)
: LLVMOpLowering(LinalgOp::getOperationName(), context, lowering_) {}
- PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
- PatternRewriter &rewriter) const override {
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
// Only emit library call declaration. Fill in the body later.
auto f = getLLVMLibraryCallDeclaration<LinalgOp>(op, lowering, rewriter);
static_cast<LinalgTypeConverter &>(lowering).addLibraryFnDeclaration(f);
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
+using namespace mlir::detail;
#define DEBUG_TYPE "dialect-conversion"
/// The pattern rewriter to use when materializing conversions.
PatternRewriter &rewriter;
};
+} // end anonymous namespace
constexpr StringLiteral ArgConverter::kCastName;
}
//===----------------------------------------------------------------------===//
-// DialectConversionRewriter
+// ConversionPatternRewriterImpl
//===----------------------------------------------------------------------===//
-
+namespace {
/// This class contains a snapshot of the current conversion rewriter state.
/// This is useful when saving and undoing a set of rewrites.
struct RewriterState {
/// The current number of type conversion actions performed.
unsigned numTypeConversions;
};
+} // end anonymous namespace
-/// This class implements a pattern rewriter for ConversionPattern
-/// patterns. It automatically performs remapping of replaced operation values.
-struct DialectConversionRewriter final : public PatternRewriter {
+namespace mlir {
+namespace detail {
+struct ConversionPatternRewriterImpl {
/// This class represents one requested operation replacement via 'replaceOp'.
struct OpReplacement {
OpReplacement() = default;
NamedAttributeList originalParentAttributes;
};
- DialectConversionRewriter(MLIRContext *ctx, TypeConverter *converter)
- : PatternRewriter(ctx), argConverter(converter, *this) {}
- ~DialectConversionRewriter() = default;
+ ConversionPatternRewriterImpl(PatternRewriter &rewriter,
+ TypeConverter *converter)
+ : argConverter(converter, rewriter) {}
/// Return the current state of the rewriter.
- RewriterState getCurrentState() {
- return RewriterState(createdOps.size(), replacements.size(),
- blockActions.size(), typeConversions.size());
- }
+ RewriterState getCurrentState();
/// Reset the state of the rewriter to a previously saved point.
- void resetState(RewriterState state) {
- // Undo any type conversions or block actions.
- undoTypeConversions(state.numTypeConversions);
- undoBlockActions(state.numBlockActions);
-
- // Reset any replaced operations and undo any saved mappings.
- for (auto &repl : llvm::drop_begin(replacements, state.numReplacements))
- for (auto *result : repl.op->getResults())
- mapping.erase(result);
- replacements.resize(state.numReplacements);
-
- // Pop all of the newly created operations.
- while (createdOps.size() != state.numCreatedOperations)
- createdOps.pop_back_val()->erase();
- }
+ void resetState(RewriterState state);
/// Undo the block actions (motions, splits) one by one in reverse order until
/// "numActionsToKeep" actions remains.
- void undoBlockActions(unsigned numActionsToKeep = 0) {
- for (auto &action :
- llvm::reverse(llvm::drop_begin(blockActions, numActionsToKeep))) {
- switch (action.kind) {
- // Merge back the block that was split out.
- case BlockActionKind::Split: {
- action.originalBlock->getOperations().splice(
- action.originalBlock->end(), action.block->getOperations());
- action.block->erase();
- break;
- }
- // Move the block back to its original position.
- case BlockActionKind::Move: {
- Region *originalRegion = action.originalPosition.region;
- originalRegion->getBlocks().splice(
- std::next(originalRegion->begin(),
- action.originalPosition.position),
- action.block->getParent()->getBlocks(), action.block);
- break;
- }
- }
- }
- blockActions.resize(numActionsToKeep);
- }
+ void undoBlockActions(unsigned numActionsToKeep = 0);
/// Undo the type conversion actions one by one, until "numActionsToKeep"
/// actions remain.
- void undoTypeConversions(unsigned numActionsToKeep = 0) {
- for (auto &conversion :
- llvm::drop_begin(typeConversions, numActionsToKeep)) {
- if (auto *region = conversion.object.dyn_cast<Region *>())
- region->getContainingOp()->setAttrs(
- conversion.originalParentAttributes);
- else
- argConverter.discardPendingRewrites(conversion.object.get<Block *>());
- }
- typeConversions.resize(numActionsToKeep);
- }
+ void undoTypeConversions(unsigned numActionsToKeep = 0);
/// Cleanup and destroy any generated rewrite operations. This method is
/// invoked when the conversion process fails.
- void discardRewrites() {
- undoTypeConversions();
- undoBlockActions();
-
- // Remove any newly created ops.
- for (auto *op : createdOps) {
- op->dropAllDefinedValueUses();
- op->erase();
- }
- }
+ void discardRewrites();
/// Apply all requested operation rewrites. This method is invoked when the
/// conversion process succeeds.
- void applyRewrites() {
- // Apply all of the rewrites replacements requested during conversion.
- for (auto &repl : replacements) {
- for (unsigned i = 0, e = repl.newValues.size(); i != e; ++i)
- repl.op->getResult(i)->replaceAllUsesWith(
- mapping.lookupOrDefault(repl.newValues[i]));
-
- // if this operation defines any regions, drop any pending argument
- // rewrites.
- if (argConverter.typeConverter && repl.op->getNumRegions()) {
- for (auto ®ion : repl.op->getRegions())
- for (auto &block : region)
- argConverter.cancelPendingRewrites(&block);
- }
- }
-
- // In a second pass, erase all of the replaced operations in reverse. This
- // allows processing nested operations before their parent region is
- // destroyed.
- for (auto &repl : llvm::reverse(replacements))
- repl.op->erase();
-
- argConverter.applyRewrites();
- }
+ void applyRewrites();
/// Return if the given block has already been converted.
- bool hasSignatureBeenConverted(Block *block) {
- return argConverter.hasBeenConverted(block);
- }
+ bool hasSignatureBeenConverted(Block *block);
/// Convert the signature of the given region.
- LogicalResult convertRegionSignature(Region ®ion) {
- auto parentAttrs = region.getContainingOp()->getAttrList();
- auto result = argConverter.convertSignature(region, mapping);
- if (succeeded(result)) {
- typeConversions.push_back(TypeConversion{®ion, parentAttrs});
- if (!region.empty())
- typeConversions.push_back(
- TypeConversion{®ion.front(), NamedAttributeList()});
- }
- return result;
- }
+ LogicalResult convertRegionSignature(Region ®ion);
/// Convert the signature of the given block.
- LogicalResult convertBlockSignature(Block *block) {
- auto result = argConverter.convertSignature(block, mapping);
- if (succeeded(result))
- typeConversions.push_back(TypeConversion{block, NamedAttributeList()});
- return result;
- }
+ LogicalResult convertBlockSignature(Block *block);
/// PatternRewriter hook for replacing the results of an operation.
void replaceOp(Operation *op, ArrayRef<Value *> newValues,
- ArrayRef<Value *> valuesToRemoveIfDead) override {
- assert(newValues.size() == op->getNumResults());
-
- // Create mappings for each of the new result values.
- for (unsigned i = 0, e = newValues.size(); i < e; ++i) {
- assert((newValues[i] || op->getResult(i)->use_empty()) &&
- "result value has remaining uses that must be replaced");
- if (newValues[i])
- mapping.map(op->getResult(i), newValues[i]);
- }
+ ArrayRef<Value *> valuesToRemoveIfDead);
- // Record the requested operation replacement.
- replacements.emplace_back(op, newValues);
- }
-
- /// PatternRewriter hook for splitting a block into two parts.
- Block *splitBlock(Block *block, Block::iterator before) override {
- auto *continuation = PatternRewriter::splitBlock(block, before);
- BlockAction action;
- action.kind = BlockActionKind::Split;
- action.block = continuation;
- action.originalBlock = block;
- blockActions.push_back(action);
- return continuation;
- }
-
- /// PatternRewriter hook for moving blocks out of a region.
- void inlineRegionBefore(Region ®ion, Region &parent,
- Region::iterator before) override {
- for (auto &pair : llvm::enumerate(region)) {
- Block &block = pair.value();
- unsigned position = pair.index();
- BlockAction action;
- action.kind = BlockActionKind::Move;
- action.block = █
- action.originalPosition = {®ion, position};
- blockActions.push_back(action);
- }
- PatternRewriter::inlineRegionBefore(region, parent, before);
- }
-
- /// PatternRewriter hook for creating a new operation.
- Operation *createOperation(const OperationState &state) override {
- auto *result = OpBuilder::createOperation(state);
- createdOps.push_back(result);
- return result;
- }
+ /// Notifies that a block was split.
+ void notifySplitBlock(Block *block, Block *continuation);
- /// PatternRewriter hook for updating the root operation in-place.
- void notifyRootUpdated(Operation *op) override {
- // The rewriter caches changes to the IR to allow for operating in-place and
- // backtracking. The rewrite is currently not capable of backtracking
- // in-place modifications.
- llvm_unreachable("in-place operation updates are not supported");
- }
+ /// Notifies that the blocks of a region are about to be moved.
+ void notifyRegionIsBeingInlinedBefore(Region ®ion, Region &parent,
+ Region::iterator before);
/// Remap the given operands to those with potentially different types.
void remapValues(Operation::operand_range operands,
- SmallVectorImpl<Value *> &remapped) {
- remapped.reserve(llvm::size(operands));
- for (Value *operand : operands)
- remapped.push_back(mapping.lookupOrDefault(operand));
- }
+ SmallVectorImpl<Value *> &remapped);
// Mapping between replaced values that differ in type. This happens when
// replacing a value with one of a different type.
/// Ordered list of type conversion actions.
SmallVector<TypeConversion, 4> typeConversions;
};
-} // end anonymous namespace
+} // end namespace detail
+} // end namespace mlir
+
+RewriterState ConversionPatternRewriterImpl::getCurrentState() {
+ return RewriterState(createdOps.size(), replacements.size(),
+ blockActions.size(), typeConversions.size());
+}
+
+void ConversionPatternRewriterImpl::resetState(RewriterState state) {
+ // Undo any type conversions or block actions.
+ undoTypeConversions(state.numTypeConversions);
+ undoBlockActions(state.numBlockActions);
+
+ // Reset any replaced operations and undo any saved mappings.
+ for (auto &repl : llvm::drop_begin(replacements, state.numReplacements))
+ for (auto *result : repl.op->getResults())
+ mapping.erase(result);
+ replacements.resize(state.numReplacements);
+
+ // Pop all of the newly created operations.
+ while (createdOps.size() != state.numCreatedOperations)
+ createdOps.pop_back_val()->erase();
+}
+
+void ConversionPatternRewriterImpl::undoBlockActions(
+ unsigned numActionsToKeep) {
+ for (auto &action :
+ llvm::reverse(llvm::drop_begin(blockActions, numActionsToKeep))) {
+ switch (action.kind) {
+ // Merge back the block that was split out.
+ case BlockActionKind::Split: {
+ action.originalBlock->getOperations().splice(
+ action.originalBlock->end(), action.block->getOperations());
+ action.block->erase();
+ break;
+ }
+ // Move the block back to its original position.
+ case BlockActionKind::Move: {
+ Region *originalRegion = action.originalPosition.region;
+ originalRegion->getBlocks().splice(
+ std::next(originalRegion->begin(), action.originalPosition.position),
+ action.block->getParent()->getBlocks(), action.block);
+ break;
+ }
+ }
+ }
+ blockActions.resize(numActionsToKeep);
+}
+
+void ConversionPatternRewriterImpl::undoTypeConversions(
+ unsigned numActionsToKeep) {
+ for (auto &conversion : llvm::drop_begin(typeConversions, numActionsToKeep)) {
+ if (auto *region = conversion.object.dyn_cast<Region *>())
+ region->getContainingOp()->setAttrs(conversion.originalParentAttributes);
+ else
+ argConverter.discardPendingRewrites(conversion.object.get<Block *>());
+ }
+ typeConversions.resize(numActionsToKeep);
+}
+
+void ConversionPatternRewriterImpl::discardRewrites() {
+ undoTypeConversions();
+ undoBlockActions();
+
+ // Remove any newly created ops.
+ for (auto *op : createdOps) {
+ op->dropAllDefinedValueUses();
+ op->erase();
+ }
+}
+
+void ConversionPatternRewriterImpl::applyRewrites() {
+ // Apply all of the rewrites replacements requested during conversion.
+ for (auto &repl : replacements) {
+ for (unsigned i = 0, e = repl.newValues.size(); i != e; ++i)
+ repl.op->getResult(i)->replaceAllUsesWith(
+ mapping.lookupOrDefault(repl.newValues[i]));
+
+ // If this operation defines any regions, drop any pending argument
+ // rewrites.
+ if (argConverter.typeConverter && repl.op->getNumRegions()) {
+ for (auto ®ion : repl.op->getRegions())
+ for (auto &block : region)
+ argConverter.cancelPendingRewrites(&block);
+ }
+ }
+
+ // In a second pass, erase all of the replaced operations in reverse. This
+ // allows processing nested operations before their parent region is
+ // destroyed.
+ for (auto &repl : llvm::reverse(replacements))
+ repl.op->erase();
+
+ argConverter.applyRewrites();
+}
+
+bool ConversionPatternRewriterImpl::hasSignatureBeenConverted(Block *block) {
+ return argConverter.hasBeenConverted(block);
+}
+
+LogicalResult
+ConversionPatternRewriterImpl::convertRegionSignature(Region ®ion) {
+ auto parentAttrs = region.getContainingOp()->getAttrList();
+ auto result = argConverter.convertSignature(region, mapping);
+ if (succeeded(result)) {
+ typeConversions.push_back(TypeConversion{®ion, parentAttrs});
+ if (!region.empty())
+ typeConversions.push_back(
+ TypeConversion{®ion.front(), NamedAttributeList()});
+ }
+ return result;
+}
+
+LogicalResult
+ConversionPatternRewriterImpl::convertBlockSignature(Block *block) {
+ auto result = argConverter.convertSignature(block, mapping);
+ if (succeeded(result))
+ typeConversions.push_back(TypeConversion{block, NamedAttributeList()});
+ return result;
+}
+
+void ConversionPatternRewriterImpl::replaceOp(
+ Operation *op, ArrayRef<Value *> newValues,
+ ArrayRef<Value *> valuesToRemoveIfDead) {
+ assert(newValues.size() == op->getNumResults());
+
+ // Create mappings for each of the new result values.
+ for (unsigned i = 0, e = newValues.size(); i < e; ++i) {
+ assert((newValues[i] || op->getResult(i)->use_empty()) &&
+ "result value has remaining uses that must be replaced");
+ if (newValues[i])
+ mapping.map(op->getResult(i), newValues[i]);
+ }
+
+ // Record the requested operation replacement.
+ replacements.emplace_back(op, newValues);
+}
+
+void ConversionPatternRewriterImpl::notifySplitBlock(Block *block,
+ Block *continuation) {
+ BlockAction action;
+ action.kind = BlockActionKind::Split;
+ action.block = continuation;
+ action.originalBlock = block;
+ blockActions.push_back(action);
+}
+
+void ConversionPatternRewriterImpl::notifyRegionIsBeingInlinedBefore(
+ Region ®ion, Region &parent, Region::iterator before) {
+ for (auto &pair : llvm::enumerate(region)) {
+ Block &block = pair.value();
+ unsigned position = pair.index();
+ BlockAction action;
+ action.kind = BlockActionKind::Move;
+ action.block = █
+ action.originalPosition = {®ion, position};
+ blockActions.push_back(action);
+ }
+}
+
+void ConversionPatternRewriterImpl::remapValues(
+ Operation::operand_range operands, SmallVectorImpl<Value *> &remapped) {
+ remapped.reserve(llvm::size(operands));
+ for (Value *operand : operands)
+ remapped.push_back(mapping.lookupOrDefault(operand));
+}
+
+//===----------------------------------------------------------------------===//
+// ConversionPatternRewriter
+//===----------------------------------------------------------------------===//
+
+ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx,
+ TypeConverter *converter)
+ : PatternRewriter(ctx),
+ impl(new detail::ConversionPatternRewriterImpl(*this, converter)) {}
+ConversionPatternRewriter::~ConversionPatternRewriter() {}
+
+/// PatternRewriter hook for replacing the results of an operation.
+void ConversionPatternRewriter::replaceOp(
+ Operation *op, ArrayRef<Value *> newValues,
+ ArrayRef<Value *> valuesToRemoveIfDead) {
+ impl->replaceOp(op, newValues, valuesToRemoveIfDead);
+}
+
+/// PatternRewriter hook for splitting a block into two parts.
+Block *ConversionPatternRewriter::splitBlock(Block *block,
+ Block::iterator before) {
+ auto *continuation = PatternRewriter::splitBlock(block, before);
+ impl->notifySplitBlock(block, continuation);
+ return continuation;
+}
+
+/// PatternRewriter hook for moving blocks out of a region.
+void ConversionPatternRewriter::inlineRegionBefore(Region ®ion,
+ Region &parent,
+ Region::iterator before) {
+ impl->notifyRegionIsBeingInlinedBefore(region, parent, before);
+ PatternRewriter::inlineRegionBefore(region, parent, before);
+}
+
+/// PatternRewriter hook for creating a new operation.
+Operation *
+ConversionPatternRewriter::createOperation(const OperationState &state) {
+ auto *result = OpBuilder::createOperation(state);
+ impl->createdOps.push_back(result);
+ return result;
+}
+
+/// PatternRewriter hook for updating the root operation in-place.
+void ConversionPatternRewriter::notifyRootUpdated(Operation *op) {
+ // The rewriter caches changes to the IR to allow for operating in-place and
+ // backtracking. The rewriter is currently not capable of backtracking
+ // in-place modifications.
+ llvm_unreachable("in-place operation updates are not supported");
+}
+
+/// Return a reference to the internal implementation.
+detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
+ return *impl;
+}
//===----------------------------------------------------------------------===//
// Conversion Patterns
ConversionPattern::matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const {
SmallVector<Value *, 4> operands;
- auto &dialectRewriter = static_cast<DialectConversionRewriter &>(rewriter);
- dialectRewriter.remapValues(op->getOperands(), operands);
+ auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
+ dialectRewriter.getImpl().remapValues(op->getOperands(), operands);
// If this operation has no successors, invoke the rewrite directly.
if (op->getNumSuccessors() == 0)
- return matchAndRewrite(op, operands, rewriter);
+ return matchAndRewrite(op, operands, dialectRewriter);
// Otherwise, we need to remap the successors.
SmallVector<Block *, 2> destinations;
op,
llvm::makeArrayRef(operands.data(),
operands.data() + firstSuccessorOperand),
- destinations, operandsPerDestination, rewriter);
+ destinations, operandsPerDestination, dialectRewriter);
}
//===----------------------------------------------------------------------===//
/// Attempt to legalize the given operation. Returns success if the operation
/// was legalized, failure otherwise.
- LogicalResult legalize(Operation *op, DialectConversionRewriter &rewriter);
+ LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter);
private:
/// Attempt to legalize the given operation by applying the provided pattern.
/// Returns success if the operation was legalized, failure otherwise.
LogicalResult legalizePattern(Operation *op, RewritePattern *pattern,
- DialectConversionRewriter &rewriter);
+ ConversionPatternRewriter &rewriter);
/// Build an optimistic legalization graph given the provided patterns. This
/// function populates 'legalizerPatterns' with the operations that are not
LogicalResult
OperationLegalizer::legalize(Operation *op,
- DialectConversionRewriter &rewriter) {
+ ConversionPatternRewriter &rewriter) {
// Make sure that the signature of the parent block of this operation has been
// converted.
- if (rewriter.argConverter.typeConverter) {
+ auto &rewriterImpl = rewriter.getImpl();
+ if (rewriterImpl.argConverter.typeConverter) {
auto *block = op->getBlock();
- if (block && !rewriter.hasSignatureBeenConverted(block)) {
+ if (block && !rewriterImpl.hasSignatureBeenConverted(block)) {
if (failed(block->isEntryBlock()
- ? rewriter.convertRegionSignature(*block->getParent())
- : rewriter.convertBlockSignature(block)))
+ ? rewriterImpl.convertRegionSignature(*block->getParent())
+ : rewriterImpl.convertBlockSignature(block)))
return failure();
}
}
LogicalResult
OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern,
- DialectConversionRewriter &rewriter) {
+ ConversionPatternRewriter &rewriter) {
LLVM_DEBUG({
llvm::dbgs() << "-* Applying rewrite pattern '" << op->getName() << " -> (";
interleaveComma(pattern->getGeneratedOps(), llvm::dbgs());
return failure();
}
- RewriterState curState = rewriter.getCurrentState();
+ auto &rewriterImpl = rewriter.getImpl();
+ RewriterState curState = rewriterImpl.getCurrentState();
auto cleanupFailure = [&] {
// Reset the rewriter state and pop this pattern.
- rewriter.resetState(curState);
+ rewriterImpl.resetState(curState);
appliedPatterns.erase(pattern);
return failure();
};
// Recursively legalize each of the new operations.
for (unsigned i = curState.numCreatedOperations,
- e = rewriter.createdOps.size();
+ e = rewriterImpl.createdOps.size();
i != e; ++i) {
- if (failed(legalize(rewriter.createdOps[i], rewriter))) {
+ if (failed(legalize(rewriterImpl.createdOps[i], rewriter))) {
LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Generated operation was illegal.\n");
return cleanupFailure();
}
private:
/// Converts an operation with the given rewriter.
- LogicalResult convert(DialectConversionRewriter &rewriter, Operation *op);
+ LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op);
/// Recursively collect all of the operations, to convert from within
/// 'region'.
}
/// Converts an operation with the given rewriter.
-LogicalResult OperationConverter::convert(DialectConversionRewriter &rewriter,
+LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
Operation *op) {
// Legalize the given operation.
if (failed(opLegalizer.legalize(op, rewriter))) {
// within.
// FIXME(riverriddle) This should be replaced by patterns when the pattern
// rewriter exposes functionality to remap region signatures.
- if (rewriter.argConverter.typeConverter) {
+ auto &rewriterImpl = rewriter.getImpl();
+ if (rewriterImpl.argConverter.typeConverter) {
for (auto ®ion : op->getRegions())
- if (region.empty() && failed(rewriter.convertRegionSignature(region)))
+ if (region.empty() && failed(rewriterImpl.convertRegionSignature(region)))
return failure();
}
return success();
}
-/// Converts the given top-level operation to the conversion target.
+/// Converts the given operations to the conversion target.
LogicalResult
OperationConverter::convertOperations(ArrayRef<Operation *> ops,
TypeConverter *typeConverter) {
}
// Convert each operation and discard rewrites on failure.
- DialectConversionRewriter rewriter(ops.front()->getContext(), typeConverter);
+ ConversionPatternRewriter rewriter(ops.front()->getContext(), typeConverter);
for (auto *op : toConvert) {
if (failed(convert(rewriter, op))) {
- rewriter.discardRewrites();
+ rewriter.getImpl().discardRewrites();
return failure();
}
}
// Otherwise the body conversion succeeded, so apply all rewrites.
- rewriter.applyRewrites();
+ rewriter.getImpl().applyRewrites();
return success();
}
TestRegionRewriteBlockMovement(MLIRContext *ctx)
: ConversionPattern("test.region", 1, ctx) {}
- PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
- PatternRewriter &rewriter) const final {
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const final {
// Inline this region into the parent region.
auto &parentRegion = *op->getContainingRegion();
rewriter.inlineRegionBefore(op->getRegion(0), parentRegion,
/// This pattern simply erases the given operation.
struct TestDropOp : public ConversionPattern {
TestDropOp(MLIRContext *ctx) : ConversionPattern("test.drop_op", 1, ctx) {}
- PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
- PatternRewriter &rewriter) const final {
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const final {
rewriter.replaceOp(op, llvm::None);
return matchSuccess();
}
struct TestPassthroughInvalidOp : public ConversionPattern {
TestPassthroughInvalidOp(MLIRContext *ctx)
: ConversionPattern("test.invalid", 1, ctx) {}
- PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
- PatternRewriter &rewriter) const final {
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const final {
rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands,
llvm::None);
return matchSuccess();
struct TestSplitReturnType : public ConversionPattern {
TestSplitReturnType(MLIRContext *ctx)
: ConversionPattern("test.return", 1, ctx) {}
- PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
- PatternRewriter &rewriter) const final {
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const final {
// Check for a return of F32.
if (op->getNumOperands() != 1 || !op->getOperand(0)->getType().isF32())
return matchFailure();
lowering_.getDialect()->getContext(), lowering_) {}
// Convert the kernel arguments to an LLVM type, preserve the rest.
- PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
- PatternRewriter &rewriter) const override {
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
rewriter.clone(*op)->setOperands(operands);
return rewriter.replaceOp(op, llvm::None), matchSuccess();
}