-// RUN: mlir-tblgen -gen-python-op-bindings -bind-dialect=test -I %S/../../include %s | FileCheck %s
+// RUN: mlir-tblgen -gen-python-op-bindings -bind-dialect=test -I %S/../../include -I %S/../../lib/Bindings/Python %s | FileCheck %s
include "mlir/IR/OpBase.td"
+include "Attributes.td"
// CHECK: @_cext.register_dialect
// CHECK: class _Dialect(_ir.Dialect):
Optional<AnyType>:$variadic2);
}
+
+// CHECK: @_cext.register_operation(_Dialect)
+// CHECK: class AttributedOp(_ir.OpView):
+// CHECK-LABEL: OPERATION_NAME = "test.attributed_op"
+def AttributedOp : TestOp<"attributed_op"> {
+ // CHECK: def __init__(self, i32attr, optionalF32Attr, unitAttr, in_, loc=None, ip=None):
+ // CHECK: operands = []
+ // CHECK: results = []
+ // CHECK: attributes = {}
+ // CHECK: attributes["i32attr"] = i32attr
+ // CHECK: if optionalF32Attr is not None: attributes["optionalF32Attr"] = optionalF32Attr
+ // CHECK: if bool(unitAttr): attributes["unitAttr"] = _ir.UnitAttr.get(
+ // CHECK: _ir.Location.current.context if loc is None else loc.context)
+ // CHECK: attributes["in"] = in_
+ // CHECK: super().__init__(_ir.Operation.create(
+ // CHECK: "test.attributed_op", attributes=attributes, operands=operands, results=results,
+ // CHECK: loc=loc, ip=ip))
+
+ // CHECK: @property
+ // CHECK: def i32attr(self):
+ // CHECK: return _ir.IntegerAttr(self.operation.attributes["i32attr"])
+
+ // CHECK: @property
+ // CHECK: def optionalF32Attr(self):
+ // CHECK: if "optionalF32Attr" not in self.operation.attributes:
+ // CHECK: return None
+ // CHECK: return _ir.FloatAttr(self.operation.attributes["optionalF32Attr"])
+
+ // CHECK: @property
+ // CHECK: def unitAttr(self):
+ // CHECK: return "unitAttr" in self.operation.attributes
+
+ // CHECK: @property
+ // CHECK: def in_(self):
+ // CHECK: return _ir.IntegerAttr(self.operation.attributes["in"])
+ let arguments = (ins I32Attr:$i32attr, OptionalAttr<F32Attr>:$optionalF32Attr,
+ UnitAttr:$unitAttr, I32Attr:$in);
+}
+
+// CHECK: @_cext.register_operation(_Dialect)
+// CHECK: class AttributedOpWithOperands(_ir.OpView):
+// CHECK-LABEL: OPERATION_NAME = "test.attributed_op_with_operands"
+def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
+ // CHECK: def __init__(self, _gen_arg_0, in_, _gen_arg_2, is_, loc=None, ip=None):
+ // CHECK: operands = []
+ // CHECK: results = []
+ // CHECK: attributes = {}
+ // CHECK: operands.append(_gen_arg_0)
+ // CHECK: operands.append(_gen_arg_2)
+ // CHECK: if bool(in_): attributes["in"] = _ir.UnitAttr.get(
+ // CHECK: _ir.Location.current.context if loc is None else loc.context)
+ // CHECK: if is_ is not None: attributes["is"] = is_
+ // CHECK: super().__init__(_ir.Operation.create(
+ // CHECK: "test.attributed_op_with_operands", attributes=attributes, operands=operands, results=results,
+ // CHECK: loc=loc, ip=ip))
+
+ // CHECK: @property
+ // CHECK: def in_(self):
+ // CHECK: return "in" in self.operation.attributes
+
+ // CHECK: @property
+ // CHECK: def is_(self):
+ // CHECK: if "is" not in self.operation.attributes:
+ // CHECK: return None
+ // CHECK: return _ir.FloatAttr(self.operation.attributes["is"])
+ let arguments = (ins I32, UnitAttr:$in, F32, OptionalAttr<F32Attr>:$is);
+}
+
+
// CHECK: @_cext.register_operation(_Dialect)
// CHECK: class EmptyOp(_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.empty"
constexpr const char *opVariadicSegmentOptionalTrailingTemplate =
R"Py([0] if len({0}_range) > 0 else None)Py";
+/// Template for an operation attribute getter:
+/// {0} is the name of the attribute sanitized for Python;
+/// {1} is the Python type of the attribute;
+/// {2} os the original name of the attribute.
+constexpr const char *attributeGetterTemplate = R"Py(
+ @property
+ def {0}(self):
+ return {1}(self.operation.attributes["{2}"])
+)Py";
+
+/// Template for an optional operation attribute getter:
+/// {0} is the name of the attribute sanitized for Python;
+/// {1} is the Python type of the attribute;
+/// {2} is the original name of the attribute.
+constexpr const char *optionalAttributeGetterTemplate = R"Py(
+ @property
+ def {0}(self):
+ if "{2}" not in self.operation.attributes:
+ return None
+ return {1}(self.operation.attributes["{2}"])
+)Py";
+
+/// Template for a accessing a unit operation attribute, returns True of the
+/// unit attribute is present, False otherwise (unit attributes have meaning
+/// by mere presence):
+/// {0} is the name of the attribute sanitized for Python,
+/// {1} is the original name of the attribute.
+constexpr const char *unitAttributeGetterTemplate = R"Py(
+ @property
+ def {0}(self):
+ return "{1}" in self.operation.attributes
+)Py";
+
static llvm::cl::OptionCategory
clOpPythonBindingCat("Options for -gen-python-op-bindings");
llvm::cl::desc("The dialect to run the generator for"),
llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
+using AttributeClasses = DenseMap<StringRef, StringRef>;
+
/// Checks whether `str` is a Python keyword.
static bool isPythonKeyword(StringRef str) {
static llvm::StringSet<> keywords(
return op.getResult(i);
}
-/// Emits accessor to Op operands.
+/// Emits accessors to Op operands.
static void emitOperandAccessors(const Operator &op, raw_ostream &os) {
auto getNumVariadic = [](const Operator &oper) {
return oper.getNumVariableLengthOperands();
getOperand);
}
-/// Emits access or Op results.
+/// Emits accessors Op results.
static void emitResultAccessors(const Operator &op, raw_ostream &os) {
auto getNumVariadic = [](const Operator &oper) {
return oper.getNumVariableLengthResults();
getResult);
}
+/// Emits accessors to Op attributes.
+static void emitAttributeAccessors(const Operator &op,
+ const AttributeClasses &attributeClasses,
+ raw_ostream &os) {
+ for (const auto &namedAttr : op.getAttributes()) {
+ // Skip "derived" attributes because they are just C++ functions that we
+ // don't currently expose.
+ if (namedAttr.attr.isDerivedAttr())
+ continue;
+
+ if (namedAttr.name.empty())
+ continue;
+
+ // Unit attributes are handled specially.
+ if (namedAttr.attr.getStorageType().trim().equals("::mlir::UnitAttr")) {
+ os << llvm::formatv(unitAttributeGetterTemplate,
+ sanitizeName(namedAttr.name), namedAttr.name);
+ continue;
+ }
+
+ // Other kinds of attributes need a mapping to a Python type.
+ if (!attributeClasses.count(namedAttr.attr.getStorageType().trim()))
+ continue;
+
+ os << llvm::formatv(
+ namedAttr.attr.isOptional() ? optionalAttributeGetterTemplate
+ : attributeGetterTemplate,
+ sanitizeName(namedAttr.name),
+ attributeClasses.lookup(namedAttr.attr.getStorageType()),
+ namedAttr.name);
+ }
+}
+
/// Template for the default auto-generated builder.
/// {0} is the operation name;
/// {1} is a comma-separated list of builder arguments, including the trailing
constexpr const char *variadicSegmentTemplate =
"{0}_segment_sizes.append(len({1}))";
-/// Populates `builderArgs` with the list of `__init__` arguments that
-/// correspond to either operands or results of `op`, and `builderLines` with
-/// additional lines that are required in the builder. `kind` must be either
-/// "operand" or "result". `unnamedTemplate` is used to generate names for
-/// operands or results that don't have the name in ODS.
+/// Template for setting an attribute in the operation builder.
+/// {0} is the attribute name;
+/// {1} is the builder argument name.
+constexpr const char *initAttributeTemplate = R"Py(attributes["{0}"] = {1})Py";
+
+/// Template for setting an optional attribute in the operation builder.
+/// {0} is the attribute name;
+/// {1} is the builder argument name.
+constexpr const char *initOptionalAttributeTemplate =
+ R"Py(if {1} is not None: attributes["{0}"] = {1})Py";
+
+constexpr const char *initUnitAttributeTemplate =
+ R"Py(if bool({1}): attributes["{0}"] = _ir.UnitAttr.get(
+ _ir.Location.current.context if loc is None else loc.context))Py";
+
+/// Populates `builderArgs` with the Python-compatible names of builder function
+/// arguments, first the results, then the intermixed attributes and operands in
+/// the same order as they appear in the `arguments` field of the op definition.
+/// Additionally, `operandNames` is populated with names of operands in their
+/// order of appearance.
+static void
+populateBuilderArgs(const Operator &op,
+ llvm::SmallVectorImpl<std::string> &builderArgs,
+ llvm::SmallVectorImpl<std::string> &operandNames) {
+ for (int i = 0, e = op.getNumResults(); i < e; ++i) {
+ std::string name = op.getResultName(i).str();
+ if (name.empty())
+ name = llvm::formatv("_gen_res_{0}", i);
+ name = sanitizeName(name);
+ builderArgs.push_back(name);
+ }
+ for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
+ std::string name = op.getArgName(i).str();
+ if (name.empty())
+ name = llvm::formatv("_gen_arg_{0}", i);
+ name = sanitizeName(name);
+ builderArgs.push_back(name);
+ if (!op.getArg(i).is<NamedAttribute *>())
+ operandNames.push_back(name);
+ }
+}
+
+/// Populates `builderLines` with additional lines that are required in the
+/// builder to set up operation attributes. `argNames` is expected to contain
+/// the names of builder arguments that correspond to op arguments, i.e. to the
+/// operands and attributes in the same order as they appear in the `arguments`
+/// field.
+static void
+populateBuilderLinesAttr(const Operator &op,
+ llvm::ArrayRef<std::string> argNames,
+ llvm::SmallVectorImpl<std::string> &builderLines) {
+ for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
+ Argument arg = op.getArg(i);
+ auto *attribute = arg.dyn_cast<NamedAttribute *>();
+ if (!attribute)
+ continue;
+
+ // Unit attributes are handled specially.
+ if (attribute->attr.getStorageType().trim().equals("::mlir::UnitAttr")) {
+ builderLines.push_back(llvm::formatv(initUnitAttributeTemplate,
+ attribute->name, argNames[i]));
+ continue;
+ }
+
+ builderLines.push_back(llvm::formatv(attribute->attr.isOptional()
+ ? initOptionalAttributeTemplate
+ : initAttributeTemplate,
+ attribute->name, argNames[i]));
+ }
+}
+
+/// Populates `builderLines` with additional lines that are required in the
+/// builder. `kind` must be either "operand" or "result". `names` contains the
+/// names of init arguments that correspond to the elements.
static void populateBuilderLines(
- const Operator &op, const char *kind, const char *unnamedTemplate,
- llvm::SmallVectorImpl<std::string> &builderArgs,
+ const Operator &op, const char *kind, llvm::ArrayRef<std::string> names,
llvm::SmallVectorImpl<std::string> &builderLines,
llvm::function_ref<int(const Operator &)> getNumElements,
llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
// For each element, find or generate a name.
for (int i = 0, e = getNumElements(op); i < e; ++i) {
const NamedTypeConstraint &element = getElement(op, i);
- std::string name = element.name.str();
- if (name.empty())
- name = llvm::formatv(unnamedTemplate, i).str();
- name = sanitizeName(name);
- builderArgs.push_back(name);
+ std::string name = names[i];
// Choose the formatting string based on the element kind.
llvm::StringRef formatString, segmentFormatString;
/// Emits a default builder constructing an operation from the list of its
/// result types, followed by a list of its operands.
static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) {
- // TODO: support attribute types.
- if (op.getNumNativeAttributes() != 0)
- return;
-
// If we are asked to skip default builders, comply.
if (op.skipDefaultBuilders())
return;
llvm::SmallVector<std::string, 8> builderArgs;
llvm::SmallVector<std::string, 8> builderLines;
- builderArgs.reserve(op.getNumOperands() + op.getNumResults());
- populateBuilderLines(op, "result", "_gen_res_{0}", builderArgs, builderLines,
- getNumResults, getResult);
- populateBuilderLines(op, "operand", "_gen_arg_{0}", builderArgs, builderLines,
+ llvm::SmallVector<std::string, 4> operandArgNames;
+ builderArgs.reserve(op.getNumOperands() + op.getNumResults() +
+ op.getNumNativeAttributes());
+ populateBuilderArgs(op, builderArgs, operandArgNames);
+ populateBuilderLines(
+ op, "result",
+ llvm::makeArrayRef(builderArgs).take_front(op.getNumResults()),
+ builderLines, getNumResults, getResult);
+ populateBuilderLines(op, "operand", operandArgNames, builderLines,
getNumOperands, getOperand);
+ populateBuilderLinesAttr(
+ op, llvm::makeArrayRef(builderArgs).drop_front(op.getNumResults()),
+ builderLines);
builderArgs.push_back("loc=None");
builderArgs.push_back("ip=None");
llvm::join(builderLines, "\n "));
}
+static void constructAttributeMapping(const llvm::RecordKeeper &records,
+ AttributeClasses &attributeClasses) {
+ for (const llvm::Record *rec :
+ records.getAllDerivedDefinitions("PythonAttr")) {
+ attributeClasses.try_emplace(rec->getValueAsString("cppStorageType").trim(),
+ rec->getValueAsString("pythonType").trim());
+ }
+}
+
/// Emits bindings for a specific Op to the given output stream.
-static void emitOpBindings(const Operator &op, raw_ostream &os) {
+static void emitOpBindings(const Operator &op,
+ const AttributeClasses &attributeClasses,
+ raw_ostream &os) {
os << llvm::formatv(opClassTemplate, op.getCppClassName(),
op.getOperationName());
emitDefaultOpBuilder(op, os);
emitOperandAccessors(op, os);
+ emitAttributeAccessors(op, attributeClasses, os);
emitResultAccessors(op, os);
}
if (clDialectName.empty())
llvm::PrintFatalError("dialect name not provided");
+ AttributeClasses attributeClasses;
+ constructAttributeMapping(records, attributeClasses);
+
os << fileHeader;
os << llvm::formatv(dialectClassTemplate, clDialectName.getValue());
for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) {
Operator op(rec);
if (op.getDialectName() == clDialectName.getValue())
- emitOpBindings(op, os);
+ emitOpBindings(op, attributeClasses, os);
}
return false;
}