/// each operation.
void walk(const std::function<void(Operation *)> &callback);
+ /// Specialization of walk to only visit operations of 'OpTy'.
+ template <typename OpTy> void walk(std::function<void(OpTy)> callback) {
+ walk([&](Operation *opInst) {
+ if (auto op = dyn_cast<OpTy>(opInst))
+ callback(op);
+ });
+ }
+
/// Walk the operations in the specified [begin, end) range of this block in
/// postorder, calling the callback for each operation.
void walk(Block::iterator begin, Block::iterator end,
/// Parse a `:` token.
virtual ParseResult parseColon() = 0;
+ /// Parse a `:` token if present.
+ virtual ParseResult parseOptionalColon() = 0;
+
/// Parse a '(' token.
virtual ParseResult parseLParen() = 0;
+ /// Parse a '(' token if present.
+ virtual ParseResult parseOptionalLParen() = 0;
+
/// Parse a ')' token.
virtual ParseResult parseRParen() = 0;
- /// Parses a ')' if present.
+ /// Parse a ')' token if present.
virtual ParseResult parseOptionalRParen() = 0;
- /// This parses an equal(=) token!
+ /// Parse a '=' token.
virtual ParseResult parseEqual() = 0;
/// Parse a type.
Module &module, ConversionTarget &target, TypeConverter &converter,
OwningRewritePatternList &&patterns);
+/// Convert the given functions with the provided conversion patterns. This will
+/// convert as many of the operations within each function as possible given the
+/// set of patterns. If conversion fails for specific functions, those functions
+/// remains unmodified.
+LLVM_NODISCARD
+LogicalResult applyConversionPatterns(ArrayRef<Function *> fns,
+ ConversionTarget &target,
+ TypeConverter &converter,
+ OwningRewritePatternList &&patterns);
+
/// Convert the given function with the provided conversion patterns. This will
/// convert as many of the operations within 'fn' as possible given the set of
/// patterns.
*loc = parser.getToken().getLoc();
return success();
}
+
ParseResult parseComma() override {
return parser.parseToken(Token::comma, "expected ','");
}
+
+ ParseResult parseOptionalComma() override {
+ return success(parser.consumeIf(Token::comma));
+ }
+
ParseResult parseColon() override {
return parser.parseToken(Token::colon, "expected ':'");
}
+
+ ParseResult parseOptionalColon() override {
+ return success(parser.consumeIf(Token::colon));
+ }
+
ParseResult parseEqual() override {
return parser.parseToken(Token::equal, "expected '='");
}
return success();
}
- ParseResult parseOptionalComma() override {
- return success(parser.consumeIf(Token::comma));
- }
-
/// Parse an optional keyword.
ParseResult parseOptionalKeyword(const char *keyword) override {
// Check that the current token is a bare identifier or keyword.
return parser.parseToken(Token::l_paren, "expected '('");
}
+ /// Parses a '(' if present.
+ ParseResult parseOptionalLParen() override {
+ return success(parser.consumeIf(Token::l_paren));
+ }
+
ParseResult parseRParen() override {
return parser.parseToken(Token::r_paren, "expected ')'");
}
mlir::applyConversionPatterns(Module &module, ConversionTarget &target,
TypeConverter &converter,
OwningRewritePatternList &&patterns) {
+ std::vector<Function *> allFunctions;
+ allFunctions.reserve(module.getFunctions().size());
+ for (auto &func : module)
+ allFunctions.push_back(&func);
+ return applyConversionPatterns(allFunctions, target, converter,
+ std::move(patterns));
+}
+
+/// Convert the given functions with the provided conversion patterns. This will
+/// convert as many of the operations within each function as possible given the
+/// set of patterns. If conversion fails for specific functions, those functions
+// remains unmodified.
+LogicalResult mlir::applyConversionPatterns(
+ ArrayRef<Function *> fns, ConversionTarget &target,
+ TypeConverter &converter, OwningRewritePatternList &&patterns) {
+ if (fns.empty())
+ return success();
+
// Build the function converter.
- FunctionConverter funcConverter(module.getContext(), target, patterns,
+ FunctionConverter funcConverter(fns.front()->getContext(), target, patterns,
&converter);
// Try to convert each of the functions within the module. Defer updating the
// public signatures of the functions within the module before they are
// updated.
std::vector<ConvertedFunction> toConvert;
- toConvert.reserve(module.getFunctions().size());
- for (auto &func : module) {
+ toConvert.reserve(fns.size());
+ for (auto *func : fns) {
// Convert the function type using the dialect converter.
SmallVector<NamedAttributeList, 4> newFunctionArgAttrs;
FunctionType newType = converter.convertFunctionSignatureType(
- func.getType(), func.getAllArgAttrs(), newFunctionArgAttrs);
+ func->getType(), func->getAllArgAttrs(), newFunctionArgAttrs);
if (!newType || !newType.isa<FunctionType>())
- return func.emitError("could not convert function type");
+ return func->emitError("could not convert function type");
// Convert the body of this function.
- if (failed(funcConverter.convertFunction(&func)))
+ if (failed(funcConverter.convertFunction(func)))
return failure();
// Add function signature to be updated.
- toConvert.emplace_back(&func, newType.cast<FunctionType>(),
+ toConvert.emplace_back(func, newType.cast<FunctionType>(),
newFunctionArgAttrs);
}