//
//===----------------------------------------------------------------------===//
+#include "mlir/Support/STLExtras.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/OpTrait.h"
#include "mlir/TableGen/Operator.h"
static const char *const generatedArgName = "_arg";
-// Helper macro that returns indented os.
-#define OUT(X) os.indent((X))
+static const char *const opCommentHeader = R"(
+//===----------------------------------------------------------------------===//
+// {0} {1}
+//===----------------------------------------------------------------------===//
-// TODO(jpienaar): The builder body should probably be separate from the header.
+)";
+
+//===----------------------------------------------------------------------===//
+// Utility structs and functions
+//===----------------------------------------------------------------------===//
// Variation of method in FormatVariadic.h which takes a StringRef as input
// instead.
return isa<CodeInit>(valueInit) || isa<StringInit>(valueInit);
}
+// Returns the given `op`'s qualified C++ class name.
+static std::string getOpQualClassName(const Record &op) {
+ SmallVector<StringRef, 2> splittedName;
+ llvm::SplitString(op.getName(), splittedName, "_");
+ return llvm::join(splittedName, "::");
+}
+
static std::string getArgumentName(const Operator &op, int index) {
const auto &operand = op.getOperand(index);
if (!operand.name.empty())
public:
IfDefScope(StringRef name, raw_ostream &os) : name(name), os(os) {
os << "#ifdef " << name << "\n"
- << "#undef " << name << "\n";
+ << "#undef " << name << "\n\n";
}
~IfDefScope() { os << "\n#endif // " << name << "\n\n"; }
};
} // end anonymous namespace
+//===----------------------------------------------------------------------===//
+// Classes for C++ code emission
+//===----------------------------------------------------------------------===//
+
+// We emit the op declaration and definition into separate files: *Ops.h.inc
+// and *Ops.cpp.inc. The former is to be included in the dialect *Ops.h and
+// the latter for dialect *Ops.cpp. This way provides a cleaner interface.
+//
+// In order to do this split, we need to track method signature and
+// implementation logic separately. Signature information is used for both
+// declaration and definition, while implementation logic is only for
+// definition. So we have the following classes for C++ code emission.
+
namespace {
-// Helper class to emit a record into the given output stream.
-class OpEmitter {
+// Class for holding the signature of an op's method for C++ code emission
+class OpMethodSignature {
public:
- static void emit(const Record &def, raw_ostream &os);
+ OpMethodSignature(StringRef retType, StringRef name, StringRef params);
- // Emit getters for the attributes of the operation.
- void emitAttrGetters();
+ // Writes the signature as a method declaration to the given `os`.
+ void writeDeclTo(raw_ostream &os) const;
+ // Writes the signature as the start of a method definition to the given `os`.
+ // `namePrefix` is the prefix to be prepended to the method name (typically
+ // namespaces for qualifying the method definition).
+ void writeDefTo(raw_ostream &os, StringRef namePrefix) const;
- // Emit query methods for the named operands.
- void emitNamedOperands();
+private:
+ // Returns true if the given C++ `type` ends with '&' or '*'.
+ static bool endsWithRefOrPtr(StringRef type);
- // Emit query methods for the named results.
- void emitNamedResults();
+ std::string returnType;
+ std::string methodName;
+ std::string parameters;
+};
- // Emit builder method for the operation.
- void emitBuilder();
+// Class for holding the body of an op's method for C++ code emission
+class OpMethodBody {
+public:
+ explicit OpMethodBody(bool declOnly);
- // Emit method declaration for the getCanonicalizationPatterns() interface.
- void emitCanonicalizationPatterns();
+ OpMethodBody &operator<<(Twine content);
+ OpMethodBody &operator<<(int content);
- // Emit the folder methods for the operation.
- void emitFolders();
+ void writeTo(raw_ostream &os) const;
- // Emit the parser for the operation.
- void emitParser();
+private:
+ // Whether this class should record method body.
+ bool isEffective;
+ std::string body;
+};
+
+// Class for holding an op's method for C++ code emission
+class OpMethod {
+public:
+ // Properties (qualifiers) of class methods. Bitfield is used here to help
+ // querying properties.
+ enum Property {
+ MP_None = 0x0,
+ MP_Static = 0x1, // Static method
+ MP_Const = 0x2, // Const method
+ };
- // Emit the printer for the operation.
- void emitPrinter();
+ OpMethod(StringRef retType, StringRef name, StringRef params,
+ Property property, bool declOnly);
- // Emit verify method for the operation.
- void emitVerifier();
+ OpMethodSignature &signature();
+ OpMethodBody &body();
- // Emit the traits used by the object.
- void emitTraits();
+ // Returns true if this is a static method.
+ bool isStatic() const;
+ // Returns true if this is a const method.
+ bool isConst() const;
+
+ // Writes the method as a declaration to the given `os`.
+ void writeDeclTo(raw_ostream &os) const;
+ // Writes the method as a definition to the given `os`. `namePrefix` is the
+ // prefix to be prepended to the method name (typically namespaces for
+ // qualifying the method definition).
+ void writeDefTo(raw_ostream &os, StringRef namePrefix) const;
private:
- OpEmitter(const Record &def, raw_ostream &os);
+ Property properties;
+ // Whether this method only contains a declaration.
+ bool isDeclOnly;
+ OpMethodSignature methodSignature;
+ OpMethodBody methodBody;
+};
- // Invokes the given function over all the namespaces of the class.
- void mapOverClassNamespaces(function_ref<void(StringRef)> fn);
+// Class for holding an op for C++ code emission
+class OpClass {
+public:
+ explicit OpClass(StringRef name);
- // Emits the build() method that takes each result-type/operand/attribute as
- // a stand-alone parameter. Using the first operand's type as all result
- // types if `isAllSameType` is true.
- void emitStandaloneParamBuilder(bool isAllSameType);
+ // Adds an op trait.
+ void addTrait(Twine trait);
- // The record corresponding to the op.
- const Record &def;
+ // Creates a new method in this op's class.
+ OpMethod &newMethod(StringRef retType, StringRef name, StringRef params = "",
+ OpMethod::Property = OpMethod::MP_None,
+ bool declOnly = false);
- // The operator being emitted.
- Operator op;
+ // Writes this op's class as a declaration to the given `os`.
+ void writeDeclTo(raw_ostream &os) const;
+ // Writes the method definitions in this op's class to the given `os`.
+ void writeDefTo(raw_ostream &os) const;
- raw_ostream &os;
+private:
+ std::string className;
+ SmallVector<std::string, 4> traits;
+ SmallVector<OpMethod, 8> methods;
};
} // end anonymous namespace
-OpEmitter::OpEmitter(const Record &def, raw_ostream &os)
- : def(def), op(def), os(os) {}
+OpMethodSignature::OpMethodSignature(StringRef retType, StringRef name,
+ StringRef params)
+ : returnType(retType), methodName(name), parameters(params) {}
+
+void OpMethodSignature::writeDeclTo(raw_ostream &os) const {
+ os << returnType << (endsWithRefOrPtr(returnType) ? "" : " ") << methodName
+ << "(" << parameters << ")";
+}
+
+void OpMethodSignature::writeDefTo(raw_ostream &os,
+ StringRef namePrefix) const {
+ // We need to remove the default values for parameters in method definition.
+ // TODO(antiagainst): We are using '=' and ',' as delimiters for parameter
+ // initializers. This is incorrect for initializer list with more than one
+ // element. Change to a more robust approach.
+ auto removeParamDefaultValue = [](StringRef params) {
+ string result;
+ std::pair<StringRef, StringRef> parts;
+ while (!params.empty()) {
+ parts = params.split("=");
+ result.append(result.empty() ? "" : ", ");
+ result.append(parts.first);
+ params = parts.second.split(",").second;
+ }
+ return result;
+ };
+
+ os << returnType << (endsWithRefOrPtr(returnType) ? "" : " ") << namePrefix
+ << (namePrefix.empty() ? "" : "::") << methodName << "("
+ << removeParamDefaultValue(parameters) << ")";
+}
+
+bool OpMethodSignature::endsWithRefOrPtr(StringRef type) {
+ return type.endswith("&") || type.endswith("*");
+};
+
+OpMethodBody::OpMethodBody(bool declOnly) : isEffective(!declOnly) {}
+
+OpMethodBody &OpMethodBody::operator<<(Twine content) {
+ if (isEffective)
+ body.append(content.str());
+ return *this;
+}
+
+OpMethodBody &OpMethodBody::operator<<(int content) {
+ if (isEffective)
+ body.append(std::to_string(content));
+ return *this;
+}
-void OpEmitter::mapOverClassNamespaces(function_ref<void(StringRef)> fn) {
- // We only care about namespaces, so drop the class name here
- auto splittedDefName = op.getSplitDefName().drop_back();
- for (auto ns : splittedDefName)
- fn(ns);
+void OpMethodBody::writeTo(raw_ostream &os) const {
+ os << body;
+ if (body.back() != '\n')
+ os << "\n";
}
-void OpEmitter::emit(const Record &def, raw_ostream &os) {
- OpEmitter emitter(def, os);
+OpMethod::OpMethod(StringRef retType, StringRef name, StringRef params,
+ OpMethod::Property property, bool declOnly)
+ : properties(property), isDeclOnly(declOnly),
+ methodSignature(retType, name, params), methodBody(declOnly) {}
+
+OpMethodSignature &OpMethod::signature() { return methodSignature; }
+
+OpMethodBody &OpMethod::body() { return methodBody; }
- emitter.mapOverClassNamespaces(
- [&os](StringRef ns) { os << "\nnamespace " << ns << "{\n"; });
- os << formatv("class {0} : public Op<{0}", emitter.op.getCppClassName());
- emitter.emitTraits();
+bool OpMethod::isStatic() const { return properties & MP_Static; }
+bool OpMethod::isConst() const { return properties & MP_Const; }
+
+void OpMethod::writeDeclTo(raw_ostream &os) const {
+ os.indent(2);
+ if (isStatic())
+ os << "static ";
+ methodSignature.writeDeclTo(os);
+ if (isConst())
+ os << " const";
+ os << ";";
+}
+
+void OpMethod::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
+ if (isDeclOnly)
+ return;
+
+ methodSignature.writeDefTo(os, namePrefix);
+ if (isConst())
+ os << " const";
+ os << " {\n";
+ methodBody.writeTo(os);
+ os << "}";
+}
+
+OpClass::OpClass(StringRef name) : className(name) {}
+
+// Adds the given trait to this op. Prefixes "OpTrait::" to `trait` implicitly.
+void OpClass::addTrait(Twine trait) {
+ traits.push_back(("OpTrait::" + trait).str());
+}
+
+OpMethod &OpClass::newMethod(StringRef retType, StringRef name,
+ StringRef params, OpMethod::Property property,
+ bool declOnly) {
+ methods.emplace_back(retType, name, params, property, declOnly);
+ return methods.back();
+}
+
+void OpClass::writeDeclTo(raw_ostream &os) const {
+ os << "class " << className << " : public Op<" << className;
+ for (const auto &trait : traits)
+ os << ", " << trait;
os << "> {\npublic:\n";
+ for (const auto &method : methods) {
+ method.writeDeclTo(os);
+ os << "\n";
+ }
+ os << "\nprivate:\n"
+ << " friend class ::mlir::Instruction;\n";
+ os << " explicit " << className
+ << "(const Instruction *state) : Op(state) {}\n"
+ << "};";
+}
- // Build operation name.
- OUT(2) << "static StringRef getOperationName() { return \""
- << emitter.op.getOperationName() << "\"; };\n";
+void OpClass::writeDefTo(raw_ostream &os) const {
+ for (const auto &method : methods) {
+ method.writeDefTo(os, className);
+ os << "\n\n";
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// Op emitter
+//===----------------------------------------------------------------------===//
- emitter.emitNamedOperands();
- emitter.emitNamedResults();
- emitter.emitBuilder();
- emitter.emitParser();
- emitter.emitPrinter();
- emitter.emitVerifier();
- emitter.emitAttrGetters();
- emitter.emitCanonicalizationPatterns();
- emitter.emitFolders();
+namespace {
+// Helper class to emit a record into the given output stream.
+class OpEmitter {
+public:
+ static void emitDecl(const Record &def, raw_ostream &os);
+ static void emitDef(const Record &def, raw_ostream &os);
- os << "private:\n friend class ::mlir::Instruction;\n"
- << " explicit " << emitter.op.getCppClassName()
- << "(const Instruction* state) : Op(state) {}\n};\n";
- emitter.mapOverClassNamespaces(
- [&os](StringRef ns) { os << "} // end namespace " << ns << "\n"; });
+private:
+ OpEmitter(const Record &def);
+
+ void emitDecl(raw_ostream &os);
+ void emitDef(raw_ostream &os);
+
+ // Generates getters for the attributes.
+ void genAttrGetters();
+
+ // Generates getters for named operands.
+ void genNamedOperandGetters();
+
+ // Generates getters for named results.
+ void genNamedResultGetters();
+
+ // Generates builder method for the operation.
+ void genBuilder();
+
+ // Generates canonicalizer declaration for the operation.
+ void genCanonicalizerDecls();
+
+ // Generates the folder declaration for the operation.
+ void genFolderDecls();
+
+ // Generates the parser for the operation.
+ void genParser();
+
+ // Generates the printer for the operation.
+ void genPrinter();
+
+ // Generates verify method for the operation.
+ void genVerifier();
+
+ // Generates the traits used by the object.
+ void genTraits();
+
+ // Generates the build() method that takes each result-type/operand/attribute
+ // as a stand-alone parameter. Using the first operand's type as all result
+ // types if `isAllSameType` is true.
+ void genStandaloneParamBuilder(bool isAllSameType);
+
+ void genOpNameGetter();
+
+ // The TableGen record for this op.
+ const Record &def;
+
+ // The wrapper operator class for querying information from this op.
+ Operator op;
+
+ // The C++ code builder for this op
+ OpClass opClass;
+};
+} // end anonymous namespace
+
+OpEmitter::OpEmitter(const Record &def)
+ : def(def), op(def), opClass(op.getCppClassName()) {
+ genTraits();
+ // Generate C++ code for various op methods. The order here determines the
+ // methods in the generated file.
+ genOpNameGetter();
+ genNamedOperandGetters();
+ genNamedResultGetters();
+ genAttrGetters();
+ genBuilder();
+ genParser();
+ genPrinter();
+ genVerifier();
+ genCanonicalizerDecls();
+ genFolderDecls();
+}
+
+void OpEmitter::emitDecl(const Record &def, raw_ostream &os) {
+ OpEmitter(def).emitDecl(os);
+}
+
+void OpEmitter::emitDef(const Record &def, raw_ostream &os) {
+ OpEmitter(def).emitDef(os);
}
-void OpEmitter::emitAttrGetters() {
+void OpEmitter::emitDecl(raw_ostream &os) { opClass.writeDeclTo(os); }
+
+void OpEmitter::emitDef(raw_ostream &os) { opClass.writeDefTo(os); }
+
+void OpEmitter::genAttrGetters() {
for (auto &namedAttr : op.getAttributes()) {
auto name = namedAttr.getName();
const auto &attr = namedAttr.attr;
if (!it.second.empty())
getter = it.second;
+ auto &method = opClass.newMethod(attr.getReturnType(), getter,
+ /*params=*/"", OpMethod::MP_Const);
+
// Emit the derived attribute body.
if (attr.isDerivedAttr()) {
- OUT(2) << attr.getReturnType() << ' ' << getter << "() const {"
- << attr.getDerivedCodeBody() << " }\n";
+ method.body() << " " << attr.getDerivedCodeBody() << "\n";
continue;
}
// Emit normal emitter.
- OUT(2) << attr.getReturnType() << ' ' << getter << "() const {\n";
// Return the queried attribute with the correct return type.
std::string attrVal =
formatv("this->getAttr(\"{1}\").dyn_cast_or_null<{0}>()",
attr.getStorageType(), name);
- OUT(4) << "auto attr = " << attrVal << ";\n";
+ method.body() << " auto attr = " << attrVal << ";\n";
if (attr.hasDefaultValue()) {
// Returns the default value if not set.
// TODO: this is inefficient, we are recreating the attribute for every
// call. This should be set instead.
- OUT(4) << "if (!attr)\n";
- OUT(6) << "return "
- << formatv(
- attr.getConvertFromStorageCall(),
- formatv(
- attr.getDefaultValueTemplate(),
- "mlir::Builder(this->getInstruction()->getContext())"))
- << ";\n";
+ method.body()
+ << " if (!attr)\n"
+ " return "
+ << formatv(
+ attr.getConvertFromStorageCall(),
+ formatv(attr.getDefaultValueTemplate(),
+ "mlir::Builder(this->getInstruction()->getContext())"))
+ << ";\n";
}
- OUT(4) << "return " << formatv(attr.getConvertFromStorageCall(), "attr")
- << ";\n }\n";
+ method.body() << " return "
+ << formatv(attr.getConvertFromStorageCall(), "attr") << ";\n";
}
}
-void OpEmitter::emitNamedOperands() {
- const auto operandMethods = R"( Value *{0}() {
- return this->getInstruction()->getOperand({1});
- }
- const Value *{0}() const {
- return this->getInstruction()->getOperand({1});
- }
-)";
-
- const auto variadicOperandMethods = R"( SmallVector<Value *, 4> {0}() {
- assert(getInstruction()->getNumOperands() >= {1});
- SmallVector<Value *, 4> operands(
- std::next(getInstruction()->operand_begin(), {1}),
- getInstruction()->operand_end());
- return operands;
- }
- SmallVector<const Value *, 4> {0}() const {
- assert(getInstruction()->getNumOperands() >= {1});
- SmallVector<const Value *, 4> operands(
- std::next(getInstruction()->operand_begin(), {1}),
- getInstruction()->operand_end());
- return operands;
- }
-)";
-
+void OpEmitter::genNamedOperandGetters() {
for (int i = 0, e = op.getNumOperands(); i != e; ++i) {
const auto &operand = op.getOperand(i);
- if (!operand.name.empty()) {
- if (operand.constraint.isVariadic()) {
- assert(i == e - 1 && "only the last operand can be variadic");
- os << formatv(variadicOperandMethods, operand.name, i);
- } else {
- os << formatv(operandMethods, operand.name, i);
- }
+ if (operand.name.empty())
+ continue;
+
+ if (!operand.constraint.isVariadic()) {
+ auto &m1 = opClass.newMethod("Value *", operand.name);
+ m1.body() << " return this->getInstruction()->getOperand(" << i
+ << ");\n";
+ auto &m2 = opClass.newMethod("const Value *", operand.name, /*params=*/"",
+ OpMethod::MP_Const);
+ m2.body() << " return this->getInstruction()->getOperand(" << i
+ << ");\n";
+ } else {
+ assert(i + 1 == e && "only the last operand can be variadic");
+
+ const char *const code1 =
+ R"( assert(getInstruction()->getNumOperands() >= {0});
+ SmallVector<Value *, 4> operands(
+ std::next(getInstruction()->operand_begin(), {0}),
+ getInstruction()->operand_end());
+ return operands;)";
+ auto &m1 = opClass.newMethod("SmallVector<Value *, 4>", operand.name);
+ m1.body() << formatv(code1, i);
+
+ const char *const code2 =
+ R"( assert(getInstruction()->getNumOperands() >= {0});
+ SmallVector<const Value *, 4> operands(
+ std::next(getInstruction()->operand_begin(), {0}),
+ getInstruction()->operand_end());
+ return operands;)";
+ auto &m2 =
+ opClass.newMethod("const SmallVector<const Value *, 4>", operand.name,
+ /*params=*/"", OpMethod::MP_Const);
+ m2.body() << formatv(code2, i);
}
}
}
-void OpEmitter::emitNamedResults() {
- const auto resultMethods = R"( Value *{0}() {
- return this->getInstruction()->getResult({1});
- }
- const Value *{0}() const {
- return this->getInstruction()->getResult({1});
- }
-)";
+void OpEmitter::genNamedResultGetters() {
for (int i = 0, e = op.getNumResults(); i != e; ++i) {
const auto &result = op.getResult(i);
- if (!result.constraint.isVariadic() && !result.name.empty())
- os << formatv(resultMethods, result.name, i);
+ if (result.constraint.isVariadic() || result.name.empty())
+ continue;
+
+ auto &m1 = opClass.newMethod("Value *", result.name);
+ m1.body() << " return this->getInstruction()->getResult(" << i << ");\n";
+ auto &m2 = opClass.newMethod("const Value *", result.name, /*params=*/"",
+ OpMethod::MP_Const);
+ m2.body() << " return this->getInstruction()->getResult(" << i << ");\n";
}
}
-void OpEmitter::emitStandaloneParamBuilder(bool isAllSameType) {
- OUT(2) << "static void build(Builder *builder, OperationState *result";
-
+void OpEmitter::genStandaloneParamBuilder(bool isAllSameType) {
auto numResults = op.getNumResults();
-
llvm::SmallVector<std::string, 4> resultNames;
resultNames.reserve(numResults);
+ std::string paramList = "Builder *builder, OperationState *result";
+
// Emit parameters for all return types
if (!isAllSameType) {
for (unsigned i = 0; i != numResults; ++i) {
if (resultName.empty())
resultName = formatv("resultType{0}", i);
- os << (op.getResultTypeConstraint(i).isVariadic() ? ", ArrayRef<Type> "
- : ", Type ")
- << resultName;
+ bool isVariadic = op.getResultTypeConstraint(i).isVariadic();
+ paramList.append(isVariadic ? ", ArrayRef<Type> " : ", Type ");
+ paramList.append(resultName);
resultNames.emplace_back(std::move(resultName));
}
auto argument = op.getArg(i);
if (argument.is<tblgen::NamedTypeConstraint *>()) {
auto &operand = op.getOperand(numOperands);
- os << (operand.constraint.isVariadic() ? ", ArrayRef<Value *> "
- : ", Value *")
- << getArgumentName(op, numOperands);
+ paramList.append(operand.constraint.isVariadic() ? ", ArrayRef<Value *> "
+ : ", Value *");
+ paramList.append(getArgumentName(op, numOperands));
++numOperands;
} else {
// TODO(antiagainst): Support default initializer for attributes
const auto &namedAttr = op.getAttribute(numAttrs);
const auto &attr = namedAttr.attr;
- os << ", ";
+ paramList.append(", ");
if (attr.isOptional())
- os << "/*optional*/";
- os << attr.getStorageType() << ' ' << namedAttr.name;
+ paramList.append("/*optional*/");
+ paramList.append(
+ (attr.getStorageType() + Twine(" ") + namedAttr.name).str());
++numAttrs;
}
}
+
if (numOperands + numAttrs != op.getNumArgs())
return PrintFatalError(
"op arguments must be either operands or attributes");
- os << ") {\n";
+ auto &method =
+ opClass.newMethod("void", "build", paramList, OpMethod::MP_Static);
// Push all result types to the result
if (numResults > 0) {
numResults - static_cast<int>(hasVariadicResult);
if (numNonVariadicResults > 0) {
- OUT(4) << "result->addTypes({" << resultNames.front();
+ method.body() << " result->addTypes({" << resultNames.front();
for (int i = 1; i < numNonVariadicResults; ++i) {
- os << ", " << resultNames[i];
+ method.body() << ", " << resultNames[i];
}
- os << "});\n";
+ method.body() << "});\n";
}
if (hasVariadicResult) {
- OUT(4) << "result->addTypes(" << resultNames.back() << ");\n";
+ method.body() << " result->addTypes(" << resultNames.back() << ");\n";
}
} else {
- OUT(4) << "result->addTypes({";
auto resultType = formatv("{0}->getType()", getArgumentName(op, 0)).str();
- os << resultType;
+ method.body() << " result->addTypes({" << resultType;
for (unsigned i = 1; i != numResults; ++i)
- os << resultType;
- os << "});\n\n";
+ method.body() << resultType;
+ method.body() << "});\n\n";
}
}
int numNonVariadicOperands =
numOperands - static_cast<int>(hasVariadicOperand);
if (numNonVariadicOperands > 0) {
- OUT(4) << "result->addOperands({" << getArgumentName(op, 0);
+ method.body() << " result->addOperands({" << getArgumentName(op, 0);
for (int i = 1; i < numNonVariadicOperands; ++i) {
- os << ", " << getArgumentName(op, i);
+ method.body() << ", " << getArgumentName(op, i);
}
- os << "});\n";
+ method.body() << "});\n";
}
if (hasVariadicOperand) {
- OUT(4) << "result->addOperands(" << getArgumentName(op, numOperands - 1)
- << ");\n";
+ method.body() << " result->addOperands("
+ << getArgumentName(op, numOperands - 1) << ");\n";
}
// Push all attributes to the result
if (!namedAttr.attr.isDerivedAttr()) {
bool emitNotNullCheck = namedAttr.attr.isOptional();
if (emitNotNullCheck) {
- OUT(4) << formatv("if ({0}) ", namedAttr.name) << "{\n";
+ method.body() << formatv(" if ({0}) ", namedAttr.name) << "{\n";
}
- OUT(4) << formatv("result->addAttribute(\"{0}\", {1});\n",
- namedAttr.getName(), namedAttr.name);
+ method.body() << formatv(" result->addAttribute(\"{0}\", {1});\n",
+ namedAttr.getName(), namedAttr.name);
if (emitNotNullCheck) {
- OUT(4) << "}\n";
+ method.body() << " }\n";
}
}
}
- OUT(2) << "}\n";
}
-void OpEmitter::emitBuilder() {
- if (hasStringAttribute(def, "builder")) {
- // If a custom builder is given then print that out.
- auto builder = def.getValueAsString("builder");
- if (!builder.empty())
- os << builder << '\n';
+void OpEmitter::genBuilder() {
+ // Handle custom builders if provided.
+ // TODO(antiagainst): Create wrapper class for OpBuilder to hide the native
+ // TableGen API calls here.
+ {
+ auto *listInit = dyn_cast_or_null<ListInit>(def.getValueInit("builders"));
+ if (listInit) {
+ for (Init *init : listInit->getValues()) {
+ Record *builderDef = cast<DefInit>(init)->getDef();
+ StringRef params = builderDef->getValueAsString("params");
+ StringRef body = builderDef->getValueAsString("body");
+ bool hasBody = !body.empty();
+
+ auto &method =
+ opClass.newMethod("void", "build", params, OpMethod::MP_Static,
+ /*declOnly=*/!hasBody);
+ if (hasBody)
+ method.body() << body;
+ }
+ }
}
auto numResults = op.getNumResults();
// 1. Stand-alone parameters
- emitStandaloneParamBuilder(/*isAllSameType=*/false);
+ genStandaloneParamBuilder(/*isAllSameType=*/false);
// 2. Aggregated parameters
// Signature
- OUT(2) << "static void build(Builder* builder, OperationState* result, "
- << "ArrayRef<Type> resultTypes, ArrayRef<Value*> args, "
- "ArrayRef<NamedAttribute> attributes) {\n";
+ const char *const params =
+ "Builder *builder, OperationState *result, ArrayRef<Type> resultTypes, "
+ "ArrayRef<Value *> args, ArrayRef<NamedAttribute> attributes";
+ auto &method =
+ opClass.newMethod("void", "build", params, OpMethod::MP_Static);
// Result types
- OUT(4) << "assert(resultTypes.size()" << (hasVariadicResult ? " >= " : " == ")
- << numNonVariadicResults
- << "u && \"mismatched number of return types\");\n"
- << " result->addTypes(resultTypes);\n";
+ method.body() << " assert(resultTypes.size()"
+ << (hasVariadicResult ? " >= " : " == ")
+ << numNonVariadicResults
+ << "u && \"mismatched number of return types\");\n"
+ << " result->addTypes(resultTypes);\n";
// Operands
- OUT(4) << "assert(args.size()" << (hasVariadicOperand ? " >= " : " == ")
- << numNonVariadicOperands
- << "u && \"mismatched number of parameters\");\n"
- << " result->addOperands(args);\n\n";
+ method.body() << " assert(args.size()"
+ << (hasVariadicOperand ? " >= " : " == ")
+ << numNonVariadicOperands
+ << "u && \"mismatched number of parameters\");\n"
+ << " result->addOperands(args);\n\n";
// Attributes
if (op.getNumAttributes() > 0) {
- OUT(4) << "assert(!attributes.size() && \"no attributes expected\");\n"
- << " }\n";
+ method.body()
+ << " assert(!attributes.size() && \"no attributes expected\");\n";
} else {
- OUT(4) << "assert(attributes.size() >= " << op.getNumAttributes()
- << "u && \"not enough attributes\");\n"
- << " for (const auto& pair : attributes)\n"
- << " result->addAttribute(pair.first, pair.second);\n"
- << " }\n";
+ method.body() << " assert(attributes.size() >= " << op.getNumAttributes()
+ << "u && \"not enough attributes\");\n"
+ << " for (const auto& pair : attributes)\n"
+ << " result->addAttribute(pair.first, pair.second);\n";
}
// 3. Deduced result types
if (!op.hasVariadicResult() && op.hasTrait("SameOperandsAndResultType"))
- emitStandaloneParamBuilder(/*isAllSameType=*/true);
+ genStandaloneParamBuilder(/*isAllSameType=*/true);
}
-void OpEmitter::emitCanonicalizationPatterns() {
+void OpEmitter::genCanonicalizerDecls() {
if (!def.getValueAsBit("hasCanonicalizer"))
return;
- OUT(2) << "static void getCanonicalizationPatterns("
- << "OwningRewritePatternList &results, MLIRContext* context);\n";
+
+ const char *const params =
+ "OwningRewritePatternList &results, MLIRContext *context";
+ opClass.newMethod("void", "getCanonicalizationPatterns", params,
+ OpMethod::MP_Static, /*declOnly=*/true);
}
-void OpEmitter::emitFolders() {
+void OpEmitter::genFolderDecls() {
bool hasSingleResult = op.getNumResults() == 1;
+
if (def.getValueAsBit("hasConstantFolder")) {
if (hasSingleResult) {
- os << " Attribute constantFold(ArrayRef<Attribute> operands,\n"
- " MLIRContext *context) const;\n";
+ const char *const params =
+ "ArrayRef<Attribute> operands, MLIRContext *context";
+ opClass.newMethod("Attribute", "constantFold", params, OpMethod::MP_Const,
+ /*declOnly=*/true);
} else {
- os << " LogicalResult constantFold(ArrayRef<Attribute> operands,\n"
- << " SmallVectorImpl<Attribute> &results,"
- << "\n MLIRContext *context) const;\n";
+ const char *const params =
+ "ArrayRef<Attribute> operands, SmallVectorImpl<Attribute> &results, "
+ "MLIRContext *context";
+ opClass.newMethod("LogicalResult", "constantFold", params,
+ OpMethod::MP_Const, /*declOnly=*/true);
}
}
if (def.getValueAsBit("hasFolder")) {
if (hasSingleResult) {
- os << " Value *fold();\n";
+ opClass.newMethod("Value *", "fold", /*params=*/"", OpMethod::MP_None,
+ /*declOnly=*/true);
} else {
- os << " bool fold(SmallVectorImpl<Value *> &results);\n";
+ opClass.newMethod("bool", "fold", "SmallVectorImpl<Value *> &results",
+ OpMethod::MP_None,
+ /*declOnly=*/true);
}
}
}
-void OpEmitter::emitParser() {
+void OpEmitter::genParser() {
if (!hasStringAttribute(def, "parser"))
return;
- os << " static bool parse(OpAsmParser *parser, OperationState *result) {"
- << "\n " << def.getValueAsString("parser") << "\n }\n";
+
+ auto &method = opClass.newMethod(
+ "bool", "parse", "OpAsmParser *parser, OperationState *result",
+ OpMethod::MP_Static);
+ auto parser = def.getValueAsString("parser").ltrim().rtrim(" \t\v\f\r");
+ method.body() << " " << parser;
}
-void OpEmitter::emitPrinter() {
+void OpEmitter::genPrinter() {
auto valueInit = def.getValueInit("printer");
CodeInit *codeInit = dyn_cast<CodeInit>(valueInit);
if (!codeInit)
return;
- auto printer = codeInit->getValue();
- os << " void print(OpAsmPrinter *p) const {\n"
- << " " << printer << "\n }\n";
+ auto &method =
+ opClass.newMethod("void", "print", "OpAsmPrinter *p", OpMethod::MP_Const);
+ auto printer = codeInit->getValue().ltrim().rtrim(" \t\v\f\r");
+ method.body() << " " << printer;
}
-void OpEmitter::emitVerifier() {
+void OpEmitter::genVerifier() {
auto valueInit = def.getValueInit("verifier");
CodeInit *codeInit = dyn_cast<CodeInit>(valueInit);
bool hasCustomVerify = codeInit && !codeInit->getValue().empty();
if (!hasCustomVerify && op.getNumArgs() == 0 && op.getNumResults() == 0)
return;
- OUT(2) << "bool verify() const {\n";
+ auto &method =
+ opClass.newMethod("bool", "verify", /*params=*/"", OpMethod::MP_Const);
+ auto &body = method.body();
+
// Verify the attributes have the correct type.
for (const auto &namedAttr : op.getAttributes()) {
const auto &attr = namedAttr.attr;
auto name = namedAttr.getName();
if (!attr.hasStorageType() && !attr.hasDefaultValue()) {
// TODO: Some verification can be done even without storage type.
- OUT(4) << "if (!this->getAttr(\"" << name
- << "\")) return emitOpError(\"requires attribute '" << name
- << "'\");\n";
+ body << " if (!this->getAttr(\"" << name
+ << "\")) return emitOpError(\"requires attribute '" << name
+ << "'\");\n";
continue;
}
// If the attribute has a default value, then only verify the predicate if
// set. This does effectively assume that the default value is valid.
// TODO: verify the debug value is valid (perhaps in debug mode only).
- OUT(4) << "if (this->getAttr(\"" << name << "\")) {\n";
+ body << " if (this->getAttr(\"" << name << "\")) {\n";
}
- OUT(6) << "if (!this->getAttr(\"" << name << "\").dyn_cast_or_null<"
- << attr.getStorageType() << ">()) return emitOpError(\"requires "
- << attr.getDescription() << " attribute '" << name << "'\");\n";
+ body << " if (!this->getAttr(\"" << name << "\").dyn_cast_or_null<"
+ << attr.getStorageType() << ">()) return emitOpError(\"requires "
+ << attr.getDescription() << " attribute '" << name << "'\");\n";
auto attrPred = attr.getPredicate();
if (!attrPred.isNull()) {
- OUT(6) << formatv("if (!({0})) return emitOpError(\"attribute '{1}' "
- "failed to satisfy {2} attribute constraints\");\n",
- formatv(attrPred.getCondition(),
- formatv("this->getAttr(\"{0}\")", name)),
- name, attr.getDescription());
+ body << formatv(" if (!({0})) return emitOpError(\"attribute '{1}' "
+ "failed to satisfy {2} attribute constraints\");\n",
+ formatv(attrPred.getCondition(),
+ formatv("this->getAttr(\"{0}\")", name)),
+ name, attr.getDescription());
}
if (allowMissingAttr)
- OUT(4) << "}\n";
+ body << " }\n";
}
// Emits verification code for an operand or result.
- auto verifyValue = [this](const tblgen::NamedTypeConstraint &value, int index,
- bool isOperand) -> void {
+ auto verifyValue = [&](const tblgen::NamedTypeConstraint &value, int index,
+ bool isOperand) -> void {
// TODO: Handle variadic operand/result verification.
if (value.constraint.isVariadic())
return;
// concise code.
if (value.hasPredicate()) {
auto description = value.constraint.getDescription();
- OUT(4) << "if (!("
- << formatv(value.constraint.getConditionTemplate(),
- "this->getInstruction()->get" +
- Twine(isOperand ? "Operand" : "Result") + "(" +
- Twine(index) + ")->getType()")
- << ")) {\n";
- OUT(6) << "return emitOpError(\"" << (isOperand ? "operand" : "result")
- << " #" << index
- << (description.empty() ? " type precondition failed"
- : " must be " + Twine(description))
- << "\");";
- OUT(4) << "}\n";
+ body << " if (!("
+ << formatv(value.constraint.getConditionTemplate(),
+ "this->getInstruction()->get" +
+ Twine(isOperand ? "Operand" : "Result") + "(" +
+ Twine(index) + ")->getType()")
+ << "))\n";
+ body << " return emitOpError(\"" << (isOperand ? "operand" : "result")
+ << " #" << index
+ << (description.empty() ? " type precondition failed"
+ : " must be " + Twine(description))
+ << "\");\n";
}
};
for (auto &trait : op.getTraits()) {
if (auto t = dyn_cast<tblgen::PredOpTrait>(&trait)) {
- OUT(4) << "if (!"
- << formatv(t->getPredTemplate().c_str(),
- "(*this->getInstruction())")
- << ")\n";
- OUT(6) << "return emitOpError(\"failed to verify that "
- << t->getDescription() << "\");\n";
+ body << " if (!"
+ << formatv(t->getPredTemplate().c_str(), "(*this->getInstruction())")
+ << ")\n";
+ body << " return emitOpError(\"failed to verify that "
+ << t->getDescription() << "\");\n";
}
}
if (hasCustomVerify)
- OUT(4) << codeInit->getValue() << "\n";
+ body << codeInit->getValue() << "\n";
else
- OUT(4) << "return false;\n";
- OUT(2) << "}\n";
+ body << " return false;\n";
}
-void OpEmitter::emitTraits() {
+void OpEmitter::genTraits() {
auto numResults = op.getNumResults();
bool hasVariadicResult = op.hasVariadicResult();
// Add return size trait.
- os << ", OpTrait::";
if (hasVariadicResult) {
if (numResults == 1)
- os << "VariadicResults";
+ opClass.addTrait("VariadicResults");
else
- os << "AtLeastNResults<" << (numResults - 1) << ">::Impl";
+ opClass.addTrait("AtLeastNResults<" + Twine(numResults - 1) + ">::Impl");
} else {
switch (numResults) {
case 0:
- os << "ZeroResult";
+ opClass.addTrait("ZeroResult");
break;
case 1:
- os << "OneResult";
+ opClass.addTrait("OneResult");
break;
default:
- os << "NResults<" << numResults << ">::Impl";
+ opClass.addTrait("NResults<" + Twine(numResults) + ">::Impl");
break;
}
}
for (const auto &trait : op.getTraits()) {
if (auto opTrait = dyn_cast<tblgen::NativeOpTrait>(&trait))
- os << ", OpTrait::" << opTrait->getTrait();
+ opClass.addTrait(opTrait->getTrait());
}
// Add variadic size trait and normal op traits.
bool hasVariadicOperand = op.hasVariadicOperand();
// Add operand size trait.
- os << ", OpTrait::";
if (hasVariadicOperand) {
if (numOperands == 1)
- os << "VariadicOperands";
+ opClass.addTrait("VariadicOperands");
else
- os << "AtLeastNOperands<" << (numOperands - 1) << ">::Impl";
+ opClass.addTrait("AtLeastNOperands<" + Twine(numOperands - 1) +
+ ">::Impl");
} else {
switch (numOperands) {
case 0:
- os << "ZeroOperands";
+ opClass.addTrait("ZeroOperands");
break;
case 1:
- os << "OneOperand";
+ opClass.addTrait("OneOperand");
break;
default:
- os << "NOperands<" << numOperands << ">::Impl";
+ opClass.addTrait("NOperands<" + Twine(numOperands) + ">::Impl");
break;
}
}
}
+void OpEmitter::genOpNameGetter() {
+ auto &method = opClass.newMethod("StringRef", "getOperationName",
+ /*params=*/"", OpMethod::MP_Static);
+ method.body() << " return \"" << op.getOperationName() << "\";\n";
+}
+
// Emits the opcode enum and op classes.
-static void emitOpClasses(const std::vector<Record *> &defs, raw_ostream &os) {
+static void emitOpClasses(const std::vector<Record *> &defs, raw_ostream &os,
+ bool emitDecl) {
IfDefScope scope("GET_OP_CLASSES", os);
- for (auto *def : defs)
- OpEmitter::emit(*def, os);
+ for (auto *def : defs) {
+ if (emitDecl) {
+ os << formatv(opCommentHeader, getOpQualClassName(*def), "declarations");
+ OpEmitter::emitDecl(*def, os);
+ } else {
+ os << formatv(opCommentHeader, getOpQualClassName(*def), "definitions");
+ OpEmitter::emitDef(*def, os);
+ }
+ }
}
// Emits a comma-separated list of the ops.
static void emitOpList(const std::vector<Record *> &defs, raw_ostream &os) {
IfDefScope scope("GET_OP_LIST", os);
- bool first = true;
- for (auto &def : defs) {
- if (!first)
- os << ",";
- os << Operator(def).getQualCppClassName();
- first = false;
- }
+ interleave(
+ defs, [&os](Record *def) { os << getOpQualClassName(*def); },
+ [&os]() { os << ",\n"; });
+}
+
+static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
+ emitSourceFileHeader("Op Declarations", os);
+
+ const auto &defs = recordKeeper.getAllDerivedDefinitions("Op");
+ emitOpClasses(defs, os, /*emitDecl=*/true);
+
+ return false;
}
-static void emitOpDefinitions(const RecordKeeper &recordKeeper,
- raw_ostream &os) {
- emitSourceFileHeader("List of ops", os);
+static bool emitOpDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
+ emitSourceFileHeader("Op Definitions", os);
const auto &defs = recordKeeper.getAllDerivedDefinitions("Op");
emitOpList(defs, os);
- emitOpClasses(defs, os);
+ emitOpClasses(defs, os, /*emitDecl=*/false);
+
+ return false;
}
static mlir::GenRegistration
- genOpDefinitions("gen-op-definitions", "Generate op definitions",
- [](const RecordKeeper &records, raw_ostream &os) {
- emitOpDefinitions(records, os);
- return false;
- });
+ genOpDecls("gen-op-decls", "Generate op declarations",
+ [](const RecordKeeper &records, raw_ostream &os) {
+ return emitOpDecls(records, os);
+ });
+
+static mlir::GenRegistration genOpDefs("gen-op-defs", "Generate op definitions",
+ [](const RecordKeeper &records,
+ raw_ostream &os) {
+ return emitOpDefs(records, os);
+ });