enum class Kind {
/// This element is a directive.
AttrDictDirective,
+ CustomDirective,
FunctionalTypeDirective,
OperandsDirective,
ResultsDirective,
namespace {
/// This class implements single kind directives.
-template <Element::Kind type>
-class DirectiveElement : public Element {
+template <Element::Kind type> class DirectiveElement : public Element {
public:
DirectiveElement() : Element(type){};
static bool classof(const Element *ele) { return ele->getKind() == type; }
bool withKeyword;
};
+/// This class represents a custom format directive that is implemented by the
+/// user in C++.
+class CustomDirective : public Element {
+public:
+ CustomDirective(StringRef name,
+ std::vector<std::unique_ptr<Element>> &&arguments)
+ : Element{Kind::CustomDirective}, name(name),
+ arguments(std::move(arguments)) {}
+
+ static bool classof(const Element *element) {
+ return element->getKind() == Kind::CustomDirective;
+ }
+
+ /// Return the name of this optional element.
+ StringRef getName() const { return name; }
+
+ /// Return the arguments to the custom directive.
+ auto getArguments() const { return llvm::make_pointee_range(arguments); }
+
+private:
+ /// The user provided name of the directive.
+ StringRef name;
+
+ /// The arguments to the custom directive.
+ std::vector<std::unique_ptr<Element>> arguments;
+};
+
/// This class represents the `functional-type` directive. This directive takes
/// two arguments and formats them, respectively, as the inputs and results of a
/// FunctionType.
/// The code snippet used to generate a parser call for an attribute.
///
-/// {0}: The storage type of the attribute.
-/// {1}: The name of the attribute.
-/// {2}: The type for the attribute.
+/// {0}: The name of the attribute.
+/// {1}: The type for the attribute.
const char *const attrParserCode = R"(
- {0} {1}Attr;
- if (parser.parseAttribute({1}Attr{2}, "{1}", result.attributes))
+ if (parser.parseAttribute({0}Attr{1}, "{0}", result.attributes))
return failure();
)";
const char *const optionalAttrParserCode = R"(
- {0} {1}Attr;
{
::mlir::OptionalParseResult parseResult =
- parser.parseOptionalAttribute({1}Attr{2}, "{1}", result.attributes);
+ parser.parseOptionalAttribute({0}Attr{1}, "{0}", result.attributes);
if (parseResult.hasValue() && failed(*parseResult))
return failure();
}
return parser.emitError(loc, "invalid ")
<< "{0} attribute specification: " << attrVal;
- result.addAttribute("{0}", {3});
+ {0}Attr = {3};
+ result.addAttribute("{0}", {0}Attr);
}
)";
const char *const optionalEnumAttrParserCode = R"(
- Attribute {0}Attr;
{
::mlir::StringAttr attrVal;
::mlir::NamedAttrList attrStorage;
///
/// {0}: The name of the operand.
const char *const variadicOperandParserCode = R"(
+ {0}OperandsLoc = parser.getCurrentLocation();
if (parser.parseOperandList({0}Operands))
return failure();
)";
const char *const optionalOperandParserCode = R"(
{
+ {0}OperandsLoc = parser.getCurrentLocation();
::mlir::OpAsmParser::OperandType operand;
::mlir::OptionalParseResult parseResult =
parser.parseOptionalOperand(operand);
}
)";
const char *const operandParserCode = R"(
+ {0}OperandsLoc = parser.getCurrentLocation();
if (parser.parseOperand({0}RawOperands[0]))
return failure();
)";
///
/// {0}: The name for the successor list.
const char *successorListParserCode = R"(
- ::llvm::SmallVector<::mlir::Block *, 2> {0}Successors;
{
::mlir::Block *succ;
auto firstSucc = parser.parseOptionalSuccessor(succ);
///
/// {0}: The name of the successor.
const char *successorParserCode = R"(
- ::mlir::Block *{0}Successor = nullptr;
if (parser.parseSuccessor({0}Successor))
return failure();
)";
/// Generate the storage code required for parsing the given element.
static void genElementParserStorage(Element *element, OpMethodBody &body) {
if (auto *optional = dyn_cast<OptionalElement>(element)) {
- for (auto &childElement : optional->getElements())
- genElementParserStorage(&childElement, body);
+ auto elements = optional->getElements();
+
+ // If the anchor is a unit attribute, it won't be parsed directly so elide
+ // it.
+ auto *anchor = dyn_cast<AttributeVariable>(optional->getAnchor());
+ Element *elidedAnchorElement = nullptr;
+ if (anchor && anchor != &*elements.begin() && anchor->isUnitAttr())
+ elidedAnchorElement = anchor;
+ for (auto &childElement : elements)
+ if (&childElement != elidedAnchorElement)
+ genElementParserStorage(&childElement, body);
+
+ } else if (auto *custom = dyn_cast<CustomDirective>(element)) {
+ for (auto ¶mElement : custom->getArguments())
+ genElementParserStorage(¶mElement, body);
+
+ } else if (isa<OperandsDirective>(element)) {
+ body << " ::mlir::SmallVector<::mlir::OpAsmParser::OperandType, 4> "
+ "allOperands;\n";
+
+ } else if (isa<SuccessorsDirective>(element)) {
+ body << " ::llvm::SmallVector<::mlir::Block *, 2> fullSuccessors;\n";
+
+ } else if (auto *attr = dyn_cast<AttributeVariable>(element)) {
+ const NamedAttribute *var = attr->getVar();
+ body << llvm::formatv(" {0} {1}Attr;\n", var->attr.getStorageType(),
+ var->name);
+
} else if (auto *operand = dyn_cast<OperandVariable>(element)) {
StringRef name = operand->getVar()->name;
if (operand->getVar()->isVariableLength()) {
<< " ::llvm::ArrayRef<::mlir::OpAsmParser::OperandType> " << name
<< "Operands(" << name << "RawOperands);";
}
- body << llvm::formatv(
- " ::llvm::SMLoc {0}OperandsLoc = parser.getCurrentLocation();\n"
- " (void){0}OperandsLoc;\n",
- name);
+ body << llvm::formatv(" ::llvm::SMLoc {0}OperandsLoc;\n"
+ " (void){0}OperandsLoc;\n",
+ name);
+ } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
+ StringRef name = successor->getVar()->name;
+ if (successor->getVar()->isVariadic()) {
+ body << llvm::formatv(" ::llvm::SmallVector<::mlir::Block *, 2> "
+ "{0}Successors;\n",
+ name);
+ } else {
+ body << llvm::formatv(" ::mlir::Block *{0}Successor = nullptr;\n", name);
+ }
+
} else if (auto *dir = dyn_cast<TypeDirective>(element)) {
ArgumentLengthKind lengthKind;
StringRef name = getTypeListName(dir->getOperand(), lengthKind);
}
}
+/// Generate the parser for a parameter to a custom directive.
+static void genCustomParameterParser(Element ¶m, OpMethodBody &body) {
+ body << ", ";
+ if (auto *attr = dyn_cast<AttributeVariable>(¶m)) {
+ body << attr->getVar()->name << "Attr";
+
+ } else if (auto *operand = dyn_cast<OperandVariable>(¶m)) {
+ StringRef name = operand->getVar()->name;
+ ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
+ if (lengthKind == ArgumentLengthKind::Variadic)
+ body << llvm::formatv("{0}Operands", name);
+ else if (lengthKind == ArgumentLengthKind::Optional)
+ body << llvm::formatv("{0}Operand", name);
+ else
+ body << formatv("{0}RawOperands[0]", name);
+
+ } else if (auto *successor = dyn_cast<SuccessorVariable>(¶m)) {
+ StringRef name = successor->getVar()->name;
+ if (successor->getVar()->isVariadic())
+ body << llvm::formatv("{0}Successors", name);
+ else
+ body << llvm::formatv("{0}Successor", name);
+
+ } else if (auto *dir = dyn_cast<TypeDirective>(¶m)) {
+ ArgumentLengthKind lengthKind;
+ StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
+ if (lengthKind == ArgumentLengthKind::Variadic)
+ body << llvm::formatv("{0}Types", listName);
+ else if (lengthKind == ArgumentLengthKind::Optional)
+ body << llvm::formatv("{0}Type", listName);
+ else
+ body << formatv("{0}RawTypes[0]", listName);
+ } else {
+ llvm_unreachable("unknown custom directive parameter");
+ }
+}
+
+/// Generate the parser for a custom directive.
+static void genCustomDirectiveParser(CustomDirective *dir, OpMethodBody &body) {
+ body << " {\n";
+
+ // Preprocess the directive variables.
+ // * Add a local variable for optional operands and types. This provides a
+ // better API to the user defined parser methods.
+ // * Set the location of operand variables.
+ for (Element ¶m : dir->getArguments()) {
+ if (auto *operand = dyn_cast<OperandVariable>(¶m)) {
+ body << " " << operand->getVar()->name
+ << "OperandsLoc = parser.getCurrentLocation();\n";
+ if (operand->getVar()->isOptional()) {
+ body << llvm::formatv(
+ " llvm::Optional<::mlir::OpAsmParser::OperandType> "
+ "{0}Operand;\n",
+ operand->getVar()->name);
+ }
+ } else if (auto *dir = dyn_cast<TypeDirective>(¶m)) {
+ ArgumentLengthKind lengthKind;
+ StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
+ if (lengthKind == ArgumentLengthKind::Optional)
+ body << llvm::formatv(" ::mlir::Type {0}Type;\n", listName);
+ }
+ }
+
+ body << " if (parse" << dir->getName() << "(parser";
+ for (Element ¶m : dir->getArguments())
+ genCustomParameterParser(param, body);
+
+ body << "))\n"
+ << " return failure();\n";
+
+ // After parsing, add handling for any of the optional constructs.
+ for (Element ¶m : dir->getArguments()) {
+ if (auto *attr = dyn_cast<AttributeVariable>(¶m)) {
+ const NamedAttribute *var = attr->getVar();
+ if (var->attr.isOptional())
+ body << llvm::formatv(" if ({0}Attr)\n ", var->name);
+
+ body << llvm::formatv(
+ " result.attributes.addAttribute(\"{0}\", {0}Attr);", var->name);
+ } else if (auto *operand = dyn_cast<OperandVariable>(¶m)) {
+ const NamedTypeConstraint *var = operand->getVar();
+ if (!var->isOptional())
+ continue;
+ body << llvm::formatv(" if ({0}Operand.hasValue())\n"
+ " {0}Operands.push_back(*{0}Operand);\n",
+ var->name);
+ } else if (auto *dir = dyn_cast<TypeDirective>(¶m)) {
+ ArgumentLengthKind lengthKind;
+ StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
+ if (lengthKind == ArgumentLengthKind::Optional) {
+ body << llvm::formatv(" if ({0}Type)\n"
+ " {0}Types.push_back({0}Type);\n",
+ listName);
+ }
+ }
+ }
+
+ body << " }\n";
+}
+
/// Generate the parser for a single format element.
static void genElementParser(Element *element, OpMethodBody &body,
FmtContext &attrTypeCtx) {
body << formatv(var->attr.isOptional() ? optionalAttrParserCode
: attrParserCode,
- var->attr.getStorageType(), var->name, attrTypeStr);
+ var->name, attrTypeStr);
} else if (auto *operand = dyn_cast<OperandVariable>(element)) {
ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
StringRef name = operand->getVar()->name;
<< (attrDict->isWithKeyword() ? "WithKeyword" : "")
<< "(result.attributes))\n"
<< " return failure();\n";
+ } else if (auto *customDir = dyn_cast<CustomDirective>(element)) {
+ genCustomDirectiveParser(customDir, body);
+
} else if (isa<OperandsDirective>(element)) {
body << " ::llvm::SMLoc allOperandLoc = parser.getCurrentLocation();\n"
- << " ::mlir::SmallVector<::mlir::OpAsmParser::OperandType, 4> "
- "allOperands;\n"
<< " if (parser.parseOperandList(allOperands))\n"
<< " return failure();\n";
} else if (isa<SuccessorsDirective>(element)) {
llvm::interleaveComma(op.getOperands(), body, interleaveFn);
body << "}));\n";
}
+
+ if (!allResultTypes && op.getTrait("OpTrait::AttrSizedResultSegments")) {
+ body << " result.addAttribute(\"result_segment_sizes\", "
+ << "parser.getBuilder().getI32VectorAttr({";
+ auto interleaveFn = [&](const NamedTypeConstraint &result) {
+ // If the result is variadic emit the parsed size.
+ if (result.isVariableLength())
+ body << "static_cast<int32_t>(" << result.name << "Types.size())";
+ else
+ body << "1";
+ };
+ llvm::interleaveComma(op.getResults(), body, interleaveFn);
+ body << "}));\n";
+ }
}
//===----------------------------------------------------------------------===//
// Elide the variadic segment size attributes if necessary.
if (!fmt.allOperands && op.getTrait("OpTrait::AttrSizedOperandSegments"))
body << "\"operand_segment_sizes\", ";
+ if (!fmt.allResultTypes && op.getTrait("OpTrait::AttrSizedResultSegments"))
+ body << "\"result_segment_sizes\", ";
llvm::interleaveComma(usedAttributes, body, [&](const NamedAttribute *attr) {
body << "\"" << attr->name << "\"";
});
lastWasPunctuation = !(value.front() == '_' || isalpha(value.front()));
}
+/// Generate the printer for a literal value. `shouldEmitSpace` is true if a
+/// space should be emitted before this element. `lastWasPunctuation` is true if
+/// the previous element was a punctuation literal.
+static void genCustomDirectivePrinter(CustomDirective *customDir,
+ OpMethodBody &body) {
+ body << " print" << customDir->getName() << "(p";
+ for (Element ¶m : customDir->getArguments()) {
+ body << ", ";
+ if (auto *attr = dyn_cast<AttributeVariable>(¶m)) {
+ body << attr->getVar()->name << "Attr()";
+
+ } else if (auto *operand = dyn_cast<OperandVariable>(¶m)) {
+ body << operand->getVar()->name << "()";
+
+ } else if (auto *successor = dyn_cast<SuccessorVariable>(¶m)) {
+ body << successor->getVar()->name << "()";
+
+ } else if (auto *dir = dyn_cast<TypeDirective>(¶m)) {
+ auto *typeOperand = dir->getOperand();
+ auto *operand = dyn_cast<OperandVariable>(typeOperand);
+ auto *var = operand ? operand->getVar()
+ : cast<ResultVariable>(typeOperand)->getVar();
+ if (var->isVariadic())
+ body << var->name << "().getTypes()";
+ else if (var->isOptional())
+ body << llvm::formatv("({0}() ? {0}().getType() : Type())", var->name);
+ else
+ body << var->name << "().getType()";
+ } else {
+ llvm_unreachable("unknown custom directive parameter");
+ }
+ }
+
+ body << ");\n";
+}
+
/// Generate the C++ for an operand to a (*-)type directive.
static OpMethodBody &genTypeOperandPrinter(Element *arg, OpMethodBody &body) {
if (isa<OperandsDirective>(arg))
body << " ::llvm::interleaveComma(" << var->name << "(), p);\n";
else
body << " p << " << var->name << "();\n";
+ } else if (auto *dir = dyn_cast<CustomDirective>(element)) {
+ genCustomDirectivePrinter(dir, body);
} else if (isa<OperandsDirective>(element)) {
body << " p << getOperation()->getOperands();\n";
} else if (isa<SuccessorsDirective>(element)) {
caret,
comma,
equal,
+ less,
+ greater,
question,
// Keywords.
keyword_start,
kw_attr_dict,
kw_attr_dict_w_keyword,
+ kw_custom,
kw_functional_type,
kw_operands,
kw_results,
return formToken(Token::comma, tokStart);
case '=':
return formToken(Token::equal, tokStart);
+ case '<':
+ return formToken(Token::less, tokStart);
+ case '>':
+ return formToken(Token::greater, tokStart);
case '?':
return formToken(Token::question, tokStart);
case '(':
llvm::StringSwitch<Token::Kind>(str)
.Case("attr-dict", Token::kw_attr_dict)
.Case("attr-dict-with-keyword", Token::kw_attr_dict_w_keyword)
+ .Case("custom", Token::kw_custom)
.Case("functional-type", Token::kw_functional_type)
.Case("operands", Token::kw_operands)
.Case("results", Token::kw_results)
/// Function to find an element within the given range that has the same name as
/// 'name'.
-template <typename RangeT>
-static auto findArg(RangeT &&range, StringRef name) {
+template <typename RangeT> static auto findArg(RangeT &&range, StringRef name) {
auto it = llvm::find_if(range, [=](auto &arg) { return arg.name == name; });
return it != range.end() ? &*it : nullptr;
}
LogicalResult parseAttrDictDirective(std::unique_ptr<Element> &element,
llvm::SMLoc loc, bool isTopLevel,
bool withKeyword);
+ LogicalResult parseCustomDirective(std::unique_ptr<Element> &element,
+ llvm::SMLoc loc, bool isTopLevel);
+ LogicalResult parseCustomDirectiveParameter(
+ std::vector<std::unique_ptr<Element>> ¶meters);
LogicalResult parseFunctionalTypeDirective(std::unique_ptr<Element> &element,
Token tok, bool isTopLevel);
LogicalResult parseOperandsDirective(std::unique_ptr<Element> &element,
case Token::kw_attr_dict_w_keyword:
return parseAttrDictDirective(element, dirTok.getLoc(), isTopLevel,
/*withKeyword=*/true);
+ case Token::kw_custom:
+ return parseCustomDirective(element, dirTok.getLoc(), isTopLevel);
case Token::kw_functional_type:
return parseFunctionalTypeDirective(element, dirTok, isTopLevel);
case Token::kw_operands:
seenVariables.insert(ele->getVar());
return success();
})
- // Literals and type directives may be used, but they can't anchor the
- // group.
- .Case<LiteralElement, TypeDirective, FunctionalTypeDirective>(
- [&](Element *) {
- if (isAnchor)
- return emitError(childLoc, "only variables can be used to anchor "
- "an optional group");
- return success();
- })
+ // Literals, custom directives, and type directives may be used,
+ // but they can't anchor the group.
+ .Case<LiteralElement, CustomDirective, TypeDirective,
+ FunctionalTypeDirective>([&](Element *) {
+ if (isAnchor)
+ return emitError(childLoc, "only variables can be used to anchor "
+ "an optional group");
+ return success();
+ })
.Default([&](Element *) {
return emitError(childLoc, "only literals, types, and variables can be "
"used within an optional group");
}
LogicalResult
+FormatParser::parseCustomDirective(std::unique_ptr<Element> &element,
+ llvm::SMLoc loc, bool isTopLevel) {
+ llvm::SMLoc curLoc = curToken.getLoc();
+
+ // Parse the custom directive name.
+ if (failed(
+ parseToken(Token::less, "expected '<' before custom directive name")))
+ return failure();
+
+ Token nameTok = curToken;
+ if (failed(parseToken(Token::identifier,
+ "expected custom directive name identifier")) ||
+ failed(parseToken(Token::greater,
+ "expected '>' after custom directive name")) ||
+ failed(parseToken(Token::l_paren,
+ "expected '(' before custom directive parameters")))
+ return failure();
+
+ // Parse the child elements for this optional group.=
+ std::vector<std::unique_ptr<Element>> elements;
+ do {
+ if (failed(parseCustomDirectiveParameter(elements)))
+ return failure();
+ if (curToken.getKind() != Token::comma)
+ break;
+ consumeToken();
+ } while (true);
+
+ if (failed(parseToken(Token::r_paren,
+ "expected ')' after custom directive parameters")))
+ return failure();
+
+ // After parsing all of the elements, ensure that all type directives refer
+ // only to variables.
+ for (auto &ele : elements) {
+ if (auto *typeEle = dyn_cast<TypeDirective>(ele.get())) {
+ if (!isa<OperandVariable, ResultVariable>(typeEle->getOperand())) {
+ return emitError(curLoc, "type directives within a custom directive "
+ "may only refer to variables");
+ }
+ }
+ }
+
+ element = std::make_unique<CustomDirective>(nameTok.getSpelling(),
+ std::move(elements));
+ return success();
+}
+
+LogicalResult FormatParser::parseCustomDirectiveParameter(
+ std::vector<std::unique_ptr<Element>> ¶meters) {
+ llvm::SMLoc childLoc = curToken.getLoc();
+ parameters.push_back({});
+ if (failed(parseElement(parameters.back(), /*isTopLevel=*/true)))
+ return failure();
+
+ // Verify that the element can be placed within a custom directive.
+ if (!isa<TypeDirective, AttributeVariable, OperandVariable,
+ SuccessorVariable>(parameters.back().get())) {
+ return emitError(childLoc, "only variables and types may be used as "
+ "parameters to a custom directive");
+ }
+ return success();
+}
+
+LogicalResult
FormatParser::parseFunctionalTypeDirective(std::unique_ptr<Element> &element,
Token tok, bool isTopLevel) {
llvm::SMLoc loc = tok.getLoc();