#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/MapVector.h"
namespace mlir {
SmallVectorImpl<NamedAttributeList> &convertedArgAttrs);
};
+/// This class describes a specific conversion target.
+class ConversionTarget {
+public:
+ /// This enumeration corresponds to the specific action to take when
+ /// considering an operation legal for this conversion target.
+ enum class LegalizationAction {
+ /// The target supports this operation.
+ Legal,
+
+ /// This operation has dynamic legalization constraints that must be checked
+ /// by the target.
+ Dynamic
+ };
+
+ /// The type used to store operation legality information.
+ using LegalityMapTy = llvm::MapVector<OperationName, LegalizationAction>;
+
+ ConversionTarget(MLIRContext &ctx) : ctx(ctx) {}
+ virtual ~ConversionTarget() = default;
+
+ /// Runs a custom legalization query for the given operation. This should
+ /// return true if the given operation is legal, otherwise false.
+ virtual bool isLegal(Operation *op) const {
+ llvm_unreachable(
+ "targets with custom legalization must override 'isLegal'");
+ }
+
+ /// Register a legality action for the given operation.
+ void setOpAction(OperationName op, LegalizationAction action) {
+ legalOperations[op] = action;
+ }
+ template <typename OpT> void setOpAction(LegalizationAction action) {
+ setOpAction(OperationName(OpT::getOperationName(), &ctx), action);
+ }
+
+ /// Register the given operations as legal.
+ template <typename OpT> void addLegalOp() {
+ setOpAction<OpT>(LegalizationAction::Legal);
+ }
+ template <typename OpT, typename OpT2, typename... OpTs> void addLegalOp() {
+ addLegalOp<OpT>();
+ addLegalOp<OpT2, OpTs...>();
+ }
+
+ /// Register the operations of the given dialects as legal.
+ void addLegalDialects(ArrayRef<StringRef> dialectNames);
+ template <typename... Names>
+ void addLegalDialects(StringRef name, Names... names) {
+ SmallVector<StringRef, 2> dialectNames({name, names...});
+ addLegalDialects(dialectNames);
+ }
+ template <typename... Args> void addLegalDialects() {
+ SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
+ addLegalDialects(dialectNames);
+ }
+
+ /// Register the given operation as dynamically legal, i.e. requiring custom
+ /// handling by the target via 'isLegal'.
+ template <typename OpT> void addDynamicallyLegalOp() {
+ setOpAction<OpT>(LegalizationAction::Dynamic);
+ }
+ template <typename OpT, typename OpT2, typename... OpTs>
+ void addDynamicallyLegalOp() {
+ addDynamicallyLegalOp<OpT>();
+ addDynamicallyLegalOp<OpT2, OpTs...>();
+ }
+
+ /// Get the legality action for the given operation.
+ llvm::Optional<LegalizationAction> getOpAction(OperationName op) const {
+ auto it = legalOperations.find(op);
+ if (it != legalOperations.end())
+ return it->second;
+ return llvm::None;
+ }
+
+ /// Returns a range of operations that this target has defined to be legal in
+ /// some capacity.
+ llvm::iterator_range<LegalityMapTy::const_iterator> getLegalOps() const {
+ return llvm::make_range(legalOperations.begin(), legalOperations.end());
+ }
+
+private:
+ /// A deterministic mapping of operation name to the specific legality action
+ /// to take.
+ LegalityMapTy legalOperations;
+
+ /// The current context this target applies to.
+ MLIRContext &ctx;
+};
+
/// Convert the given module with the provided conversion patterns and type
/// conversion object. If conversion fails for specific functions, those
/// functions remains unmodified.
-LLVM_NODISCARD
-LogicalResult applyConversionPatterns(Module &module, TypeConverter &converter,
- OwningRewritePatternList &&patterns);
+LLVM_NODISCARD LogicalResult applyConversionPatterns(
+ Module &module, ConversionTarget &target, TypeConverter &converter,
+ OwningRewritePatternList &&patterns);
/// Convert the given function with the provided conversion patterns. This will
/// convert as many of the operations within 'fn' as possible given the set of
/// patterns.
LLVM_NODISCARD
-LogicalResult applyConversionPatterns(Function &fn,
+LogicalResult applyConversionPatterns(Function &fn, ConversionTarget &target,
OwningRewritePatternList &&patterns);
} // end namespace mlir
#include "mlir/IR/Function.h"
#include "mlir/IR/Module.h"
#include "mlir/Transforms/Utils.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
using namespace mlir;
-using namespace mlir::impl;
+
+#define DEBUG_TYPE "dialect-conversion"
//===----------------------------------------------------------------------===//
// ArgConverter
assert(newValues.size() == op->getNumResults());
// Create mappings for any type changes.
for (unsigned i = 0, e = newValues.size(); i < e; ++i)
- if (op->getResult(i)->getType() != newValues[i]->getType())
+ if (newValues[i] &&
+ op->getResult(i)->getType() != newValues[i]->getType())
mapping.map(op->getResult(i), newValues[i]);
// Record the requested operation replacement.
}
//===----------------------------------------------------------------------===//
+// ConversionTarget
+//===----------------------------------------------------------------------===//
+
+/// Register the operations of the given dialects as legal.
+void ConversionTarget::addLegalDialects(ArrayRef<StringRef> dialectNames) {
+ SmallPtrSet<Dialect *, 2> dialects;
+ for (auto dialectName : dialectNames)
+ if (auto *dialect = ctx.getRegisteredDialect(dialectName))
+ dialects.insert(dialect);
+
+ // Set all dialect operations as legal.
+ for (auto op : ctx.getRegisteredOperations())
+ if (dialects.count(&op->dialect))
+ setOpAction(op, LegalizationAction::Legal);
+}
+
+//===----------------------------------------------------------------------===//
+// OperationLegalizer
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// This class represents the information necessary for legalizing an operation
+/// kind.
+struct OpLegality {
+ /// This is the legalization action specified by the target, if it provided
+ /// one.
+ llvm::Optional<ConversionTarget::LegalizationAction> targetAction;
+
+ /// The set of patterns to apply to an instance of this operation to legalize
+ /// it.
+ SmallVector<RewritePattern *, 1> patterns;
+};
+
+/// This class defines a recursive operation legalizer.
+class OperationLegalizer {
+public:
+ OperationLegalizer(ConversionTarget &targetInfo,
+ OwningRewritePatternList &patterns)
+ : target(targetInfo) {
+ buildLegalizationGraph(patterns);
+ }
+
+ /// Attempt to legalize the given operation. Returns success if the operation
+ /// was legalized, failure otherwise.
+ LogicalResult legalize(Operation *op, DialectConversionRewriter &rewriter);
+
+private:
+ /// Attempt to legalize the given operation by applying the provided pattern.
+ /// Returns success if the operation was legalized, failure otherwise.
+ LogicalResult legalizePattern(Operation *op, RewritePattern *pattern,
+ DialectConversionRewriter &rewriter);
+
+ /// Build an optimistic legalization graph given the provided patterns. This
+ /// function populates 'legalOps' with the operations that are either legal,
+ /// or transitively legal for the current target given the provided patterns.
+ void buildLegalizationGraph(OwningRewritePatternList &patterns);
+
+ /// The current set of patterns that have been applied.
+ llvm::SmallPtrSet<RewritePattern *, 8> appliedPatterns;
+
+ /// The set of legality information for operations transitively supported by
+ /// the target.
+ DenseMap<OperationName, OpLegality> legalOps;
+
+ /// The legalization information provided by the target.
+ ConversionTarget ⌖
+};
+} // namespace
+
+LogicalResult
+OperationLegalizer::legalize(Operation *op,
+ DialectConversionRewriter &rewriter) {
+ LLVM_DEBUG(llvm::dbgs() << "Legalizing operation : " << op->getName()
+ << "\n");
+
+ auto it = legalOps.find(op->getName());
+ if (it == legalOps.end()) {
+ LLVM_DEBUG(llvm::dbgs() << "-- FAIL : no known legalization path.\n");
+ return failure();
+ }
+
+ // Check if this was marked legal by the target.
+ auto &opInfo = it->second;
+ if (auto action = opInfo.targetAction) {
+ // Check if this operation is always legal.
+ if (*action == ConversionTarget::LegalizationAction::Legal)
+ return success();
+
+ // Otherwise, handle custom legalization.
+ LLVM_DEBUG(llvm::dbgs() << "- Trying dynamic legalization.\n");
+ if (target.isLegal(op))
+ return success();
+
+ // Fallthough to see if a pattern can convert this into a legal operation.
+ }
+
+ // Otherwise, we need to apply a legalization pattern to this operation.
+ // TODO(riverriddle) This currently has no cost model and doesn't prioritize
+ // specific patterns in any way.
+ for (auto *pattern : opInfo.patterns)
+ if (succeeded(legalizePattern(op, pattern, rewriter)))
+ return success();
+
+ LLVM_DEBUG(llvm::dbgs() << "-- FAIL : no matched legalization pattern.\n");
+ return failure();
+}
+
+LogicalResult
+OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern,
+ DialectConversionRewriter &rewriter) {
+ LLVM_DEBUG({
+ llvm::dbgs() << "-* Applying rewrite pattern '" << op->getName() << " -> (";
+ interleaveComma(pattern->getGeneratedOps(), llvm::dbgs());
+ llvm::dbgs() << ")'.\n";
+ });
+
+ // Ensure that we don't cycle by not allowing the same pattern to be
+ // applied twice in the same recursion stack.
+ // TODO(riverriddle) We could eventually converge, but that requires more
+ // complicated analysis.
+ if (!appliedPatterns.insert(pattern).second) {
+ LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Pattern was already applied.\n");
+ return failure();
+ }
+
+ auto curOpCount = rewriter.createdOps.size();
+ auto curReplCount = rewriter.replacements.size();
+ auto cleanupFailure = [&] {
+ // Pop all of the newly created operations and replacements.
+ while (rewriter.createdOps.size() != curOpCount)
+ rewriter.createdOps.pop_back_val()->erase();
+ rewriter.replacements.resize(curReplCount);
+ appliedPatterns.erase(pattern);
+ return failure();
+ };
+
+ // Try to rewrite with the given pattern.
+ rewriter.setInsertionPoint(op);
+ if (!pattern->matchAndRewrite(op, rewriter)) {
+ LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Pattern failed to match.\n");
+ return cleanupFailure();
+ }
+
+ // Recursively legalize each of the new operations.
+ for (unsigned i = curOpCount, e = rewriter.createdOps.size(); i != e; ++i) {
+ if (succeeded(legalize(rewriter.createdOps[i], rewriter)))
+ continue;
+ LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Generated operation was illegal.\n");
+ return cleanupFailure();
+ }
+
+ appliedPatterns.erase(pattern);
+ return success();
+}
+
+void OperationLegalizer::buildLegalizationGraph(
+ OwningRewritePatternList &patterns) {
+ // A mapping between an operation and a set of operations that can be used to
+ // generate it.
+ DenseMap<OperationName, SmallPtrSet<OperationName, 2>> parentOps;
+ // A mapping between an operation and any currently invalid patterns it has.
+ DenseMap<OperationName, SmallPtrSet<RewritePattern *, 2>> invalidPatterns;
+ // A worklist of patterns to consider for legality.
+ llvm::SetVector<RewritePattern *> patternWorklist;
+
+ // Collect the initial set of valid target ops.
+ for (auto &opInfoPair : target.getLegalOps())
+ legalOps[opInfoPair.first].targetAction = opInfoPair.second;
+
+ // Build the mapping from operations to the parent ops that may generate them.
+ for (auto &pattern : patterns) {
+ auto root = pattern->getRootKind();
+
+ // Skip operations that are known to always be legal.
+ auto it = legalOps.find(root);
+ if (it != legalOps.end() &&
+ it->second.targetAction == ConversionTarget::LegalizationAction::Legal)
+ continue;
+
+ // Add this pattern to the invalid set for the root op and record this root
+ // as a parent for any generated operations.
+ invalidPatterns[root].insert(pattern.get());
+ for (auto op : pattern->getGeneratedOps())
+ parentOps[op].insert(root);
+
+ // If this pattern doesn't generate any operations, optimistically add it to
+ // the worklist.
+ if (pattern->getGeneratedOps().empty())
+ patternWorklist.insert(pattern.get());
+ }
+
+ // Build the initial worklist with the patterns that generate operations that
+ // are known to be legal.
+ for (auto &opInfoPair : target.getLegalOps())
+ for (auto &parentOp : parentOps[opInfoPair.first])
+ patternWorklist.set_union(invalidPatterns[parentOp]);
+
+ while (!patternWorklist.empty()) {
+ auto *pattern = patternWorklist.pop_back_val();
+
+ // Check to see if any of the generated operations are invalid.
+ if (llvm::any_of(pattern->getGeneratedOps(),
+ [&](OperationName op) { return !legalOps.count(op); }))
+ continue;
+
+ // Otherwise, if all of the generated operation are valid, this op is now
+ // legal so add all of the child patterns to the worklist.
+ legalOps[pattern->getRootKind()].patterns.push_back(pattern);
+ invalidPatterns[pattern->getRootKind()].erase(pattern);
+
+ // Add any invalid patterns of the parent operations to see if they have now
+ // become legal.
+ for (auto op : parentOps[pattern->getRootKind()])
+ patternWorklist.set_union(invalidPatterns[op]);
+ }
+}
+
+//===----------------------------------------------------------------------===//
// FunctionConverter
//===----------------------------------------------------------------------===//
namespace {
// This class converts a single function using the given pattern matcher. If a
// TypeConverter object is provided, then the types of block arguments will be
// converted using the appropriate 'convertType' calls.
-class FunctionConverter {
-public:
- explicit FunctionConverter(MLIRContext *ctx, RewritePatternMatcher &matcher,
+struct FunctionConverter {
+ explicit FunctionConverter(MLIRContext *ctx, ConversionTarget &target,
+ OwningRewritePatternList &patterns,
TypeConverter *conversion = nullptr)
- : typeConverter(conversion), matcher(matcher) {}
+ : typeConverter(conversion), opLegalizer(target, patterns) {}
/// Converts the given function to the dialect using hooks defined in
/// `typeConverter`. Returns failure on error, success otherwise.
/// Pointer to a specific dialect conversion info.
TypeConverter *typeConverter;
- /// The matcher to use when converting operations.
- RewritePatternMatcher &matcher;
+ /// The legalizer to use when converting operations.
+ OperationLegalizer opLegalizer;
};
} // end anonymous namespace
// Iterate over ops and convert them.
for (Operation &op : llvm::make_early_inc_range(*block)) {
- rewriter.setInsertionPoint(&op);
- if (matcher.matchAndRewrite(&op, rewriter))
- continue;
-
// Traverse any held regions.
for (auto ®ion : op.getRegions())
if (!region.empty() &&
failed(convertRegion(rewriter, region, op.getLoc())))
return failure();
+
+ // Legalize the current operation.
+ (void)opLegalizer.legalize(&op, rewriter);
}
// Recurse to children that haven't been visited.
/// conversion object. If conversion fails for specific functions, those
/// functions remains unmodified.
LogicalResult
-mlir::applyConversionPatterns(Module &module, TypeConverter &converter,
+mlir::applyConversionPatterns(Module &module, ConversionTarget &target,
+ TypeConverter &converter,
OwningRewritePatternList &&patterns) {
- // Grab the conversion patterns from the converter and create the pattern
- // matcher.
- MLIRContext *context = module.getContext();
- RewritePatternMatcher matcher(std::move(patterns));
+ // Build the function converter.
+ FunctionConverter funcConverter(module.getContext(), target, patterns,
+ &converter);
// Try to convert each of the functions within the module. Defer updating the
// signatures of the functions until after all of the bodies have been
return func.emitError("could not convert function type");
// Convert the body of this function.
- FunctionConverter funcConverter(context, matcher, &converter);
if (failed(funcConverter.convertFunction(&func)))
return failure();
/// convert as many of the operations within 'fn' as possible given the set of
/// patterns.
LogicalResult
-mlir::applyConversionPatterns(Function &fn,
+mlir::applyConversionPatterns(Function &fn, ConversionTarget &target,
OwningRewritePatternList &&patterns) {
// Convert the body of this function.
- RewritePatternMatcher matcher(std::move(patterns));
- FunctionConverter converter(fn.getContext(), matcher);
+ FunctionConverter converter(fn.getContext(), target, patterns);
return converter.convertFunction(&fn);
}