#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
+#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/TypeSwitch.h"
// Printing.
//===----------------------------------------------------------------------===//
-static void printTypeImpl(llvm::raw_ostream &os, Type type,
- llvm::SetVector<StringRef> &stack);
+/// If the given type is compatible with the LLVM dialect, prints it using
+/// internal functions to avoid getting a verbose `!llvm` prefix. Otherwise
+/// prints it as usual.
+static void dispatchPrint(DialectAsmPrinter &printer, Type type) {
+ if (isCompatibleType(type))
+ return mlir::LLVM::detail::printType(type, printer);
+ printer.printType(type);
+}
/// Returns the keyword to use for the given type.
static StringRef getTypeKeyword(Type type) {
});
}
-/// Prints the body of a structure type. Uses `stack` to avoid printing
-/// recursive structs indefinitely.
-static void printStructTypeBody(llvm::raw_ostream &os, LLVMStructType type,
- llvm::SetVector<StringRef> &stack) {
- if (type.isIdentified() && type.isOpaque()) {
- os << "opaque";
- return;
- }
-
- if (type.isPacked())
- os << "packed ";
-
- // Put the current type on stack to avoid infinite recursion.
- os << '(';
- if (type.isIdentified())
- stack.insert(type.getName());
- llvm::interleaveComma(type.getBody(), os, [&](Type subtype) {
- printTypeImpl(os, subtype, stack);
+/// Prints a structure type. Keeps track of known struct names to handle self-
+/// or mutually-referring structs without falling into infinite recursion.
+static void printStructType(DialectAsmPrinter &printer, LLVMStructType type) {
+ // This keeps track of the names of identified structure types that are
+ // currently being printed. Since such types can refer themselves, this
+ // tracking is necessary to stop the recursion: the current function may be
+ // called recursively from DialectAsmPrinter::printType after the appropriate
+ // dispatch. We maintain the invariant of this storage being modified
+ // exclusively in this function, and at most one name being added per call.
+ // TODO: consider having such functionality inside DialectAsmPrinter.
+ thread_local llvm::SetVector<StringRef> knownStructNames;
+ unsigned stackSize = knownStructNames.size();
+ (void)stackSize;
+ auto guard = llvm::make_scope_exit([&]() {
+ assert(knownStructNames.size() == stackSize &&
+ "malformed identified stack when printing recursive structs");
});
- if (type.isIdentified())
- stack.pop_back();
- os << ')';
-}
-/// Prints a structure type. Uses `stack` to keep track of the identifiers of
-/// the structs being printed. Checks if the identifier of a struct is contained
-/// in `stack`, i.e. whether a self-reference to a recursive stack is being
-/// printed, and only prints the name to avoid infinite recursion.
-static void printStructType(llvm::raw_ostream &os, LLVMStructType type,
- llvm::SetVector<StringRef> &stack) {
- os << "<";
+ printer << "<";
if (type.isIdentified()) {
- os << '"' << type.getName() << '"';
+ printer << '"' << type.getName() << '"';
// If we are printing a reference to one of the enclosing structs, just
// print the name and stop to avoid infinitely long output.
- if (stack.count(type.getName())) {
- os << '>';
+ if (knownStructNames.count(type.getName())) {
+ printer << '>';
return;
}
- os << ", ";
+ printer << ", ";
+ }
+
+ if (type.isIdentified() && type.isOpaque()) {
+ printer << "opaque>";
+ return;
}
- printStructTypeBody(os, type, stack);
- os << '>';
+ if (type.isPacked())
+ printer << "packed ";
+
+ // Put the current type on stack to avoid infinite recursion.
+ printer << '(';
+ if (type.isIdentified())
+ knownStructNames.insert(type.getName());
+ llvm::interleaveComma(type.getBody(), printer.getStream(),
+ [&](Type subtype) { dispatchPrint(printer, subtype); });
+ if (type.isIdentified())
+ knownStructNames.pop_back();
+ printer << ')';
+ printer << '>';
}
/// Prints a type containing a fixed number of elements.
template <typename TypeTy>
-static void printArrayOrVectorType(llvm::raw_ostream &os, TypeTy type,
- llvm::SetVector<StringRef> &stack) {
- os << '<' << type.getNumElements() << " x ";
- printTypeImpl(os, type.getElementType(), stack);
- os << '>';
+static void printArrayOrVectorType(DialectAsmPrinter &printer, TypeTy type) {
+ printer << '<' << type.getNumElements() << " x ";
+ dispatchPrint(printer, type.getElementType());
+ printer << '>';
}
/// Prints a function type.
-static void printFunctionType(llvm::raw_ostream &os, LLVMFunctionType funcType,
- llvm::SetVector<StringRef> &stack) {
- os << '<';
- printTypeImpl(os, funcType.getReturnType(), stack);
- os << " (";
- llvm::interleaveComma(funcType.getParams(), os, [&os, &stack](Type subtype) {
- printTypeImpl(os, subtype, stack);
- });
+static void printFunctionType(DialectAsmPrinter &printer,
+ LLVMFunctionType funcType) {
+ printer << '<';
+ dispatchPrint(printer, funcType.getReturnType());
+ printer << " (";
+ llvm::interleaveComma(
+ funcType.getParams(), printer.getStream(),
+ [&printer](Type subtype) { dispatchPrint(printer, subtype); });
if (funcType.isVarArg()) {
if (funcType.getNumParams() != 0)
- os << ", ";
- os << "...";
+ printer << ", ";
+ printer << "...";
}
- os << ")>";
+ printer << ")>";
}
/// Prints the given LLVM dialect type recursively. This leverages closedness of
/// struct<"c", (ptr<struct<"b", (ptr<struct<"c">>)>>,
/// ptr<struct<"b", (ptr<struct<"c">>)>>)>
/// note that "b" is printed twice.
-static void printTypeImpl(llvm::raw_ostream &os, Type type,
- llvm::SetVector<StringRef> &stack) {
+void mlir::LLVM::detail::printType(Type type, DialectAsmPrinter &printer) {
if (!type) {
- os << "<<NULL-TYPE>>";
+ printer << "<<NULL-TYPE>>";
return;
}
- os << getTypeKeyword(type);
+ printer << getTypeKeyword(type);
if (auto intType = type.dyn_cast<LLVMIntegerType>()) {
- os << intType.getBitWidth();
+ printer << intType.getBitWidth();
return;
}
if (auto ptrType = type.dyn_cast<LLVMPointerType>()) {
- os << '<';
- printTypeImpl(os, ptrType.getElementType(), stack);
+ printer << '<';
+ dispatchPrint(printer, ptrType.getElementType());
if (ptrType.getAddressSpace() != 0)
- os << ", " << ptrType.getAddressSpace();
- os << '>';
+ printer << ", " << ptrType.getAddressSpace();
+ printer << '>';
return;
}
if (auto arrayType = type.dyn_cast<LLVMArrayType>())
- return printArrayOrVectorType(os, arrayType, stack);
+ return printArrayOrVectorType(printer, arrayType);
if (auto vectorType = type.dyn_cast<LLVMFixedVectorType>())
- return printArrayOrVectorType(os, vectorType, stack);
+ return printArrayOrVectorType(printer, vectorType);
if (auto vectorType = type.dyn_cast<LLVMScalableVectorType>()) {
- os << "<? x " << vectorType.getMinNumElements() << " x ";
- printTypeImpl(os, vectorType.getElementType(), stack);
- os << '>';
+ printer << "<? x " << vectorType.getMinNumElements() << " x ";
+ dispatchPrint(printer, vectorType.getElementType());
+ printer << '>';
return;
}
if (auto structType = type.dyn_cast<LLVMStructType>())
- return printStructType(os, structType, stack);
+ return printStructType(printer, structType);
if (auto funcType = type.dyn_cast<LLVMFunctionType>())
- return printFunctionType(os, funcType, stack);
-}
-
-void mlir::LLVM::detail::printType(Type type, DialectAsmPrinter &printer) {
- llvm::SetVector<StringRef> stack;
- return printTypeImpl(printer.getStream(), type, stack);
+ return printFunctionType(printer, funcType);
}
//===----------------------------------------------------------------------===//
// Parsing.
//===----------------------------------------------------------------------===//
-static Type parseTypeImpl(DialectAsmParser &parser,
- llvm::SetVector<StringRef> &stack);
-
-/// Helper to be chained with other parsing functions.
-static ParseResult parseTypeImpl(DialectAsmParser &parser,
- llvm::SetVector<StringRef> &stack,
- Type &result) {
- result = parseTypeImpl(parser, stack);
- return success(result != nullptr);
-}
+static ParseResult dispatchParse(DialectAsmParser &parser, Type &type);
/// Parses an LLVM dialect function type.
/// llvm-type :: = `func<` llvm-type `(` llvm-type-list `...`? `)>`
-static LLVMFunctionType parseFunctionType(DialectAsmParser &parser,
- llvm::SetVector<StringRef> &stack) {
+static LLVMFunctionType parseFunctionType(DialectAsmParser &parser) {
Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
Type returnType;
- if (parser.parseLess() || parseTypeImpl(parser, stack, returnType) ||
+ if (parser.parseLess() || dispatchParse(parser, returnType) ||
parser.parseLParen())
return LLVMFunctionType();
/*isVarArg=*/true);
}
- argTypes.push_back(parseTypeImpl(parser, stack));
- if (!argTypes.back())
+ Type arg;
+ if (dispatchParse(parser, arg))
return LLVMFunctionType();
+ argTypes.push_back(arg);
} while (succeeded(parser.parseOptionalComma()));
if (parser.parseOptionalRParen() || parser.parseOptionalGreater())
/// Parses an LLVM dialect pointer type.
/// llvm-type ::= `ptr<` llvm-type (`,` integer)? `>`
-static LLVMPointerType parsePointerType(DialectAsmParser &parser,
- llvm::SetVector<StringRef> &stack) {
+static LLVMPointerType parsePointerType(DialectAsmParser &parser) {
Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
Type elementType;
- if (parser.parseLess() || parseTypeImpl(parser, stack, elementType))
+ if (parser.parseLess() || dispatchParse(parser, elementType))
return LLVMPointerType();
unsigned addressSpace = 0;
/// Parses an LLVM dialect vector type.
/// llvm-type ::= `vec<` `? x`? integer `x` llvm-type `>`
/// Supports both fixed and scalable vectors.
-static LLVMVectorType parseVectorType(DialectAsmParser &parser,
- llvm::SetVector<StringRef> &stack) {
+static LLVMVectorType parseVectorType(DialectAsmParser &parser) {
SmallVector<int64_t, 2> dims;
llvm::SMLoc dimPos;
Type elementType;
Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
if (parser.parseLess() || parser.getCurrentLocation(&dimPos) ||
parser.parseDimensionList(dims, /*allowDynamic=*/true) ||
- parseTypeImpl(parser, stack, elementType) || parser.parseGreater())
+ dispatchParse(parser, elementType) || parser.parseGreater())
return LLVMVectorType();
// We parsed a generic dimension list, but vectors only support two forms:
/// Parses an LLVM dialect array type.
/// llvm-type ::= `array<` integer `x` llvm-type `>`
-static LLVMArrayType parseArrayType(DialectAsmParser &parser,
- llvm::SetVector<StringRef> &stack) {
+static LLVMArrayType parseArrayType(DialectAsmParser &parser) {
SmallVector<int64_t, 1> dims;
llvm::SMLoc sizePos;
Type elementType;
Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
if (parser.parseLess() || parser.getCurrentLocation(&sizePos) ||
parser.parseDimensionList(dims, /*allowDynamic=*/false) ||
- parseTypeImpl(parser, stack, elementType) || parser.parseGreater())
+ dispatchParse(parser, elementType) || parser.parseGreater())
return LLVMArrayType();
if (dims.size() != 1) {
}
/// Attempts to set the body of an identified structure type. Reports a parsing
-/// error at `subtypesLoc` in case of failure, uses `stack` to make sure the
-/// types printed in the error message look like they did when parsed.
+/// error at `subtypesLoc` in case of failure.
static LLVMStructType trySetStructBody(LLVMStructType type,
ArrayRef<Type> subtypes, bool isPacked,
DialectAsmParser &parser,
- llvm::SMLoc subtypesLoc,
- llvm::SetVector<StringRef> &stack) {
+ llvm::SMLoc subtypesLoc) {
for (Type t : subtypes) {
if (!LLVMStructType::isValidElementType(t)) {
parser.emitError(subtypesLoc)
if (succeeded(type.setBody(subtypes, isPacked)))
return type;
- std::string currentBody;
- llvm::raw_string_ostream currentBodyStream(currentBody);
- printStructTypeBody(currentBodyStream, type, stack);
- auto diag = parser.emitError(subtypesLoc)
- << "identified type already used with a different body";
- diag.attachNote() << "existing body: " << currentBodyStream.str();
+ parser.emitError(subtypesLoc)
+ << "identified type already used with a different body";
return LLVMStructType();
}
/// `(` llvm-type-list `)` `>`
/// | `struct<` string-literal `>`
/// | `struct<` string-literal `, opaque>`
-static LLVMStructType parseStructType(DialectAsmParser &parser,
- llvm::SetVector<StringRef> &stack) {
+static LLVMStructType parseStructType(DialectAsmParser &parser) {
+ // This keeps track of the names of identified structure types that are
+ // currently being parsed. Since such types can refer themselves, this
+ // tracking is necessary to stop the recursion: the current function may be
+ // called recursively from DialectAsmParser::parseType after the appropriate
+ // dispatch. We maintain the invariant of this storage being modified
+ // exclusively in this function, and at most one name being added per call.
+ // TODO: consider having such functionality inside DialectAsmParser.
+ thread_local llvm::SetVector<StringRef> knownStructNames;
+ unsigned stackSize = knownStructNames.size();
+ (void)stackSize;
+ auto guard = llvm::make_scope_exit([&]() {
+ assert(knownStructNames.size() == stackSize &&
+ "malformed identified stack when parsing recursive structs");
+ });
+
Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
if (failed(parser.parseLess()))
StringRef name;
bool isIdentified = succeeded(parser.parseOptionalString(&name));
if (isIdentified) {
- if (stack.count(name)) {
+ if (knownStructNames.count(name)) {
if (failed(parser.parseGreater()))
return LLVMStructType();
return LLVMStructType::getIdentifiedChecked(loc, name);
if (!isIdentified)
return LLVMStructType::getLiteralChecked(loc, {}, isPacked);
auto type = LLVMStructType::getIdentifiedChecked(loc, name);
- return trySetStructBody(type, {}, isPacked, parser, kwLoc, stack);
+ return trySetStructBody(type, {}, isPacked, parser, kwLoc);
}
// Parse subtypes. For identified structs, put the identifier of the struct on
llvm::SMLoc subtypesLoc = parser.getCurrentLocation();
do {
if (isIdentified)
- stack.insert(name);
- Type type = parseTypeImpl(parser, stack);
- if (!type)
+ knownStructNames.insert(name);
+ Type type;
+ if (dispatchParse(parser, type))
return LLVMStructType();
subtypes.push_back(type);
if (isIdentified)
- stack.pop_back();
+ knownStructNames.pop_back();
} while (succeeded(parser.parseOptionalComma()));
if (parser.parseRParen() || parser.parseGreater())
if (!isIdentified)
return LLVMStructType::getLiteralChecked(loc, subtypes, isPacked);
auto type = LLVMStructType::getIdentifiedChecked(loc, name);
- return trySetStructBody(type, subtypes, isPacked, parser, subtypesLoc, stack);
+ return trySetStructBody(type, subtypes, isPacked, parser, subtypesLoc);
}
-/// Parses one of the LLVM dialect types.
-static Type parseTypeImpl(DialectAsmParser &parser,
- llvm::SetVector<StringRef> &stack) {
- // Special case for integers (i[1-9][0-9]*) that are literals rather than
- // keywords for the parser, so they are not caught by the main dispatch below.
- // Try parsing it a built-in integer type instead.
- Type maybeIntegerType;
- MLIRContext *ctx = parser.getBuilder().getContext();
+/// Parses a type appearing inside another LLVM dialect-compatible type. This
+/// will try to parse any type in full form (including types with the `!llvm`
+/// prefix), and on failure fall back to parsing the short-hand version of the
+/// LLVM dialect types without the `!llvm` prefix.
+static Type dispatchParse(DialectAsmParser &parser) {
+ Type type;
llvm::SMLoc keyLoc = parser.getCurrentLocation();
Location loc = parser.getEncodedSourceLoc(keyLoc);
- OptionalParseResult result = parser.parseOptionalType(maybeIntegerType);
- if (result.hasValue()) {
- if (failed(*result))
+ OptionalParseResult parseResult = parser.parseOptionalType(type);
+ if (parseResult.hasValue()) {
+ if (failed(*parseResult))
return Type();
- if (!maybeIntegerType.isSignlessInteger()) {
- parser.emitError(keyLoc) << "unexpected type, expected i* or keyword";
- return Type();
- }
- return LLVMIntegerType::getChecked(
- loc, maybeIntegerType.getIntOrFloatBitWidth());
+ // Special case for integers (i[1-9][0-9]*) that are literals rather than
+ // keywords for the parser, so they are not caught by the main dispatch
+ // below. Try parsing it a built-in integer type instead.
+ auto intType = type.dyn_cast<IntegerType>();
+ if (!intType || !intType.isSignless())
+ return type;
+
+ return LLVMIntegerType::getChecked(loc, intType.getWidth());
}
// Dispatch to concrete functions.
if (failed(parser.parseKeyword(&key)))
return Type();
+ MLIRContext *ctx = parser.getBuilder().getContext();
return StringSwitch<function_ref<Type()>>(key)
.Case("void", [&] { return LLVMVoidType::get(ctx); })
.Case("half", [&] { return LLVMHalfType::get(ctx); })
.Case("token", [&] { return LLVMTokenType::get(ctx); })
.Case("label", [&] { return LLVMLabelType::get(ctx); })
.Case("metadata", [&] { return LLVMMetadataType::get(ctx); })
- .Case("func", [&] { return parseFunctionType(parser, stack); })
- .Case("ptr", [&] { return parsePointerType(parser, stack); })
- .Case("vec", [&] { return parseVectorType(parser, stack); })
- .Case("array", [&] { return parseArrayType(parser, stack); })
- .Case("struct", [&] { return parseStructType(parser, stack); })
+ .Case("func", [&] { return parseFunctionType(parser); })
+ .Case("ptr", [&] { return parsePointerType(parser); })
+ .Case("vec", [&] { return parseVectorType(parser); })
+ .Case("array", [&] { return parseArrayType(parser); })
+ .Case("struct", [&] { return parseStructType(parser); })
.Default([&] {
parser.emitError(keyLoc) << "unknown LLVM type: " << key;
return Type();
})();
}
+/// Helper to use in parse lists.
+static ParseResult dispatchParse(DialectAsmParser &parser, Type &type) {
+ type = dispatchParse(parser);
+ return success(type != nullptr);
+}
+
+/// Parses one of the LLVM dialect types.
Type mlir::LLVM::detail::parseType(DialectAsmParser &parser) {
- llvm::SetVector<StringRef> stack;
- return parseTypeImpl(parser, stack);
+ llvm::SMLoc loc = parser.getCurrentLocation();
+ Type type = dispatchParse(parser);
+ if (!type)
+ return type;
+ if (!isCompatibleType(type)) {
+ parser.emitError(loc) << "unexpected type, expected i* or keyword";
+ return nullptr;
+ }
+ return type;
}