/// Return true if the given operation has legal operand and result types.
bool isLegal(Operation *op);
+ /// Return true if the types of block arguments within the region are legal.
+ bool isLegal(Region *region);
+
/// Return true if the inputs and outputs of the given function type are
/// legal.
bool isSignatureLegal(FunctionType ty);
// Conversion Patterns
//===----------------------------------------------------------------------===//
-/// Base class for the conversion patterns that require type changes. Specific
-/// conversions must derive this class and implement least one `rewrite` method.
-/// NOTE: These conversion patterns can only be used with the 'apply*' methods
-/// below.
+/// Base class for the conversion patterns. This pattern class enables type
+/// conversions, and other uses specific to the conversion framework. As such,
+/// patterns of this type can only be used with the 'apply*' methods below.
class ConversionPattern : public RewritePattern {
public:
/// Hook for derived classes to implement rewriting. `op` is the (first)
- /// operation matched by the pattern, `operands` is a list of rewritten values
- /// that are passed to this operation, `rewriter` can be used to emit the new
- /// operations. This function should not fail. If some specific cases of
+ /// operation matched by the pattern, `operands` is a list of the rewritten
+ /// operand values that are passed to `op`, `rewriter` can be used to emit the
+ /// new operations. This function should not fail. If some specific cases of
/// the operation are not supported, these cases should not be matched.
virtual void rewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const final;
+ /// Return the type converter held by this pattern, or nullptr if the pattern
+ /// does not require type conversion.
+ TypeConverter *getTypeConverter() const { return typeConverter; }
+
protected:
+ /// See `RewritePattern::RewritePattern` for information on the other
+ /// available constructors.
using RewritePattern::RewritePattern;
+ /// Construct a conversion pattern that matches an operation with the given
+ /// root name. This constructor allows for providing a type converter to use
+ /// within the pattern.
+ ConversionPattern(StringRef rootName, PatternBenefit benefit,
+ TypeConverter &typeConverter, MLIRContext *ctx)
+ : RewritePattern(rootName, benefit, ctx), typeConverter(&typeConverter) {}
+ /// Construct a conversion pattern that matches any operation type. This
+ /// constructor allows for providing a type converter to use within the
+ /// pattern. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any"
+ /// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should
+ /// always be supplied here.
+ ConversionPattern(PatternBenefit benefit, TypeConverter &typeConverter,
+ MatchAnyOpTypeTag tag)
+ : RewritePattern(benefit, tag), typeConverter(&typeConverter) {}
+
+protected:
+ /// An optional type converter for use by this pattern.
+ TypeConverter *typeConverter;
private:
using RewritePattern::rewrite;
struct OpConversionPattern : public ConversionPattern {
OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
: ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
+ OpConversionPattern(TypeConverter &typeConverter, MLIRContext *context,
+ PatternBenefit benefit = 1)
+ : ConversionPattern(SourceOp::getOperationName(), benefit, typeConverter,
+ context) {}
/// Wrappers around the ConversionPattern methods that pass the derived op
/// type.
/// hooks.
class ConversionPatternRewriter final : public PatternRewriter {
public:
- ConversionPatternRewriter(MLIRContext *ctx, TypeConverter *converter);
+ ConversionPatternRewriter(MLIRContext *ctx);
~ConversionPatternRewriter() override;
/// Apply a signature conversion to the entry block of the given region. This
applySignatureConversion(Region *region,
TypeConverter::SignatureConversion &conversion);
+ /// Convert the types of block arguments within the given region. This
+ /// replaces each block with a new block containing the updated signature. The
+ /// entry block may have a special conversion if `entryConversion` is
+ /// provided. On success, the new entry block to the region is returned for
+ /// convenience. Otherwise, failure is returned.
+ FailureOr<Block *> convertRegionTypes(
+ Region *region, TypeConverter &converter,
+ TypeConverter::SignatureConversion *entryConversion = nullptr);
+
/// Replace all the uses of the block argument `from` with value `to`.
void replaceUsesOfBlockArgument(BlockArgument from, Value to);
/// Apply a partial conversion on the given operations and all nested
/// operations. This method converts as many operations to the target as
/// possible, ignoring operations that failed to legalize. This method only
-/// returns failure if there ops explicitly marked as illegal. If `converter` is
-/// provided, the signatures of blocks and regions are also converted.
-/// If an `unconvertedOps` set is provided, all operations that are found not
-/// to be legalizable to the given `target` are placed within that set. (Note
-/// that if there is an op explicitly marked as illegal, the conversion
-/// terminates and the `unconvertedOps` set will not necessarily be complete.)
+/// returns failure if there ops explicitly marked as illegal. If an
+/// `unconvertedOps` set is provided, all operations that are found not to be
+/// legalizable to the given `target` are placed within that set. (Note that if
+/// there is an op explicitly marked as illegal, the conversion terminates and
+/// the `unconvertedOps` set will not necessarily be complete.)
LLVM_NODISCARD LogicalResult
applyPartialConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
const OwningRewritePatternList &patterns,
- TypeConverter *converter = nullptr,
DenseSet<Operation *> *unconvertedOps = nullptr);
LLVM_NODISCARD LogicalResult
applyPartialConversion(Operation *op, ConversionTarget &target,
const OwningRewritePatternList &patterns,
- TypeConverter *converter = nullptr,
DenseSet<Operation *> *unconvertedOps = nullptr);
/// Apply a complete conversion on the given operations, and all nested
/// operations. This method returns failure if the conversion of any operation
/// fails, or if there are unreachable blocks in any of the regions nested
-/// within 'ops'. If 'converter' is provided, the signatures of blocks and
-/// regions are also converted.
+/// within 'ops'.
LLVM_NODISCARD LogicalResult
applyFullConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
- const OwningRewritePatternList &patterns,
- TypeConverter *converter = nullptr);
+ const OwningRewritePatternList &patterns);
LLVM_NODISCARD LogicalResult
applyFullConversion(Operation *op, ConversionTarget &target,
- const OwningRewritePatternList &patterns,
- TypeConverter *converter = nullptr);
+ const OwningRewritePatternList &patterns);
/// Apply an analysis conversion on the given operations, and all nested
/// operations. This method analyzes which operations would be successfully
/// provided 'convertedOps' set; note that no actual rewrites are applied to the
/// operations on success and only pre-existing operations are added to the set.
/// This method only returns failure if there are unreachable blocks in any of
-/// the regions nested within 'ops', or if a type conversion failed. If
-/// 'converter' is provided, the signatures of blocks and regions are also
-/// considered for conversion.
-LLVM_NODISCARD LogicalResult applyAnalysisConversion(
- ArrayRef<Operation *> ops, ConversionTarget &target,
- const OwningRewritePatternList &patterns,
- DenseSet<Operation *> &convertedOps, TypeConverter *converter = nullptr);
-LLVM_NODISCARD LogicalResult applyAnalysisConversion(
- Operation *op, ConversionTarget &target,
- const OwningRewritePatternList &patterns,
- DenseSet<Operation *> &convertedOps, TypeConverter *converter = nullptr);
+/// the regions nested within 'ops'.
+LLVM_NODISCARD LogicalResult
+applyAnalysisConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
+ const OwningRewritePatternList &patterns,
+ DenseSet<Operation *> &convertedOps);
+LLVM_NODISCARD LogicalResult
+applyAnalysisConversion(Operation *op, ConversionTarget &target,
+ const OwningRewritePatternList &patterns,
+ DenseSet<Operation *> &convertedOps);
} // end namespace mlir
#endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_
}
//===----------------------------------------------------------------------===//
-// Multi-Level Value Mapper
+// ConversionValueMapping
//===----------------------------------------------------------------------===//
namespace {
/// types and extracting the block that contains the old illegal types to allow
/// for undoing pending rewrites in the case of failure.
struct ArgConverter {
- ArgConverter(TypeConverter *typeConverter, PatternRewriter &rewriter)
- : loc(rewriter.getUnknownLoc()), typeConverter(typeConverter),
- rewriter(rewriter) {}
+ ArgConverter(PatternRewriter &rewriter) : rewriter(rewriter) {}
/// This structure contains the information pertaining to an argument that has
/// been converted.
/// This structure contains information pertaining to a block that has had its
/// signature converted.
struct ConvertedBlockInfo {
- ConvertedBlockInfo(Block *origBlock) : origBlock(origBlock) {}
+ ConvertedBlockInfo(Block *origBlock, TypeConverter &converter)
+ : origBlock(origBlock), converter(&converter) {}
/// The original block that was requested to have its signature converted.
Block *origBlock;
/// The conversion information for each of the arguments. The information is
/// None if the argument was dropped during conversion.
SmallVector<Optional<ConvertedArgInfo>, 1> argInfo;
+
+ /// The type converter used to convert the arguments.
+ TypeConverter *converter;
};
/// Return if the signature of the given block has already been converted.
bool hasBeenConverted(Block *block) const {
- return conversionInfo.count(block);
+ return conversionInfo.count(block) || convertedBlocks.count(block);
+ }
+
+ /// Set the type converter to use for the given region.
+ void setConverter(Region *region, TypeConverter *typeConverter) {
+ assert(typeConverter && "expected valid type converter");
+ regionToConverter[region] = typeConverter;
+ }
+
+ /// Return the type converter to use for the given region, or null if there
+ /// isn't one.
+ TypeConverter *getConverter(Region *region) {
+ return regionToConverter.lookup(region);
}
//===--------------------------------------------------------------------===//
//===--------------------------------------------------------------------===//
/// Attempt to convert the signature of the given block, if successful a new
- /// block is returned containing the new arguments. On failure, nullptr is
- /// returned.
- Block *convertSignature(Block *block, ConversionValueMapping &mapping);
+ /// block is returned containing the new arguments. Returns `block` if it did
+ /// not require conversion.
+ FailureOr<Block *> convertSignature(Block *block, TypeConverter &converter,
+ ConversionValueMapping &mapping);
/// Apply the given signature conversion on the given block. The new block
- /// containing the updated signature is returned.
+ /// containing the updated signature is returned. If no conversions were
+ /// necessary, e.g. if the block has no arguments, `block` is returned.
+ /// `converter` is used to generate any necessary cast operations that
+ /// translate between the origin argument types and those specified in the
+ /// signature conversion.
Block *applySignatureConversion(
- Block *block, TypeConverter::SignatureConversion &signatureConversion,
+ Block *block, TypeConverter &converter,
+ TypeConverter::SignatureConversion &signatureConversion,
ConversionValueMapping &mapping);
/// Insert a new conversion into the cache.
void insertConversion(Block *newBlock, ConvertedBlockInfo &&info);
- /// A collection of blocks that have had their arguments converted.
+ /// A collection of blocks that have had their arguments converted. This is a
+ /// map from the new replacement block, back to the original block.
llvm::MapVector<Block *, ConvertedBlockInfo> conversionInfo;
+ /// The set of original blocks that were converted.
+ DenseSet<Block *> convertedBlocks;
+
/// A mapping from valid regions, to those containing the original blocks of a
/// conversion.
DenseMap<Region *, std::unique_ptr<Region>> regionMapping;
- /// An instance of the unknown location that is used when materializing
- /// conversions.
- Location loc;
-
- /// The type converter to use when changing types.
- TypeConverter *typeConverter;
+ /// A mapping of regions to type converters that should be used when
+ /// converting the arguments of blocks within that region.
+ DenseMap<Region *, TypeConverter *> regionToConverter;
/// The pattern rewriter to use when materializing conversions.
PatternRewriter &rewriter;
// Rewrite Application
void ArgConverter::notifyOpRemoved(Operation *op) {
+ if (conversionInfo.empty())
+ return;
+
for (Region ®ion : op->getRegions()) {
for (Block &block : region) {
// Drop any rewrites from within.
origBlock->moveBefore(block);
block->erase();
+ convertedBlocks.erase(origBlock);
conversionInfo.erase(it);
}
// persist in the IR after conversion.
if (!origArg.use_empty()) {
rewriter.setInsertionPointToStart(newBlock);
- Value newArg = typeConverter->materializeConversion(
- rewriter, loc, origArg.getType(), llvm::None);
+ Value newArg = blockInfo.converter->materializeConversion(
+ rewriter, origArg.getLoc(), origArg.getType(), llvm::None);
assert(newArg &&
"Couldn't materialize a block argument after 1->0 conversion");
origArg.replaceAllUsesWith(newArg);
//===----------------------------------------------------------------------===//
// Conversion
-Block *ArgConverter::convertSignature(Block *block,
- ConversionValueMapping &mapping) {
- if (auto conversion = typeConverter->convertBlockSignature(block))
- return applySignatureConversion(block, *conversion, mapping);
- return nullptr;
+FailureOr<Block *>
+ArgConverter::convertSignature(Block *block, TypeConverter &converter,
+ ConversionValueMapping &mapping) {
+ // Check if the block was already converted. If the block is detached,
+ // conservatively assume it is going to be deleted.
+ if (hasBeenConverted(block) || !block->getParent())
+ return block;
+
+ // Try to convert the signature for the block with the provided converter.
+ if (auto conversion = converter.convertBlockSignature(block))
+ return applySignatureConversion(block, converter, *conversion, mapping);
+ return failure();
}
Block *ArgConverter::applySignatureConversion(
- Block *block, TypeConverter::SignatureConversion &signatureConversion,
+ Block *block, TypeConverter &converter,
+ TypeConverter::SignatureConversion &signatureConversion,
ConversionValueMapping &mapping) {
// If no arguments are being changed or added, there is nothing to do.
unsigned origArgCount = block->getNumArguments();
// Remap each of the original arguments as determined by the signature
// conversion.
- ConvertedBlockInfo info(block);
+ ConvertedBlockInfo info(block, converter);
info.argInfo.resize(origArgCount);
OpBuilder::InsertionGuard guard(rewriter);
// to pack the new values. For 1->1 mappings, if there is no materialization
// provided, use the argument directly instead.
auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size);
- Value newArg;
- if (typeConverter)
- newArg = typeConverter->materializeConversion(
- rewriter, loc, origArg.getType(), replArgs);
+ Value newArg = converter.materializeConversion(rewriter, origArg.getLoc(),
+ origArg.getType(), replArgs);
if (!newArg) {
assert(replArgs.size() == 1 &&
"couldn't materialize the result of 1->N conversion");
// Move the original block to the mapped region and emplace the conversion.
mappedRegion->getBlocks().splice(mappedRegion->end(), region->getBlocks(),
info.origBlock->getIterator());
+ convertedBlocks.insert(info.origBlock);
conversionInfo.insert({newBlock, std::move(info)});
}
};
};
- ConversionPatternRewriterImpl(PatternRewriter &rewriter,
- TypeConverter *converter)
- : argConverter(converter, rewriter) {}
+ ConversionPatternRewriterImpl(PatternRewriter &rewriter)
+ : argConverter(rewriter) {}
/// Return the current state of the rewriter.
RewriterState getCurrentState();
void applyRewrites();
/// Convert the signature of the given block.
- LogicalResult convertBlockSignature(Block *block);
+ FailureOr<Block *> convertBlockSignature(
+ Block *block, TypeConverter &converter,
+ TypeConverter::SignatureConversion *conversion = nullptr);
/// Apply a signature conversion on the given region.
Block *
applySignatureConversion(Region *region,
TypeConverter::SignatureConversion &conversion);
+ /// Convert the types of block arguments within the given region.
+ FailureOr<Block *>
+ convertRegionTypes(Region *region, TypeConverter &converter,
+ TypeConverter::SignatureConversion *entryConversion);
+
/// PatternRewriter hook for replacing the results of an operation.
void replaceOp(Operation *op, ValueRange newValues);
/// A logger used to emit diagnostics during the conversion process.
llvm::ScopedPrinter logger{llvm::dbgs()};
#endif
+
+ /// A default type converter, used when block conversions do not have one
+ /// explicitly provided.
+ TypeConverter defaultTypeConverter;
};
} // end namespace detail
} // end namespace mlir
// If this operation defines any regions, drop any pending argument
// rewrites.
- if (argConverter.typeConverter && repl.op->getNumRegions())
+ if (repl.op->getNumRegions())
argConverter.notifyOpRemoved(repl.op);
}
eraseDanglingBlocks();
}
-LogicalResult
-ConversionPatternRewriterImpl::convertBlockSignature(Block *block) {
- // Check to see if this block should not be converted:
- // * There is no type converter.
- // * The block has already been converted.
- // * This is an entry block, these are converted explicitly via patterns.
- if (!argConverter.typeConverter || argConverter.hasBeenConverted(block) ||
- !block->getParent() || block->isEntryBlock())
- return success();
-
- // Otherwise, try to convert the block signature.
- Block *newBlock = argConverter.convertSignature(block, mapping);
- if (newBlock)
- blockActions.push_back(BlockAction::getTypeConversion(newBlock));
- return success(newBlock);
+FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature(
+ Block *block, TypeConverter &converter,
+ TypeConverter::SignatureConversion *conversion) {
+ FailureOr<Block *> result =
+ conversion ? argConverter.applySignatureConversion(block, converter,
+ *conversion, mapping)
+ : argConverter.convertSignature(block, converter, mapping);
+ if (Block *newBlock = result.getValue()) {
+ if (newBlock != block)
+ blockActions.push_back(BlockAction::getTypeConversion(newBlock));
+ }
+ return result;
}
Block *ConversionPatternRewriterImpl::applySignatureConversion(
Region *region, TypeConverter::SignatureConversion &conversion) {
if (!region->empty()) {
- Block *newEntry = argConverter.applySignatureConversion(
- ®ion->front(), conversion, mapping);
- blockActions.push_back(BlockAction::getTypeConversion(newEntry));
- return newEntry;
+ return *convertBlockSignature(®ion->front(), defaultTypeConverter,
+ &conversion);
}
return nullptr;
}
+FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
+ Region *region, TypeConverter &converter,
+ TypeConverter::SignatureConversion *entryConversion) {
+ argConverter.setConverter(region, &converter);
+ if (region->empty())
+ return nullptr;
+
+ // Convert the arguments of each block within the region.
+ FailureOr<Block *> newEntry =
+ convertBlockSignature(®ion->front(), converter, entryConversion);
+ for (Block &block : llvm::make_early_inc_range(llvm::drop_begin(*region, 1)))
+ if (failed(convertBlockSignature(&block, converter)))
+ return failure();
+ return newEntry;
+}
+
void ConversionPatternRewriterImpl::replaceOp(Operation *op,
ValueRange newValues) {
assert(newValues.size() == op->getNumResults());
// ConversionPatternRewriter
//===----------------------------------------------------------------------===//
-ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx,
- TypeConverter *converter)
+ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx)
: PatternRewriter(ctx),
- impl(new detail::ConversionPatternRewriterImpl(*this, converter)) {}
+ impl(new detail::ConversionPatternRewriterImpl(*this)) {}
ConversionPatternRewriter::~ConversionPatternRewriter() {}
/// PatternRewriter hook for replacing the results of an operation.
block->getParent()->getBlocks().remove(block);
}
-/// Apply a signature conversion to the entry block of the given region.
Block *ConversionPatternRewriter::applySignatureConversion(
Region *region, TypeConverter::SignatureConversion &conversion) {
return impl->applySignatureConversion(region, conversion);
}
+FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
+ Region *region, TypeConverter &converter,
+ TypeConverter::SignatureConversion *entryConversion) {
+ return impl->convertRegionTypes(region, converter, entryConversion);
+}
+
void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
Value to) {
LLVM_DEBUG({
ConversionPatternRewriter &rewriter,
RewriterState &curState);
+ /// Legalizes the actions registered during the execution of a pattern.
+ LogicalResult legalizePatternBlockActions(Operation *op,
+ ConversionPatternRewriter &rewriter,
+ ConversionPatternRewriterImpl &impl,
+ RewriterState &state,
+ RewriterState &newState);
+ LogicalResult legalizePatternCreatedOperations(
+ ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
+ RewriterState &state, RewriterState &newState);
+ LogicalResult legalizePatternRootUpdates(ConversionPatternRewriter &rewriter,
+ ConversionPatternRewriterImpl &impl,
+ RewriterState &state,
+ RewriterState &newState);
+
/// Build an optimistic legalization graph given the provided patterns. This
/// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with
/// patterns for operations that are not directly legal, but may be
LogicalResult OperationLegalizer::legalizePatternResult(
Operation *op, const RewritePattern &pattern,
ConversionPatternRewriter &rewriter, RewriterState &curState) {
- auto &rewriterImpl = rewriter.getImpl();
+ auto &impl = rewriter.getImpl();
#ifndef NDEBUG
- assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
+ assert(impl.pendingRootUpdates.empty() && "dangling root updates");
#endif
- // If the pattern moved or created any blocks, try to legalize their types.
- // This ensures that the types of the block arguments are legal for the region
- // they were moved into.
- for (unsigned i = curState.numBlockActions,
- e = rewriterImpl.blockActions.size();
- i != e; ++i) {
- auto &action = rewriterImpl.blockActions[i];
- if (action.kind ==
- ConversionPatternRewriterImpl::BlockActionKind::TypeConversion ||
- action.kind == ConversionPatternRewriterImpl::BlockActionKind::Erase)
- continue;
-
- // Convert the block signature.
- if (failed(rewriterImpl.convertBlockSignature(action.block))) {
- LLVM_DEBUG(logFailure(rewriterImpl.logger,
- "failed to convert types of moved block"));
- return failure();
- }
- }
-
// Check all of the replacements to ensure that the pattern actually replaced
// the root operation. We also mark any other replaced ops as 'dead' so that
// we don't try to legalize them later.
bool replacedRoot = false;
- for (unsigned i = curState.numReplacements,
- e = rewriterImpl.replacements.size();
+ for (unsigned i = curState.numReplacements, e = impl.replacements.size();
i != e; ++i) {
- Operation *replacedOp = rewriterImpl.replacements[i].op;
+ Operation *replacedOp = impl.replacements[i].op;
if (replacedOp == op)
replacedRoot = true;
else
- rewriterImpl.ignoredOps.insert(replacedOp);
+ impl.ignoredOps.insert(replacedOp);
}
// Check that the root was either updated or replace.
auto updatedRootInPlace = [&] {
return llvm::any_of(
- llvm::drop_begin(rewriterImpl.rootUpdates, curState.numRootUpdates),
+ llvm::drop_begin(impl.rootUpdates, curState.numRootUpdates),
[op](auto &state) { return state.getOperation() == op; });
};
(void)replacedRoot;
assert((replacedRoot || updatedRootInPlace()) &&
"expected pattern to replace the root operation");
- // Recursively legalize each of the operations updated in place.
- for (unsigned i = curState.numRootUpdates,
- e = rewriterImpl.rootUpdates.size();
- i != e; ++i) {
- auto &state = rewriterImpl.rootUpdates[i];
- if (failed(legalize(state.getOperation(), rewriter))) {
- LLVM_DEBUG(logFailure(rewriterImpl.logger,
- "operation updated in-place '{0}' was illegal",
- op->getName()));
+ // Legalize each of the actions registered during application.
+ RewriterState newState = impl.getCurrentState();
+ if (failed(legalizePatternBlockActions(op, rewriter, impl, curState,
+ newState)) ||
+ failed(legalizePatternRootUpdates(rewriter, impl, curState, newState)) ||
+ failed(legalizePatternCreatedOperations(rewriter, impl, curState,
+ newState))) {
+ return failure();
+ }
+
+ LLVM_DEBUG(logSuccess(impl.logger, "pattern applied successfully"));
+ return success();
+}
+
+LogicalResult OperationLegalizer::legalizePatternBlockActions(
+ Operation *op, ConversionPatternRewriter &rewriter,
+ ConversionPatternRewriterImpl &impl, RewriterState &state,
+ RewriterState &newState) {
+ SmallPtrSet<Operation *, 16> operationsToIgnore;
+
+ // If the pattern moved or created any blocks, make sure the types of block
+ // arguments get legalized.
+ for (int i = state.numBlockActions, e = newState.numBlockActions; i != e;
+ ++i) {
+ auto &action = impl.blockActions[i];
+ if (action.kind ==
+ ConversionPatternRewriterImpl::BlockActionKind::TypeConversion ||
+ action.kind == ConversionPatternRewriterImpl::BlockActionKind::Erase)
+ continue;
+ // Only check blocks outside of the current operation.
+ Operation *parentOp = action.block->getParentOp();
+ if (!parentOp || parentOp == op || action.block->getNumArguments() == 0)
+ continue;
+
+ // If the region of the block has a type converter, try to convert the block
+ // directly.
+ if (auto *converter =
+ impl.argConverter.getConverter(action.block->getParent())) {
+ if (failed(impl.convertBlockSignature(action.block, *converter))) {
+ LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
+ "block"));
+ return failure();
+ }
+ continue;
+ }
+
+ // Otherwise, check that this operation isn't one generated by this pattern.
+ // This is because we will attempt to legalize the parent operation, and
+ // blocks in regions created by this pattern will already be legalized later
+ // on. If we haven't built the set yet, build it now.
+ if (operationsToIgnore.empty()) {
+ auto createdOps = ArrayRef<Operation *>(impl.createdOps)
+ .drop_front(state.numCreatedOps);
+ operationsToIgnore.insert(createdOps.begin(), createdOps.end());
+ }
+
+ // If this operation should be considered for re-legalization, try it.
+ if (operationsToIgnore.insert(parentOp).second &&
+ failed(legalize(parentOp, rewriter))) {
+ LLVM_DEBUG(logFailure(
+ impl.logger, "operation '{0}'({1}) became illegal after block action",
+ parentOp->getName(), parentOp));
return failure();
}
}
-
- // Recursively legalize each of the new operations.
- for (unsigned i = curState.numCreatedOps, e = rewriterImpl.createdOps.size();
- i != e; ++i) {
- Operation *op = rewriterImpl.createdOps[i];
+ return success();
+}
+LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
+ ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
+ RewriterState &state, RewriterState &newState) {
+ for (int i = state.numCreatedOps, e = newState.numCreatedOps; i != e; ++i) {
+ Operation *op = impl.createdOps[i];
if (failed(legalize(op, rewriter))) {
- LLVM_DEBUG(logFailure(rewriterImpl.logger,
+ LLVM_DEBUG(logFailure(impl.logger,
"generated operation '{0}'({1}) was illegal",
op->getName(), op));
return failure();
}
}
-
- LLVM_DEBUG(logSuccess(rewriterImpl.logger, "pattern applied successfully"));
+ return success();
+}
+LogicalResult OperationLegalizer::legalizePatternRootUpdates(
+ ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
+ RewriterState &state, RewriterState &newState) {
+ for (int i = state.numRootUpdates, e = newState.numRootUpdates; i != e; ++i) {
+ Operation *op = impl.rootUpdates[i].getOperation();
+ if (failed(legalize(op, rewriter))) {
+ LLVM_DEBUG(logFailure(impl.logger,
+ "operation updated in-place '{0}' was illegal",
+ op->getName()));
+ return failure();
+ }
+ }
return success();
}
: opLegalizer(target, patterns), mode(mode), trackedOps(trackedOps) {}
/// Converts the given operations to the conversion target.
- LogicalResult convertOperations(ArrayRef<Operation *> ops,
- TypeConverter *typeConverter);
+ LogicalResult convertOperations(ArrayRef<Operation *> ops);
private:
/// Converts an operation with the given rewriter.
LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op);
- /// Converts the type signatures of the blocks nested within 'op'.
- LogicalResult convertBlockSignatures(ConversionPatternRewriter &rewriter,
- Operation *op);
-
/// The legalizer to use when converting operations.
OperationLegalizer opLegalizer;
};
} // end anonymous namespace
-LogicalResult
-OperationConverter::convertBlockSignatures(ConversionPatternRewriter &rewriter,
- Operation *op) {
- // Check to see if type signatures need to be converted.
- if (!rewriter.getImpl().argConverter.typeConverter)
- return success();
-
- for (auto ®ion : op->getRegions()) {
- for (auto &block : llvm::make_early_inc_range(region))
- if (failed(rewriter.getImpl().convertBlockSignature(&block)))
- return failure();
- }
- return success();
-}
-
LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
Operation *op) {
// Legalize the given operation.
if (trackedOps)
trackedOps->insert(op);
}
- } else {
+ } else if (mode == OpConversionMode::Analysis) {
// Analysis conversions don't fail if any operations fail to legalize,
// they are only interested in the operations that were successfully
// legalized.
- if (mode == OpConversionMode::Analysis)
- trackedOps->insert(op);
-
- // If legalization succeeded, convert the types any of the blocks within
- // this operation.
- if (failed(convertBlockSignatures(rewriter, op)))
- return failure();
+ trackedOps->insert(op);
}
return success();
}
-LogicalResult
-OperationConverter::convertOperations(ArrayRef<Operation *> ops,
- TypeConverter *typeConverter) {
+LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
if (ops.empty())
return success();
ConversionTarget &target = opLegalizer.getTarget();
}
// Convert each operation and discard rewrites on failure.
- ConversionPatternRewriter rewriter(ops.front()->getContext(), typeConverter);
+ ConversionPatternRewriter rewriter(ops.front()->getContext());
for (auto *op : toConvert)
if (failed(convert(rewriter, op)))
return rewriter.getImpl().discardRewrites(), failure();
return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes());
}
+/// Return true if the types of block arguments within the region are legal.
+bool TypeConverter::isLegal(Region *region) {
+ return llvm::all_of(*region, [this](Block &block) {
+ return isLegal(block.getArgumentTypes());
+ });
+}
+
/// Return true if the inputs and outputs of the given function type are
/// legal.
bool TypeConverter::isSignatureLegal(FunctionType ty) {
namespace {
struct FuncOpSignatureConversion : public OpConversionPattern<FuncOp> {
FuncOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter)
- : OpConversionPattern(ctx), converter(converter) {}
+ : OpConversionPattern(converter, ctx) {}
/// Hook for derived classes to implement combined matching and rewriting.
LogicalResult
// Convert the original function types.
TypeConverter::SignatureConversion result(type.getNumInputs());
- SmallVector<Type, 1> convertedResults;
- if (failed(converter.convertSignatureArgs(type.getInputs(), result)) ||
- failed(converter.convertTypes(type.getResults(), convertedResults)))
+ SmallVector<Type, 1> newResults;
+ if (failed(typeConverter->convertSignatureArgs(type.getInputs(), result)) ||
+ failed(typeConverter->convertTypes(type.getResults(), newResults)) ||
+ failed(rewriter.convertRegionTypes(&funcOp.getBody(), *typeConverter,
+ &result)))
return failure();
// Update the function signature in-place.
rewriter.updateRootInPlace(funcOp, [&] {
- funcOp.setType(FunctionType::get(result.getConvertedTypes(),
- convertedResults, funcOp.getContext()));
- rewriter.applySignatureConversion(&funcOp.getBody(), result);
+ funcOp.setType(FunctionType::get(result.getConvertedTypes(), newResults,
+ funcOp.getContext()));
});
return success();
}
-
- /// The type converter to use when rewriting the signature.
- TypeConverter &converter;
};
} // end anonymous namespace
/// Apply a partial conversion on the given operations and all nested
/// operations. This method converts as many operations to the target as
/// possible, ignoring operations that failed to legalize. This method only
-/// returns failure if there ops explicitly marked as illegal. If `converter` is
-/// provided, the signatures of blocks and regions are also converted.
+/// returns failure if there ops explicitly marked as illegal.
/// If an `unconvertedOps` set is provided, all operations that are found not
/// to be legalizable to the given `target` are placed within that set. (Note
/// that if there is an op explicitly marked as illegal, the conversion
/// terminates and the `unconvertedOps` set will not necessarily be complete.)
-LogicalResult mlir::applyPartialConversion(
- ArrayRef<Operation *> ops, ConversionTarget &target,
- const OwningRewritePatternList &patterns, TypeConverter *converter,
- DenseSet<Operation *> *unconvertedOps) {
+LogicalResult
+mlir::applyPartialConversion(ArrayRef<Operation *> ops,
+ ConversionTarget &target,
+ const OwningRewritePatternList &patterns,
+ DenseSet<Operation *> *unconvertedOps) {
OperationConverter opConverter(target, patterns, OpConversionMode::Partial,
unconvertedOps);
- return opConverter.convertOperations(ops, converter);
+ return opConverter.convertOperations(ops);
}
LogicalResult
mlir::applyPartialConversion(Operation *op, ConversionTarget &target,
const OwningRewritePatternList &patterns,
- TypeConverter *converter,
DenseSet<Operation *> *unconvertedOps) {
return applyPartialConversion(llvm::makeArrayRef(op), target, patterns,
- converter, unconvertedOps);
+ unconvertedOps);
}
/// Apply a complete conversion on the given operations, and all nested
/// operation fails.
LogicalResult
mlir::applyFullConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
- const OwningRewritePatternList &patterns,
- TypeConverter *converter) {
+ const OwningRewritePatternList &patterns) {
OperationConverter opConverter(target, patterns, OpConversionMode::Full);
- return opConverter.convertOperations(ops, converter);
+ return opConverter.convertOperations(ops);
}
LogicalResult
mlir::applyFullConversion(Operation *op, ConversionTarget &target,
- const OwningRewritePatternList &patterns,
- TypeConverter *converter) {
- return applyFullConversion(llvm::makeArrayRef(op), target, patterns,
- converter);
+ const OwningRewritePatternList &patterns) {
+ return applyFullConversion(llvm::makeArrayRef(op), target, patterns);
}
/// Apply an analysis conversion on the given operations, and all nested
/// were found to be legalizable to the given 'target' are placed within the
/// provided 'convertedOps' set; note that no actual rewrites are applied to the
/// operations on success and only pre-existing operations are added to the set.
-LogicalResult mlir::applyAnalysisConversion(
- ArrayRef<Operation *> ops, ConversionTarget &target,
- const OwningRewritePatternList &patterns,
- DenseSet<Operation *> &convertedOps, TypeConverter *converter) {
+LogicalResult
+mlir::applyAnalysisConversion(ArrayRef<Operation *> ops,
+ ConversionTarget &target,
+ const OwningRewritePatternList &patterns,
+ DenseSet<Operation *> &convertedOps) {
OperationConverter opConverter(target, patterns, OpConversionMode::Analysis,
&convertedOps);
- return opConverter.convertOperations(ops, converter);
+ return opConverter.convertOperations(ops);
}
LogicalResult
mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target,
const OwningRewritePatternList &patterns,
- DenseSet<Operation *> &convertedOps,
- TypeConverter *converter) {
+ DenseSet<Operation *> &convertedOps) {
return applyAnalysisConversion(llvm::makeArrayRef(op), target, patterns,
- convertedOps, converter);
+ convertedOps);
}