static const char *const generatedArgName = "_arg";
+// Helper macro that returns indented os.
+#define OUT(X) os.indent((X))
+
// TODO(jpienaar): The builder body should probably be separate from the header.
// Variation of method in FormatVariadic.h which takes a StringRef as input
os << "> {\npublic:\n";
// Build operation name.
- os << " static StringRef getOperationName() { return \""
- << emitter.op.getOperationName() << "\"; };\n";
+ OUT(2) << "static StringRef getOperationName() { return \""
+ << emitter.op.getOperationName() << "\"; };\n";
emitter.emitNamedOperands();
emitter.emitBuilder();
emitter.emitCanonicalizationPatterns();
emitter.emitConstantFolder();
- os << "private:\n friend class ::mlir::OperationInst;\n";
- os << " explicit " << emitter.op.cppClassName()
+ os << "private:\n friend class ::mlir::OperationInst;\n"
+ << " explicit " << emitter.op.cppClassName()
<< "(const OperationInst* state) : Op(state) {}\n};\n";
emitter.mapOverClassNamespaces(
[&os](StringRef ns) { os << "} // end namespace " << ns << "\n"; });
// Emit the derived attribute body.
if (attr.isDerived) {
- os << " " << def->getValueAsString("returnType").trim() << ' ' << name
- << "() const {" << def->getValueAsString("body") << " }\n";
+ OUT(2) << attr.getReturnType() << ' ' << name << "() const {"
+ << def->getValueAsString("body") << " }\n";
continue;
}
// Emit normal emitter.
- os << " " << def->getValueAsString("returnType").trim() << ' ' << name
- << "() const {\n";
+ OUT(2) << attr.getReturnType() << ' ' << name << "() const {\n";
// Return the queried attribute with the correct return type.
- std::string attrVal =
- formatv("this->getAttrOfType<{0}>(\"{1}\")",
- def->getValueAsString("storageType").trim(), name);
- os << " return "
- << formatv(def->getValueAsString("convertFromStorage"), attrVal)
- << ";\n }\n";
+ std::string attrVal = formatv("this->getAttrOfType<{0}>(\"{1}\")",
+ attr.getStorageType(), name);
+ OUT(4) << "return "
+ << formatv(def->getValueAsString("convertFromStorage"), attrVal)
+ << ";\n }\n";
}
}
// 1. Stand-alone parameters
std::vector<Record *> returnTypes = def.getValueAsListOfDefs("returnTypes");
- os << " static void build(Builder* builder, OperationState* result";
+ OUT(2) << "static void build(Builder* builder, OperationState* result";
// Emit parameters for all return types
for (unsigned i = 0, e = returnTypes.size(); i != e; ++i)
// Push all result types to the result
if (!returnTypes.empty()) {
- os << " result->addTypes({returnType0";
+ OUT(4) << "result->addTypes({returnType0";
for (unsigned i = 1, e = returnTypes.size(); i != e; ++i)
os << ", returnType" << i;
os << "});\n\n";
// Push all operands to the result
if (op.getNumOperands() > 0) {
- os << " result->addOperands({" << getArgumentName(op, 0);
+ OUT(4) << "result->addOperands({" << getArgumentName(op, 0);
for (int i = 1, e = op.getNumOperands(); i != e; ++i)
os << ", " << getArgumentName(op, i);
os << "});\n";
// Push all attributes to the result
for (const auto &attr : op.getAttributes())
if (!attr.isDerived)
- os.indent(4) << formatv("result->addAttribute(\"{0}\", {0});\n",
- getAttributeName(attr));
- os << " }\n";
+ OUT(4) << formatv("result->addAttribute(\"{0}\", {0});\n",
+ getAttributeName(attr));
+ OUT(2) << "}\n";
// 2. Aggregated parameters
// Signature
- os << " static void build(Builder* builder, OperationState* result, "
- << "ArrayRef<Type> resultTypes, ArrayRef<Value*> args, "
- "ArrayRef<NamedAttribute> attributes) {\n";
+ OUT(2) << "static void build(Builder* builder, OperationState* result, "
+ << "ArrayRef<Type> resultTypes, ArrayRef<Value*> args, "
+ "ArrayRef<NamedAttribute> attributes) {\n";
// Result types
- os << " assert(resultTypes.size() == " << returnTypes.size()
- << "u && \"mismatched number of return types\");\n"
- << " result->addTypes(resultTypes);\n";
+ OUT(4) << "assert(resultTypes.size() == " << returnTypes.size()
+ << "u && \"mismatched number of return types\");\n"
+ << " result->addTypes(resultTypes);\n";
// Operands
- os << " assert(args.size() == " << op.getNumOperands()
- << "u && \"mismatched number of parameters\");\n"
- << " result->addOperands(args);\n\n";
+ OUT(4) << "assert(args.size() == " << op.getNumOperands()
+ << "u && \"mismatched number of parameters\");\n"
+ << " result->addOperands(args);\n\n";
// Attributes
if (op.getNumAttributes() > 0) {
- os << " assert(!attributes.size() && \"no attributes expected\");\n"
- << " }\n";
+ OUT(4) << "assert(!attributes.size() && \"no attributes expected\");\n"
+ << " }\n";
} else {
- os << " assert(attributes.size() >= " << op.getNumAttributes()
- << "u && \"not enough attributes\");\n"
- << " for (const auto& pair : attributes)\n"
- << " result->addAttribute(pair.first, pair.second);\n"
- << " }\n";
+ OUT(4) << "assert(attributes.size() >= " << op.getNumAttributes()
+ << "u && \"not enough attributes\");\n"
+ << " for (const auto& pair : attributes)\n"
+ << " result->addAttribute(pair.first, pair.second);\n"
+ << " }\n";
}
}
void OpEmitter::emitCanonicalizationPatterns() {
if (!def.getValueAsBit("hasCanonicalizationPatterns"))
return;
- os << " static void getCanonicalizationPatterns("
- << "OwningRewritePatternList &results, MLIRContext* context);\n";
+ OUT(2) << "static void getCanonicalizationPatterns("
+ << "OwningRewritePatternList &results, MLIRContext* context);\n";
}
void OpEmitter::emitConstantFolder() {
if (!hasCustomVerify && op.getNumArgs() == 0)
return;
- os << " bool verify() const {\n";
+ OUT(2) << "bool verify() const {\n";
// Verify the attributes have the correct type.
for (const auto &attr : op.getAttributes()) {
if (attr.isDerived)
auto name = getAttributeName(attr);
if (!hasStringAttribute(*attr.record, "storageType")) {
- os << " if (!this->getAttr(\"" << name
- << "\")) return emitOpError(\"requires attribute '" << name
- << "'\");\n";
+ OUT(4) << "if (!this->getAttr(\"" << name
+ << "\")) return emitOpError(\"requires attribute '" << name
+ << "'\");\n";
continue;
}
- os << " if (!this->getAttr(\"" << name << "\").dyn_cast_or_null<"
- << attr.record->getValueAsString("storageType").trim()
- << ">()) return emitOpError(\"requires "
- << attr.record->getValueAsString("returnType").trim() << " attribute '"
- << name << "'\");\n";
+ OUT(4) << "if (!this->getAttr(\"" << name << "\").dyn_cast_or_null<"
+ << attr.getStorageType() << ">()) return emitOpError(\"requires "
+ << attr.getReturnType() << " attribute '" << name << "'\");\n";
}
// TODO: Handle variadic.
if (operand.hasMatcher()) {
auto pred =
"if (!(" + operand.createTypeMatcherTemplate() + ")) return false;\n";
- os.indent(4) << formatv(pred, "this->getInstruction()->getOperand(" +
- Twine(opIndex) + ")->getType()");
+ OUT(4) << formatv(pred, "this->getInstruction()->getOperand(" +
+ Twine(opIndex) + ")->getType()");
}
++opIndex;
}
if (hasCustomVerify)
- os << " " << codeInit->getValue() << "\n";
+ OUT(4) << codeInit->getValue() << "\n";
else
- os << " return false;\n";
- os << " }\n";
+ OUT(4) << "return false;\n";
+ OUT(2) << "}\n";
}
void OpEmitter::emitTraits() {
"' in pattern and op's definition");
for (int i = 0, e = tree->getNumArgs(); i != e; ++i) {
auto arg = tree->getArg(i);
+ auto opArg = op.getArg(i);
+
if (auto argTree = dyn_cast<DagInit>(arg)) {
os.indent(indent) << "{\n";
os.indent(indent + 2) << formatv(
// Verify arguments.
if (auto defInit = dyn_cast<DefInit>(arg)) {
- auto opArg = op.getArg(i);
// Verify operands.
if (auto *operand = opArg.dyn_cast<Operator::Operand *>()) {
// Skip verification where not needed due to definition of op.
if (operand->defInit == defInit)
- goto SkipOperandVerification;
+ goto StateCapture;
if (!defInit->getDef()->isSubClassOf("Type"))
PrintFatalError(pattern->getLoc(),
formatv("op{0}->getOperand({1})->getType()", depth, i))
<< ")) return matchFailure();\n";
}
+
+ // TODO(jpienaar): Verify attributes.
+ if (auto *attr = opArg.dyn_cast<Operator::Attribute *>()) {
+ }
}
- SkipOperandVerification:
- // TODO(jpienaar): Verify attributes.
+ StateCapture:
auto name = tree->getArgNameStr(i);
if (name.empty())
continue;
- os.indent(indent) << "state->" << name << " = op" << depth
- << "->getOperand(" << i << ");\n";
+ if (opArg.is<Operator::Operand *>())
+ os.indent(indent) << "state->" << name << " = op" << depth
+ << "->getOperand(" << i << ");\n";
+ if (auto attr = opArg.dyn_cast<Operator::Attribute *>()) {
+ os.indent(indent) << "state->" << name << " = op" << depth
+ << "->getAttrOfType<" << attr->getStorageType()
+ << ">(\"" << attr->getName() << "\");\n";
+ }
}
}
(os << ",\n").indent(6);
// The argument in the result DAG pattern.
- auto name = resultOp.getArgName(i);
+ auto name = resultTree->getArgNameStr(i);
+ auto opName = resultOp.getArgName(i);
auto defInit = dyn_cast<DefInit>(resultTree->getArg(i));
auto *value = defInit ? defInit->getDef()->getValue("value") : nullptr;
- if (!value)
- PrintFatalError(pattern->getLoc(),
- Twine("attribute '") + name +
- "' needs to be constant initialized");
+ if (!value) {
+ if (boundArguments.find(name) == boundArguments.end())
+ PrintFatalError(pattern->getLoc(),
+ Twine("referencing unbound variable '") + name + "'");
+ os << "/*" << opName << "=*/"
+ << "s." << name;
+ continue;
+ }
// TODO(jpienaar): Refactor out into map to avoid recomputing these.
auto argument = resultOp.getArg(i);