/// Base class for the dialect conversion patterns that require type changes.
/// Specific conversions must derive this class and implement least one
/// `rewrite` method.
-/// NOTE: These conversion patterns can only be used with the DialectConversion
-/// class.
+/// NOTE: These conversion patterns can only be used with the 'apply*' methods
+/// below.
class DialectConversionPattern : public RewritePattern {
public:
/// Construct an DialectConversionPattern. `rootName` must correspond to the
// match against the list of conversions. On the first match, call
// `rewrite` for the operations, and advance to the next iteration. If no
// match is found, replicate the operation as is.
-/// 3. Update all attributes of function type to point to the new functions.
-/// 4. Replace old functions with new functions in the module.
-/// If any error happened during the conversion, the pass fails as soon as
-/// possible.
-///
-/// If conversion fails for a specific function, that functions remains
-/// unmodified. Otherwise, successfully converted functions will remain
-/// converted.
class DialectConversion {
public:
virtual ~DialectConversion() = default;
- /// Run the converter on the provided module.
- LLVM_NODISCARD
- LogicalResult convert(Module *m);
-
/// Derived classes must implement this hook to produce a set of conversion
/// patterns to apply. They may use `mlirContext` to obtain registered
/// dialects or operations. This will be called in the beginning of the
SmallVectorImpl<NamedAttributeList> &convertedArgAttrs);
};
+/// Convert the given module with the provided dialect conversion object.
+/// If conversion fails for a specific function, those functions remains
+/// unmodified.
+LLVM_NODISCARD
+LogicalResult applyConverter(Module &module, DialectConversion &converter);
+
+/// Convert the given function with the provided conversion patterns. This will
+/// convert as many of the operations within 'fn' as possible given the set of
+/// patterns.
+LLVM_NODISCARD
+LogicalResult applyConversionPatterns(Function &fn,
+ OwningRewritePatternList &&patterns);
+
} // end namespace mlir
#endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_
// FunctionConverter
//===----------------------------------------------------------------------===//
namespace {
-// This class converts a single function using a given DialectConversion
-// structure.
+// This class converts a single function using the given pattern matcher. If a
+// DialectConversion object is also provided, then the types of block arguments
+// will be converted using the appropriate 'convertType' calls.
class FunctionConverter {
public:
- // Constructs a FunctionConverter.
- explicit FunctionConverter(MLIRContext *ctx, DialectConversion *conversion,
- RewritePatternMatcher &matcher)
+ explicit FunctionConverter(MLIRContext *ctx, RewritePatternMatcher &matcher,
+ DialectConversion *conversion = nullptr)
: dialectConversion(conversion), matcher(matcher) {}
/// Converts the given function to the dialect using hooks defined in
Region ®ion, RegionParent *parent) {
assert(!region.empty() && "expected non-empty region");
- // Create the arguments of each of the blocks in the region.
- for (Block &block : region)
- for (auto *arg : block.getArguments())
- if (failed(convertArgument(rewriter, arg, parent->getLoc())))
- return failure();
+ // Create the arguments of each of the blocks in the region. If a type
+ // converter was not provided, then we don't need to change any of the block
+ // types.
+ if (dialectConversion) {
+ for (Block &block : region)
+ for (auto *arg : block.getArguments())
+ if (failed(convertArgument(rewriter, arg, parent->getLoc())))
+ return failure();
+ }
// Start a DFS-order traversal of the CFG to make sure defs are converted
// before uses in dominated blocks.
// Rewrite the function body.
DialectConversionRewriter rewriter(f);
if (failed(convertRegion(rewriter, f->getBody(), f))) {
- // Reset any of the converted arguments.
- rewriter.argConverter.discardRewrites();
+ // Reset any of the generated rewrites.
+ rewriter.discardRewrites();
return failure();
}
// DialectConversion
//===----------------------------------------------------------------------===//
-namespace {
-/// This class represents a function to be converted. It allows for converting
-/// the body of functions and the signature in two phases.
-struct ConvertedFunction {
- ConvertedFunction(Function *fn, FunctionType newType,
- ArrayRef<NamedAttributeList> newFunctionArgAttrs)
- : fn(fn), newType(newType),
- newFunctionArgAttrs(newFunctionArgAttrs.begin(),
- newFunctionArgAttrs.end()) {}
-
- /// The function to convert.
- Function *fn;
- /// The new type and argument attributes for the function.
- FunctionType newType;
- SmallVector<NamedAttributeList, 4> newFunctionArgAttrs;
-};
-} // end anonymous namespace
-
// Create a function type with arguments and results converted, and argument
// attributes passed through.
FunctionType DialectConversion::convertFunctionSignatureType(
return FunctionType::get(arguments, results, type.getContext());
}
-// Converts the module as follows.
-// 1. Call `convertFunction` on each function of the module and collect the
-// mapping between old and new functions.
-// 2. Remap all function attributes in the new functions to point to the new
-// functions instead of the old ones.
-// 3. Replace old functions with the new in the module.
-LogicalResult DialectConversion::convert(Module *module) {
- if (!module)
- return failure();
+//===----------------------------------------------------------------------===//
+// applyConversionPatterns
+//===----------------------------------------------------------------------===//
+namespace {
+/// This class represents a function to be converted. It allows for converting
+/// the body of functions and the signature in two phases.
+struct ConvertedFunction {
+ ConvertedFunction(Function *fn, FunctionType newType,
+ ArrayRef<NamedAttributeList> newFunctionArgAttrs)
+ : fn(fn), newType(newType),
+ newFunctionArgAttrs(newFunctionArgAttrs.begin(),
+ newFunctionArgAttrs.end()) {}
+
+ /// The function to convert.
+ Function *fn;
+ /// The new type and argument attributes for the function.
+ FunctionType newType;
+ SmallVector<NamedAttributeList, 4> newFunctionArgAttrs;
+};
+} // end anonymous namespace
+
+/// Convert the given module with the provided dialect conversion object.
+/// If conversion fails for a specific function, those functions remains
+/// unmodified.
+LogicalResult mlir::applyConverter(Module &module,
+ DialectConversion &converter) {
// Grab the conversion patterns from the converter and create the pattern
// matcher.
- MLIRContext *context = module->getContext();
+ MLIRContext *context = module.getContext();
OwningRewritePatternList patterns;
- initConverters(patterns, context);
+ converter.initConverters(patterns, context);
RewritePatternMatcher matcher(std::move(patterns));
// Try to convert each of the functions within the module. Defer updating the
// public signatures of the functions within the module before they are
// updated.
std::vector<ConvertedFunction> toConvert;
- toConvert.reserve(module->getFunctions().size());
- for (auto &func : *module) {
+ toConvert.reserve(module.getFunctions().size());
+ for (auto &func : module) {
// Convert the function type using the dialect converter.
SmallVector<NamedAttributeList, 4> newFunctionArgAttrs;
- FunctionType newType = convertFunctionSignatureType(
+ FunctionType newType = converter.convertFunctionSignatureType(
func.getType(), func.getAllArgAttrs(), newFunctionArgAttrs);
if (!newType || !newType.isa<FunctionType>())
return func.emitError("could not convert function type");
// Convert the body of this function.
- FunctionConverter converter(context, this, matcher);
- if (failed(converter.convertFunction(&func)))
+ FunctionConverter funcConverter(context, matcher, &converter);
+ if (failed(funcConverter.convertFunction(&func)))
return failure();
// Add function signature to be updated.
return success();
}
+
+/// Convert the given function with the provided conversion patterns. This will
+/// convert as many of the operations within 'fn' as possible given the set of
+/// patterns.
+LogicalResult
+mlir::applyConversionPatterns(Function &fn,
+ OwningRewritePatternList &&patterns) {
+ // Convert the body of this function.
+ RewritePatternMatcher matcher(std::move(patterns));
+ FunctionConverter converter(fn.getContext(), matcher);
+ return converter.convertFunction(&fn);
+}