static_cast<unsigned>(rhs));
}
+inline constexpr mlir::tblgen::Method::Properties &
+operator|=(mlir::tblgen::Method::Properties &lhs,
+ mlir::tblgen::Method::Properties rhs) {
+ return lhs = mlir::tblgen::Method::Properties(static_cast<unsigned>(lhs) |
+ static_cast<unsigned>(rhs));
+}
+
namespace mlir {
namespace tblgen {
/// Write the using declaration.
void writeDeclTo(raw_indented_ostream &os) const override;
+ /// Add a template parameter.
+ template <typename ParamT>
+ void addTemplateParam(ParamT param) {
+ templateParams.insert(stringify(param));
+ }
+
+ /// Add a list of template parameters.
+ template <typename ContainerT>
+ void addTemplateParams(ContainerT &&container) {
+ templateParams.insert(std::begin(container), std::end(container));
+ }
+
private:
/// The name of the declaration, or a resolved name to an inherited function.
std::string name;
/// The type that is being aliased. Leave empty for inheriting functions.
std::string value;
+ /// An optional list of class template parameters.
+ /// This is simply a ordered list of parameter names that are then added as
+ /// template type parameters when the using declaration is emitted.
+ SetVector<std::string, SmallVector<std::string>, StringSet<>> templateParams;
};
/// This class describes a class field.
/// returns a pointer to the new constructor.
template <Method::Properties Properties = Method::None, typename... Args>
Constructor *addConstructor(Args &&...args) {
+ Method::Properties defaultProperties = Method::Constructor;
+ // If the class has template parameters, the constructor has to be defined
+ // inline.
+ if (!templateParams.empty())
+ defaultProperties |= Method::Inline;
return addConstructorAndPrune(Constructor(getClassName(),
- Properties | Method::Constructor,
+ Properties | defaultProperties,
std::forward<Args>(args)...));
}
/// Returns null if the method was not added (because an existing method would
/// make it redundant). Else, returns a pointer to the new method.
template <Method::Properties Properties = Method::None, typename RetTypeT,
+ typename NameT>
+ Method *addMethod(RetTypeT &&retType, NameT &&name,
+ Method::Properties properties,
+ ArrayRef<MethodParameter> parameters) {
+ // If the class has template parameters, the has to defined inline.
+ if (!templateParams.empty())
+ properties |= Method::Inline;
+ return addMethodAndPrune(Method(std::forward<RetTypeT>(retType),
+ std::forward<NameT>(name),
+ Properties | properties, parameters));
+ }
+
+ /// Add a method with statically-known properties.
+ template <Method::Properties Properties = Method::None, typename RetTypeT,
+ typename NameT>
+ Method *addMethod(RetTypeT &&retType, NameT &&name,
+ ArrayRef<MethodParameter> parameters) {
+ return addMethod(std::forward<RetTypeT>(retType), std::forward<NameT>(name),
+ Properties, parameters);
+ }
+
+ template <Method::Properties Properties = Method::None, typename RetTypeT,
typename NameT, typename... Args>
Method *addMethod(RetTypeT &&retType, NameT &&name,
Method::Properties properties, Args &&...args) {
- return addMethodAndPrune(
- Method(std::forward<RetTypeT>(retType), std::forward<NameT>(name),
- Properties | properties, std::forward<Args>(args)...));
+ return addMethod(std::forward<RetTypeT>(retType), std::forward<NameT>(name),
+ properties | Properties, {std::forward<Args>(args)...});
}
/// Add a method with statically-known properties.
/// Add a parent class.
ParentClass &addParent(ParentClass parent);
+ /// Add a template parameter.
+ template <typename ParamT>
+ void addTemplateParam(ParamT param) {
+ templateParams.insert(stringify(param));
+ }
+
+ /// Add a list of template parameters.
+ template <typename ContainerT>
+ void addTemplateParams(ContainerT &&container) {
+ templateParams.insert(std::begin(container), std::end(container));
+ }
+
/// Return the C++ name of the class.
StringRef getClassName() const { return className; }
/// A list of declarations in the class, emitted in order.
std::vector<std::unique_ptr<ClassDeclaration>> declarations;
+
+ /// An optional list of class template parameters.
+ SetVector<std::string, SmallVector<std::string>, StringSet<>> templateParams;
};
} // namespace tblgen
///
/// {0}: The name of the segment attribute.
/// {1}: The index of the main operand.
+/// {2}: The range type of adaptor.
static const char *const variadicOfVariadicAdaptorCalcCode = R"(
auto tblgenTmpOperands = getODSOperands({1});
auto sizes = {0}();
- ::llvm::SmallVector<::mlir::ValueRange> tblgenTmpOperandGroups;
+ ::llvm::SmallVector<{2}> tblgenTmpOperandGroups;
for (int i = 0, e = sizes.size(); i < e; ++i) {{
tblgenTmpOperandGroups.push_back(tblgenTmpOperands.take_front(sizes[i]));
tblgenTmpOperands = tblgenTmpOperands.drop_front(sizes[i]);
// Generates the code to compute the start and end index of an operand or result
// range.
template <typename RangeT>
-static void
-generateValueRangeStartAndEnd(Class &opClass, StringRef methodName,
- int numVariadic, int numNonVariadic,
- StringRef rangeSizeCall, bool hasAttrSegmentSize,
- StringRef sizeAttrInit, RangeT &&odsValues) {
+static void generateValueRangeStartAndEnd(
+ Class &opClass, bool isGenericAdaptorBase, StringRef methodName,
+ int numVariadic, int numNonVariadic, StringRef rangeSizeCall,
+ bool hasAttrSegmentSize, StringRef sizeAttrInit, RangeT &&odsValues) {
+
+ SmallVector<MethodParameter> parameters{MethodParameter("unsigned", "index")};
+ if (isGenericAdaptorBase) {
+ parameters.emplace_back("unsigned", "odsOperandsSize");
+ // The range size is passed per parameter for generic adaptor bases as
+ // using the rangeSizeCall would require the operands, which are not
+ // accessible in the base class.
+ rangeSizeCall = "odsOperandsSize";
+ }
+
auto *method = opClass.addMethod("std::pair<unsigned, unsigned>", methodName,
- MethodParameter("unsigned", "index"));
+ parameters);
if (!method)
return;
auto &body = method->body();
}
}
-static std::string generateTypeForGetter(bool isAdaptor,
- const NamedTypeConstraint &value) {
+static std::string generateTypeForGetter(const NamedTypeConstraint &value) {
std::string str = "::mlir::Value";
/// If the CPPClassName is not a fully qualified type. Uses of types
/// across Dialect fail because they are not in the correct namespace. So we
/// https://github.com/llvm/llvm-project/issues/57279.
/// Adaptor will have values that are not from the type of their operation and
/// this is expected, so we dont generate TypedValue for Adaptor
- if (!isAdaptor && value.constraint.getCPPClassName() != "::mlir::Type" &&
+ if (value.constraint.getCPPClassName() != "::mlir::Type" &&
StringRef(value.constraint.getCPPClassName()).startswith("::"))
str = llvm::formatv("::mlir::TypedValue<{0}>",
value.constraint.getCPPClassName())
// "{0}" marker in the pattern. Note that the pattern should work for any kind
// of ops, in particular for one-operand ops that may not have the
// `getOperand(unsigned)` method.
-static void generateNamedOperandGetters(const Operator &op, Class &opClass,
- bool isAdaptor, StringRef sizeAttrInit,
- StringRef rangeType,
- StringRef rangeBeginCall,
- StringRef rangeSizeCall,
- StringRef getOperandCallPattern) {
+static void
+generateNamedOperandGetters(const Operator &op, Class &opClass,
+ Class *genericAdaptorBase, StringRef sizeAttrInit,
+ StringRef rangeType, StringRef rangeElementType,
+ StringRef rangeBeginCall, StringRef rangeSizeCall,
+ StringRef getOperandCallPattern) {
const int numOperands = op.getNumOperands();
const int numVariadicOperands = op.getNumVariableLengthOperands();
const int numNormalOperands = numOperands - numVariadicOperands;
// First emit a few "sink" getter methods upon which we layer all nicer named
// getter methods.
- generateValueRangeStartAndEnd(opClass, "getODSOperandIndexAndLength",
- numVariadicOperands, numNormalOperands,
- rangeSizeCall, attrSizedOperands, sizeAttrInit,
- const_cast<Operator &>(op).getOperands());
+ // If generating for an adaptor, the method is put into the non-templated
+ // generic base class, to not require being defined in the header.
+ // Since the operand size can't be determined from the base class however,
+ // it has to be passed as an additional argument. The trampoline below
+ // generates the function with the same signature as the Op in the generic
+ // adaptor.
+ bool isGenericAdaptorBase = genericAdaptorBase != nullptr;
+ generateValueRangeStartAndEnd(
+ /*opClass=*/isGenericAdaptorBase ? *genericAdaptorBase : opClass,
+ isGenericAdaptorBase,
+ /*methodName=*/"getODSOperandIndexAndLength", numVariadicOperands,
+ numNormalOperands, rangeSizeCall, attrSizedOperands, sizeAttrInit,
+ const_cast<Operator &>(op).getOperands());
+ if (isGenericAdaptorBase) {
+ // Generate trampoline for calling 'getODSOperandIndexAndLength' with just
+ // the index. This just calls the implementation in the base class but
+ // passes the operand size as parameter.
+ Method *method = opClass.addMethod("std::pair<unsigned, unsigned>",
+ "getODSOperandIndexAndLength",
+ MethodParameter("unsigned", "index"));
+ ERROR_IF_PRUNED(method, "getODSOperandIndexAndLength", op);
+ MethodBody &body = method->body();
+ body.indent() << formatv(
+ "return Base::getODSOperandIndexAndLength(index, {0});", rangeSizeCall);
+ }
auto *m = opClass.addMethod(rangeType, "getODSOperands",
MethodParameter("unsigned", "index"));
continue;
std::string name = op.getGetterName(operand.name);
if (operand.isOptional()) {
- m = opClass.addMethod(generateTypeForGetter(isAdaptor, operand), name);
+ m = opClass.addMethod(isGenericAdaptorBase
+ ? rangeElementType
+ : generateTypeForGetter(operand),
+ name);
ERROR_IF_PRUNED(m, name, op);
- m->body() << " auto operands = getODSOperands(" << i << ");\n"
- << " return operands.empty() ? ::mlir::Value() : "
- "*operands.begin();";
+ m->body().indent() << formatv(
+ "auto operands = getODSOperands({0});\n"
+ "return operands.empty() ? {1}{{} : *operands.begin();",
+ i, rangeElementType);
} else if (operand.isVariadicOfVariadic()) {
std::string segmentAttr = op.getGetterName(
operand.constraint.getVariadicOfVariadicSegmentSizeAttr());
- if (isAdaptor) {
- m = opClass.addMethod("::llvm::SmallVector<::mlir::ValueRange>", name);
+ if (genericAdaptorBase) {
+ m = opClass.addMethod("::llvm::SmallVector<" + rangeType + ">", name);
ERROR_IF_PRUNED(m, name, op);
m->body() << llvm::formatv(variadicOfVariadicAdaptorCalcCode,
- segmentAttr, i);
+ segmentAttr, i, rangeType);
continue;
}
ERROR_IF_PRUNED(m, name, op);
m->body() << " return getODSOperands(" << i << ");";
} else {
- m = opClass.addMethod(generateTypeForGetter(isAdaptor, operand), name);
+ m = opClass.addMethod(isGenericAdaptorBase
+ ? rangeElementType
+ : generateTypeForGetter(operand),
+ name);
ERROR_IF_PRUNED(m, name, op);
m->body() << " return *getODSOperands(" << i << ").begin();";
}
generateNamedOperandGetters(
op, opClass,
- /*isAdaptor=*/false,
+ /*genericAdaptorBase=*/nullptr,
/*sizeAttrInit=*/attrSizeInitCode,
/*rangeType=*/"::mlir::Operation::operand_range",
+ /*rangeElementType=*/"::mlir::Value",
/*rangeBeginCall=*/"getOperation()->operand_begin()",
/*rangeSizeCall=*/"getOperation()->getNumOperands()",
/*getOperandCallPattern=*/"getOperation()->getOperand({0})");
}
generateValueRangeStartAndEnd(
- opClass, "getODSResultIndexAndLength", numVariadicResults,
- numNormalResults, "getOperation()->getNumResults()", attrSizedResults,
- attrSizeInitCode, op.getResults());
+ opClass, /*isGenericAdaptorBase=*/false, "getODSResultIndexAndLength",
+ numVariadicResults, numNormalResults, "getOperation()->getNumResults()",
+ attrSizedResults, attrSizeInitCode, op.getResults());
auto *m =
opClass.addMethod("::mlir::Operation::result_range", "getODSResults",
continue;
std::string name = op.getGetterName(result.name);
if (result.isOptional()) {
- m = opClass.addMethod(generateTypeForGetter(/*isAdaptor=*/false, result),
- name);
+ m = opClass.addMethod(generateTypeForGetter(result), name);
ERROR_IF_PRUNED(m, name, op);
m->body()
<< " auto results = getODSResults(" << i << ");\n"
ERROR_IF_PRUNED(m, name, op);
m->body() << " return getODSResults(" << i << ");";
} else {
- m = opClass.addMethod(generateTypeForGetter(/*isAdaptor=*/false, result),
- name);
+ m = opClass.addMethod(generateTypeForGetter(result), name);
ERROR_IF_PRUNED(m, name, op);
m->body() << " return *getODSResults(" << i << ").begin();";
}
namespace {
// Helper class to emit Op operand adaptors to an output stream. Operand
-// adaptors are wrappers around ArrayRef<Value> that provide named operand
+// adaptors are wrappers around random access ranges that provide named operand
// getters identical to those defined in the Op.
+// This currently generates 3 classes per Op:
+// * A Base class within the 'detail' namespace, which contains all logic and
+// members independent of the random access range that is indexed into.
+// In other words, it contains all the attribute and region getters.
+// * A templated class named '{OpName}GenericAdaptor' with a template parameter
+// 'RangeT' that is indexed into by the getters to access the operands.
+// It contains all getters to access operands and inherits from the previous
+// class.
+// * A class named '{OpName}Adaptor', which inherits from the 'GenericAdaptor'
+// with 'mlir::ValueRange' as template parameter. It adds a constructor from
+// an instance of the op type and a verify function.
class OpOperandAdaptorEmitter {
public:
static void
// The operation for which to emit an adaptor.
const Operator &op;
- // The generated adaptor class.
+ // The generated adaptor classes.
+ Class genericAdaptorBase;
+ Class genericAdaptor;
Class adaptor;
// The emitter containing all of the locally emitted verification functions.
OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
const Operator &op,
const StaticVerifierFunctionEmitter &staticVerifierEmitter)
- : op(op), adaptor(op.getAdaptorName()),
+ : op(op), genericAdaptorBase(op.getGenericAdaptorName() + "Base"),
+ genericAdaptor(op.getGenericAdaptorName()), adaptor(op.getAdaptorName()),
staticVerifierEmitter(staticVerifierEmitter),
emitHelper(op, /*emitForOp=*/false) {
- adaptor.addField("::mlir::ValueRange", "odsOperands");
- adaptor.addField("::mlir::DictionaryAttr", "odsAttrs");
- adaptor.addField("::mlir::RegionRange", "odsRegions");
- adaptor.addField("::std::optional<::mlir::OperationName>", "odsOpName");
+
+ genericAdaptorBase.declare<VisibilityDeclaration>(Visibility::Protected);
+ genericAdaptorBase.declare<Field>("::mlir::DictionaryAttr", "odsAttrs");
+ genericAdaptorBase.declare<Field>("::mlir::RegionRange", "odsRegions");
+ genericAdaptorBase.declare<Field>("::std::optional<::mlir::OperationName>",
+ "odsOpName");
+
+ genericAdaptor.addTemplateParam("RangeT");
+ genericAdaptor.addField("RangeT", "odsOperands");
+ genericAdaptor.addParent(
+ ParentClass("detail::" + genericAdaptorBase.getClassName()));
+ genericAdaptor.declare<UsingDeclaration>(
+ "ValueT", "::llvm::detail::ValueOfRange<RangeT>");
+ genericAdaptor.declare<UsingDeclaration>(
+ "Base", "detail::" + genericAdaptorBase.getClassName());
const auto *attrSizedOperands =
- op.getTrait("::m::OpTrait::AttrSizedOperandSegments");
+ op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
{
SmallVector<MethodParameter> paramList;
- paramList.emplace_back("::mlir::ValueRange", "values");
paramList.emplace_back("::mlir::DictionaryAttr", "attrs",
attrSizedOperands ? "" : "nullptr");
paramList.emplace_back("::mlir::RegionRange", "regions", "{}");
- auto *constructor = adaptor.addConstructor(std::move(paramList));
-
- constructor->addMemberInitializer("odsOperands", "values");
- constructor->addMemberInitializer("odsAttrs", "attrs");
- constructor->addMemberInitializer("odsRegions", "regions");
+ auto *baseConstructor = genericAdaptorBase.addConstructor(paramList);
+ baseConstructor->addMemberInitializer("odsAttrs", "attrs");
+ baseConstructor->addMemberInitializer("odsRegions", "regions");
- MethodBody &body = constructor->body();
+ MethodBody &body = baseConstructor->body();
body.indent() << "if (odsAttrs)\n";
body.indent() << formatv(
"odsOpName.emplace(\"{0}\", odsAttrs.getContext());\n",
op.getOperationName());
- }
- {
- auto *constructor =
- adaptor.addConstructor(MethodParameter(op.getCppClassName(), "op"));
- constructor->addMemberInitializer("odsOperands", "op->getOperands()");
- constructor->addMemberInitializer("odsAttrs", "op->getAttrDictionary()");
- constructor->addMemberInitializer("odsRegions", "op->getRegions()");
- constructor->addMemberInitializer("odsOpName", "op->getName()");
+ paramList.insert(paramList.begin(), MethodParameter("RangeT", "values"));
+ auto *constructor = genericAdaptor.addConstructor(std::move(paramList));
+ constructor->addMemberInitializer("Base", "attrs, regions");
+ constructor->addMemberInitializer("odsOperands", "values");
}
std::string sizeAttrInit;
sizeAttrInit = formatv(adapterSegmentSizeAttrInitCode,
emitHelper.getAttr(operandSegmentAttrName));
}
- generateNamedOperandGetters(op, adaptor,
- /*isAdaptor=*/true, sizeAttrInit,
- /*rangeType=*/"::mlir::ValueRange",
+ generateNamedOperandGetters(op, genericAdaptor,
+ /*genericAdaptorBase=*/&genericAdaptorBase,
+ /*sizeAttrInit=*/sizeAttrInit,
+ /*rangeType=*/"RangeT",
+ /*rangeElementType=*/"ValueT",
/*rangeBeginCall=*/"odsOperands.begin()",
/*rangeSizeCall=*/"odsOperands.size()",
/*getOperandCallPattern=*/"odsOperands[{0}]");
// Any invalid overlap for `getOperands` will have been diagnosed before here
// already.
- if (auto *m = adaptor.addMethod("::mlir::ValueRange", "getOperands"))
+ if (auto *m = genericAdaptor.addMethod("RangeT", "getOperands"))
m->body() << " return odsOperands;";
FmtContext fctx;
// Generate named accessor with Attribute return type.
auto emitAttrWithStorageType = [&](StringRef name, StringRef emitName,
Attribute attr) {
- auto *method = adaptor.addMethod(attr.getStorageType(), emitName + "Attr");
+ auto *method =
+ genericAdaptorBase.addMethod(attr.getStorageType(), emitName + "Attr");
ERROR_IF_PRUNED(method, "Adaptor::" + emitName + "Attr", op);
auto &body = method->body().indent();
body << "assert(odsAttrs && \"no attributes when constructing adapter\");\n"
};
{
- auto *m = adaptor.addMethod("::mlir::DictionaryAttr", "getAttributes");
+ auto *m =
+ genericAdaptorBase.addMethod("::mlir::DictionaryAttr", "getAttributes");
ERROR_IF_PRUNED(m, "Adaptor::getAttributes", op);
m->body() << " return odsAttrs;";
}
continue;
std::string emitName = op.getGetterName(name);
emitAttrWithStorageType(name, emitName, attr);
- emitAttrGetterWithReturnType(fctx, adaptor, op, emitName, attr);
+ emitAttrGetterWithReturnType(fctx, genericAdaptorBase, op, emitName, attr);
}
unsigned numRegions = op.getNumRegions();
// Generate the accessors for a variadic region.
std::string name = op.getGetterName(region.name);
if (region.isVariadic()) {
- auto *m = adaptor.addMethod("::mlir::RegionRange", name);
+ auto *m = genericAdaptorBase.addMethod("::mlir::RegionRange", name);
ERROR_IF_PRUNED(m, "Adaptor::" + name, op);
m->body() << formatv(" return odsRegions.drop_front({0});", i);
continue;
}
- auto *m = adaptor.addMethod("::mlir::Region &", name);
+ auto *m = genericAdaptorBase.addMethod("::mlir::Region &", name);
ERROR_IF_PRUNED(m, "Adaptor::" + name, op);
m->body() << formatv(" return *odsRegions[{0}];", i);
}
if (numRegions > 0) {
// Any invalid overlap for `getRegions` will have been diagnosed before here
// already.
- if (auto *m = adaptor.addMethod("::mlir::RegionRange", "getRegions"))
+ if (auto *m =
+ genericAdaptorBase.addMethod("::mlir::RegionRange", "getRegions"))
m->body() << " return odsRegions;";
}
+ StringRef genericAdaptorClassName = genericAdaptor.getClassName();
+ adaptor.addParent(ParentClass(genericAdaptorClassName))
+ .addTemplateParam("::mlir::ValueRange");
+ adaptor.declare<VisibilityDeclaration>(Visibility::Public);
+ adaptor.declare<UsingDeclaration>(genericAdaptorClassName +
+ "::" + genericAdaptorClassName);
+ {
+ // Constructor taking the Op as single parameter.
+ auto *constructor =
+ adaptor.addConstructor(MethodParameter(op.getCppClassName(), "op"));
+ constructor->addMemberInitializer(
+ adaptor.getClassName(),
+ "op->getOperands(), op->getAttrDictionary(), op->getRegions()");
+ }
+
// Add verification function.
addVerification();
+
+ genericAdaptorBase.finalize();
+ genericAdaptor.finalize();
adaptor.finalize();
}
const Operator &op,
const StaticVerifierFunctionEmitter &staticVerifierEmitter,
raw_ostream &os) {
- OpOperandAdaptorEmitter(op, staticVerifierEmitter).adaptor.writeDeclTo(os);
+ OpOperandAdaptorEmitter emitter(op, staticVerifierEmitter);
+ {
+ NamespaceEmitter ns(os, "detail");
+ emitter.genericAdaptorBase.writeDeclTo(os);
+ }
+ emitter.genericAdaptor.writeDeclTo(os);
+ emitter.adaptor.writeDeclTo(os);
}
void OpOperandAdaptorEmitter::emitDef(
const Operator &op,
const StaticVerifierFunctionEmitter &staticVerifierEmitter,
raw_ostream &os) {
- OpOperandAdaptorEmitter(op, staticVerifierEmitter).adaptor.writeDefTo(os);
+ OpOperandAdaptorEmitter emitter(op, staticVerifierEmitter);
+ {
+ NamespaceEmitter ns(os, "detail");
+ emitter.genericAdaptorBase.writeDefTo(os);
+ }
+ emitter.genericAdaptor.writeDefTo(os);
+ emitter.adaptor.writeDefTo(os);
}
// Emits the opcode enum and op classes.