[mlir][py] Enable building ops with raw inputs
authorJacques Pienaar <jpienaar@google.com>
Wed, 21 Dec 2022 18:10:31 +0000 (10:10 -0800)
committerJacques Pienaar <jpienaar@google.com>
Wed, 21 Dec 2022 18:10:31 +0000 (10:10 -0800)
For cases where we can automatically construct the Attribute allow for more
user-friendly input. This is consistent with C++ builder generation as well
choice of which single builder to generate here (most
specialized/user-friendly).

Registration of attribute builders from more pythonic input is all Python side.
The downside is that
  * extra checking to see if user provided a custom builder in op builders,
  * the ODS attribute name is load bearing
upside is that
  * easily change these/register dialect specific ones in downstream projects,
  * adding support/changing to different convenience builders are all along with
    the rest of the convenience functions in Python (and no additional changes
    to tablegen file or recompilation needed);

Allow for both building with Attributes as well as raw inputs. This change
should therefore be backwards compatible as well as allow for avoiding
recreating Attribute where already available.

Differential Revision: https://reviews.llvm.org/D139568

mlir/docs/Bindings/Python.md
mlir/lib/Bindings/Python/Globals.h
mlir/lib/Bindings/Python/IRCore.cpp
mlir/lib/Bindings/Python/IRModule.cpp
mlir/python/mlir/ir.py
mlir/test/mlir-tblgen/op-python-bindings.td
mlir/test/python/dialects/shape.py
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp

index cdb00dc..a7b2b31 100644 (file)
@@ -743,6 +743,34 @@ with Context():
   dictionary = DictAttr.get({"array": array, "unit": UnitAttr.get()})
 ```
 
+Custom builders for Attributes to be used during Operation creation can be
+registered by way of the `register_attribute_builder`. In particular the
+following is how a custom builder is registered for `I32Attr`:
+
+```python
+@register_attribute_builder("I32Attr")
+def _i32Attr(x: int, context: Context):
+  return IntegerAttr.get(
+        IntegerType.get_signless(32, context=context), x)
+```
+
+This allows to invoke op creation of an op with a `I32Attr` with
+
+```python
+foo.Op(30)
+```
+
+The registration is based on the ODS name but registry is via pure python
+method. Only single custom builder is allowed to be registered per ODS attribute
+type (e.g., I32Attr can have only one, which can correspond to multiple of the
+underlying IntegerAttr type).
+
+instead of
+
+```python
+foo.Op(IntegerAttr.get(IndexType.get_signless(32, context=context), 30))
+```
+
 ## Style
 
 In general, for the core parts of MLIR, the Python bindings should be largely
index 6613d2b..ba6cfb5 100644 (file)
@@ -58,6 +58,12 @@ public:
   /// have a DIALECT_NAMESPACE attribute.
   pybind11::object registerDialectDecorator(pybind11::object pyClass);
 
+  /// Adds a user-friendly Attribute builder.
+  /// Raises an exception if the mapping already exists.
+  /// This is intended to be called by implementation code.
+  void registerAttributeBuilder(const std::string &attributeKind,
+                                pybind11::function pyFunc);
+
   /// Adds a concrete implementation dialect class.
   /// Raises an exception if the mapping already exists.
   /// This is intended to be called by implementation code.
@@ -71,6 +77,10 @@ public:
                              pybind11::object pyClass,
                              pybind11::object rawOpViewClass);
 
+  /// Returns the custom Attribute builder for Attribute kind.
+  std::optional<pybind11::function>
+  lookupAttributeBuilder(const std::string &attributeKind);
+
   /// Looks up a registered dialect class by namespace. Note that this may
   /// trigger loading of the defining module and can arbitrarily re-enter.
   llvm::Optional<pybind11::object>
@@ -92,6 +102,8 @@ private:
   /// Map of operation name to custom subclass that directly initializes
   /// the OpView base class (bypassing the user class constructor).
   llvm::StringMap<pybind11::object> rawOpViewClassMap;
+  /// Map of attribute ODS name to custom builder.
+  llvm::StringMap<pybind11::function> attributeBuilderMap;
 
   /// Set of dialect namespaces that we have attempted to import implementation
   /// modules for.
index 794be97..f2aa8da 100644 (file)
@@ -194,6 +194,29 @@ struct PyGlobalDebugFlag {
   }
 };
 
+struct PyAttrBuilderMap {
+  static bool dunderContains(const std::string &attributeKind) {
+    return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
+  }
+  static py::function dundeGetItemNamed(const std::string &attributeKind) {
+    auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind);
+    if (!builder)
+      throw py::key_error();
+    return *builder;
+  }
+  static void dundeSetItemNamed(const std::string &attributeKind,
+                                py::function func) {
+    PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func));
+  }
+
+  static void bind(py::module &m) {
+    py::class_<PyAttrBuilderMap>(m, "AttrBuilder", py::module_local())
+        .def_static("contains", &PyAttrBuilderMap::dunderContains)
+        .def_static("get", &PyAttrBuilderMap::dundeGetItemNamed)
+        .def_static("insert", &PyAttrBuilderMap::dundeSetItemNamed);
+  }
+};
+
 //------------------------------------------------------------------------------
 // Collections.
 //------------------------------------------------------------------------------
@@ -3283,4 +3306,7 @@ void mlir::python::populateIRCore(py::module &m) {
 
   // Debug bindings.
   PyGlobalDebugFlag::bind(m);
+
+  // Attribute builder getter.
+  PyAttrBuilderMap::bind(m);
 }
index b6d1df5..be6de5f 100644 (file)
@@ -60,6 +60,17 @@ void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
   loadedDialectModulesCache.insert(dialectNamespace);
 }
 
+void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
+                                         py::function pyFunc) {
+  py::function &found = attributeBuilderMap[attributeKind];
+  if (found) {
+    throw std::runtime_error((llvm::Twine("Attribute builder for '") +
+                              attributeKind + "' is already registered")
+                                 .str());
+  }
+  found = std::move(pyFunc);
+}
+
 void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
                                     py::object pyClass) {
   py::object &found = dialectClassMap[dialectNamespace];
@@ -84,6 +95,22 @@ void PyGlobals::registerOperationImpl(const std::string &operationName,
   rawOpViewClassMap[operationName] = std::move(rawOpViewClass);
 }
 
+std::optional<py::function>
+PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) {
+  // Fast match against the class map first (common case).
+  const auto foundIt = attributeBuilderMap.find(attributeKind);
+  if (foundIt != attributeBuilderMap.end()) {
+    if (foundIt->second.is_none())
+      return std::nullopt;
+    assert(foundIt->second && "py::function is defined");
+    return foundIt->second;
+  }
+
+  // Not found and loading did not yield a registration. Negative cache.
+  attributeBuilderMap[attributeKind] = py::none();
+  return std::nullopt;
+}
+
 llvm::Optional<py::object>
 PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
   loadDialectModule(dialectNamespace);
index 99e88ff..1998691 100644 (file)
@@ -4,3 +4,44 @@
 
 from ._mlir_libs._mlir.ir import *
 from ._mlir_libs._mlir.ir import _GlobalDebug
+
+
+# Convenience decorator for registering user-friendly Attribute builders.
+def register_attribute_builder(kind):
+  def decorator_builder(func):
+    AttrBuilder.insert(kind, func)
+    return func
+  return decorator_builder
+
+
+@register_attribute_builder("BoolAttr")
+def _boolAttr(x: bool, context: Context):
+  return BoolAttr.get(x, context=context)
+
+@register_attribute_builder("IndexAttr")
+def _indexAttr(x: int, context: Context):
+  return IntegerAttr.get(IndexType.get(context=context), x)
+
+@register_attribute_builder("I32Attr")
+def _i32Attr(x: int, context: Context):
+  return IntegerAttr.get(
+      IntegerType.get_signless(32, context=context), x)
+
+@register_attribute_builder("I64Attr")
+def _i64Attr(x: int, context: Context):
+  return IntegerAttr.get(
+      IntegerType.get_signless(64, context=context), x)
+
+@register_attribute_builder("SymbolNameAttr")
+def _symbolNameAttr(x: str, context: Context):
+  return StringAttr.get(x, context=context)
+
+try:
+  import numpy as np
+  @register_attribute_builder("IndexElementsAttr")
+  def _indexElementsAttr(x: list[int], context: Context):
+    return DenseElementsAttr.get(
+        np.array(x, dtype=np.int64), type=IndexType.get(context=context),
+        context=context)
+except ImportError:
+  pass
index 2dda3db..97fe306 100644 (file)
@@ -115,11 +115,14 @@ def AttributedOp : TestOp<"attributed_op"> {
   // CHECK:   results = []
   // CHECK:   attributes = {}
   // CHECK:   regions = None
-  // CHECK:   attributes["i32attr"] = i32attr
-  // CHECK:   if optionalF32Attr is not None: attributes["optionalF32Attr"] = optionalF32Attr
+  // CHECK:   attributes["i32attr"] = (i32attr if (
+  // CHECK-NEXT:   issubclass(type(i32attr), _ods_ir.Attribute) or
+  // CHECK-NEXT:   not _ods_ir.AttrBuilder.contains('I32Attr')
+  // CHECK-NEXT:   _ods_ir.AttrBuilder.get('I32Attr')(i32attr, context=_ods_context)
+  // CHECK:   if optionalF32Attr is not None: attributes["optionalF32Attr"] = (optionalF32Attr
   // CHECK:   if bool(unitAttr): attributes["unitAttr"] = _ods_ir.UnitAttr.get(
   // CHECK:     _ods_get_default_loc_context(loc))
-  // CHECK:   attributes["in"] = in_
+  // CHECK:   attributes["in"] = (in_
   // CHECK:   _ods_successors = None
   // CHECK:   super().__init__(self.build_generic(
   // CHECK:     attributes=attributes, results=results, operands=operands,
@@ -161,7 +164,7 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
   // CHECK:   operands.append(_get_op_result_or_value(_gen_arg_2))
   // CHECK:   if bool(in_): attributes["in"] = _ods_ir.UnitAttr.get(
   // CHECK:     _ods_get_default_loc_context(loc))
-  // CHECK:   if is_ is not None: attributes["is"] = is_
+  // CHECK:   if is_ is not None: attributes["is"] = (is_
   // CHECK:   _ods_successors = None
   // CHECK:   super().__init__(self.build_generic(
   // CHECK:     attributes=attributes, results=results, operands=operands,
@@ -188,8 +191,8 @@ def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> {
   // CHECK:   results = []
   // CHECK:   attributes = {}
   // CHECK:   regions = None
-  // CHECK:   if arr is not None: attributes["arr"] = arr
-  // CHECK:   if unsupported is not None: attributes["unsupported"] = unsupported
+  // CHECK:   if arr is not None: attributes["arr"] = (arr
+  // CHECK:   if unsupported is not None: attributes["unsupported"] = (unsupported
   // CHECK:   _ods_successors = None
   // CHECK:   super().__init__(self.build_generic(
   // CHECK:     attributes=attributes, results=results, operands=operands,
@@ -202,7 +205,7 @@ def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> {
 
 // CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_op"
 def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResultType]> {
-  // CHECK: def __init__(self, type, *, loc=None, ip=None):
+  // CHECK: def __init__(self, type_, *, loc=None, ip=None):
   // CHECK:   operands = []
   // CHECK:   results = []
   // CHECK:   _ods_result_type_source_attr = attributes["type"]
@@ -217,7 +220,7 @@ def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResu
 
 // CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_variadic_op"
 def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [FirstAttrDerivedResultType]> {
-  // CHECK: def __init__(self, res, _gen_res_1, type, *, loc=None, ip=None):
+  // CHECK: def __init__(self, res, _gen_res_1, type_, *, loc=None, ip=None):
   let arguments = (ins TypeAttr:$type);
   let results = (outs AnyType:$res, Variadic<AnyType>);
 }
index 2ebad0d..2d2a203 100644 (file)
@@ -22,9 +22,18 @@ def testConstShape():
       @func.FuncOp.from_py_func(
           RankedTensorType.get((12, ShapedType.get_dynamic_size()), f32))
       def const_shape_tensor(arg):
+        shape.ConstWitnessOp(False)
+        shape.ConstSizeOp(30)
+        shape.ConstSizeOp(IntegerAttr.get(IndexType.get(), 40))
+        shape.ConstShapeOp([1, 2])
         return shape.ConstShapeOp(
-          DenseElementsAttr.get(np.array([10, 20], dtype=np.int64), type=IndexType.get()))
+            DenseElementsAttr.get(
+                np.array([3, 4], dtype=np.int64), type=IndexType.get()))
 
     # CHECK-LABEL: func @const_shape_tensor(%arg0: tensor<12x?xf32>)
-    # CHECK: shape.const_shape [10, 20] : tensor<2xindex>
+    # CHECK-DAG: shape.const_witness false
+    # CHECK-DAG: shape.const_size 30
+    # CHECK-DAG: shape.const_size 40
+    # CHECK-DAG: shape.const_shape [1, 2] : tensor<2xindex>
+    # CHECK-DAG: shape.const_shape [3, 4] : tensor<2xindex>
     print(module)
index a5ffcc4..1bd98ee 100644 (file)
@@ -280,15 +280,16 @@ static llvm::cl::opt<std::string> clDialectExtensionName(
 
 using AttributeClasses = DenseMap<StringRef, StringRef>;
 
-/// Checks whether `str` is a Python keyword.
-static bool isPythonKeyword(StringRef str) {
-  static llvm::StringSet<> keywords(
-      {"and",   "as",     "assert",   "break", "class",  "continue",
-       "def",   "del",    "elif",     "else",  "except", "finally",
-       "for",   "from",   "global",   "if",    "import", "in",
-       "is",    "lambda", "nonlocal", "not",   "or",     "pass",
-       "raise", "return", "try",      "while", "with",   "yield"});
-  return keywords.contains(str);
+/// Checks whether `str` is a Python keyword or would shadow builtin function.
+static bool isPythonReserved(StringRef str) {
+  static llvm::StringSet<> reserved(
+      {"and",      "as",    "assert", "break",      "callable", "class",
+       "continue", "def",   "del",    "elif",       "else",     "except",
+       "finally",  "for",   "from",   "global",     "if",       "import",
+       "in",       "is",    "lambda", "nonlocal",   "not",      "or",
+       "pass",     "raise", "return", "issubclass", "try",      "type",
+       "while",    "with",  "yield"});
+  return reserved.contains(str);
 }
 
 /// Checks whether `str` would shadow a generated variable or attribute
@@ -306,7 +307,7 @@ static bool isODSReserved(StringRef str) {
 /// (does not change the `name` if it already is suitable) and returns the
 /// modified version.
 static std::string sanitizeName(StringRef name) {
-  if (isPythonKeyword(name) || isODSReserved(name))
+  if (isPythonReserved(name) || isODSReserved(name))
     return (name + "_").str();
   return name.str();
 }
@@ -531,16 +532,30 @@ constexpr const char *multiOperandAppendPackTemplate =
     "operands.append(_get_op_results_or_values({0}))";
 constexpr const char *multiResultAppendTemplate = "results.extend({0})";
 
-/// Template for setting an attribute in the operation builder.
-///   {0} is the attribute name;
-///   {1} is the builder argument name.
-constexpr const char *initAttributeTemplate = R"Py(attributes["{0}"] = {1})Py";
-
-/// Template for setting an optional attribute in the operation builder.
-///   {0} is the attribute name;
-///   {1} is the builder argument name.
-constexpr const char *initOptionalAttributeTemplate =
-    R"Py(if {1} is not None: attributes["{0}"] = {1})Py";
+/// Template for attribute builder from raw input in the operation builder.
+///   {0} is the builder argument name;
+///   {1} is the attribute builder from raw;
+///   {2} is the attribute builder from raw.
+/// Use the value the user passed in if either it is already an Attribute or
+/// there is no method registered to make it an Attribute.
+constexpr const char *initAttributeWithBuilderTemplate =
+    R"Py(attributes["{1}"] = ({0} if (
+    issubclass(type({0}), _ods_ir.Attribute) or
+    not _ods_ir.AttrBuilder.contains('{2}')) else
+      _ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";
+
+/// Template for attribute builder from raw input for optional attribute in the
+/// operation builder.
+///   {0} is the builder argument name;
+///   {1} is the attribute builder from raw;
+///   {2} is the attribute builder from raw.
+/// Use the value the user passed in if either it is already an Attribute or
+/// there is no method registered to make it an Attribute.
+constexpr const char *initOptionalAttributeWithBuilderTemplate =
+    R"Py(if {0} is not None: attributes["{1}"] = ({0} if (
+        issubclass(type({0}), _ods_ir.Attribute) or
+        not _ods_ir.AttrBuilder.contains('{2}')) else
+          _ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";
 
 constexpr const char *initUnitAttributeTemplate =
     R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get(
@@ -656,6 +671,7 @@ static void
 populateBuilderLinesAttr(const Operator &op,
                          llvm::ArrayRef<std::string> argNames,
                          llvm::SmallVectorImpl<std::string> &builderLines) {
+  builderLines.push_back("_ods_context = _ods_get_default_loc_context(loc)");
   for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
     Argument arg = op.getArg(i);
     auto *attribute = arg.dyn_cast<NamedAttribute *>();
@@ -670,10 +686,10 @@ populateBuilderLinesAttr(const Operator &op,
     }
 
     builderLines.push_back(llvm::formatv(
-        (attribute->attr.isOptional() || attribute->attr.hasDefaultValue())
-            ? initOptionalAttributeTemplate
-            : initAttributeTemplate,
-        attribute->name, argNames[i]));
+        attribute->attr.isOptional() || attribute->attr.hasDefaultValue()
+            ? initOptionalAttributeWithBuilderTemplate
+            : initAttributeWithBuilderTemplate,
+        argNames[i], attribute->name, attribute->attr.getAttrDefName()));
   }
 }
 
@@ -753,8 +769,7 @@ constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})";
 /// corresponding interface:
 ///   - {0} is the name of the class for which the types are inferred.
 constexpr const char *inferTypeInterfaceTemplate =
-    R"PY(_ods_context = _ods_get_default_loc_context(loc)
-results = _ods_ir.InferTypeOpInterface({0}).inferReturnTypes(
+    R"PY(results = _ods_ir.InferTypeOpInterface({0}).inferReturnTypes(
     operands=operands,
     attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context),
     context=_ods_context,