EDSC bindings: expose generic Op construction interface
authorAlex Zinenko <zinenko@google.com>
Fri, 1 Mar 2019 14:39:05 +0000 (06:39 -0800)
committerjpienaar <jpienaar@google.com>
Fri, 29 Mar 2019 23:51:32 +0000 (16:51 -0700)
EDSC Expressions can now be used to build arbitrary MLIR operations identified
by their canonical name, i.e. the name obtained from
`OpClass::getOperationName()` for registered operations.  Expose this
functionality to the C API and Python bindings.  This exposes builder-level
interface to Python and avoids the need for experimental Python code to
implement EDSC free function calls for constructing each op type.

This modification required exposing mlir::Attribute to the C API and Python
bindings, which only supports integer attributes for now.

This is step 4/n to making EDSCs more generalizable.

PiperOrigin-RevId: 236306776

mlir/bindings/python/pybind.cpp
mlir/bindings/python/test/test_py2and3.py
mlir/include/mlir-c/Core.h
mlir/include/mlir/IR/Attributes.h
mlir/lib/EDSC/Types.cpp

index eca5d632b48fcb67359dcf1d387a8b433fee8de8..fe80e07146799d953f56aa8911a5f6595c5fde89 100644 (file)
@@ -33,6 +33,7 @@ namespace python {
 
 namespace py = pybind11;
 
+struct PythonAttribute;
 struct PythonBindable;
 struct PythonExpr;
 struct PythonStmt;
@@ -123,6 +124,17 @@ struct PythonMLIRModule {
     return declaration;
   }
 
+  // Create a custom op given its name and arguments.
+  PythonExpr op(const std::string &name, PythonType type,
+                const py::list &arguments, const py::list &successors,
+                py::kwargs attributes);
+
+  // Create an integer attribute.
+  PythonAttribute integerAttr(PythonType type, int64_t value);
+
+  // Create a boolean attribute.
+  PythonAttribute boolAttr(bool value);
+
   void compile() {
     auto created = mlir::ExecutionEngine::create(module.get());
     llvm::handleAllErrors(created.takeError(),
@@ -216,6 +228,26 @@ struct PythonBlock {
   edsc_block_t blk;
 };
 
+struct PythonAttribute {
+  PythonAttribute() : attr(nullptr) {}
+  PythonAttribute(const mlir_attr_t &a) : attr(a) {}
+  PythonAttribute(const PythonAttribute &other) = default;
+  operator mlir_attr_t() { return attr; }
+
+  std::string str() {
+    if (!attr)
+      return "##null attr##";
+
+    std::string res;
+    llvm::raw_string_ostream os(res);
+    Attribute::getFromOpaquePointer(reinterpret_cast<const void *>(attr))
+        .print(os);
+    return res;
+  }
+
+  mlir_attr_t attr;
+};
+
 struct PythonIndexed : public edsc_indexed_t {
   PythonIndexed(PythonExpr e) : edsc_indexed_t{makeIndexed(e)} {}
   PythonIndexed(PythonBindable b) : edsc_indexed_t{makeIndexed(b)} {}
@@ -273,28 +305,33 @@ private:
   edsc_mlir_emitter_t c_emitter;
 };
 
+template <typename ListTy, typename PythonTy, typename Ty>
+ListTy makeCList(SmallVectorImpl<Ty> &owning, const py::list &list) {
+  for (auto &inp : list) {
+    owning.push_back(Ty{inp.cast<PythonTy>()});
+  }
+  return ListTy{owning.data(), owning.size()};
+}
+
 static edsc_stmt_list_t makeCStmts(llvm::SmallVectorImpl<edsc_stmt_t> &owning,
                                    const py::list &stmts) {
-  for (auto &inp : stmts) {
-    owning.push_back(edsc_stmt_t{inp.cast<PythonStmt>()});
-  }
-  return edsc_stmt_list_t{owning.data(), owning.size()};
+  return makeCList<edsc_stmt_list_t, PythonStmt>(owning, stmts);
 }
 
 static edsc_expr_list_t makeCExprs(llvm::SmallVectorImpl<edsc_expr_t> &owning,
                                    const py::list &exprs) {
-  for (auto &inp : exprs) {
-    owning.push_back(edsc_expr_t{inp.cast<PythonExpr>()});
-  }
-  return edsc_expr_list_t{owning.data(), owning.size()};
+  return makeCList<edsc_expr_list_t, PythonExpr>(owning, exprs);
 }
 
 static mlir_type_list_t makeCTypes(llvm::SmallVectorImpl<mlir_type_t> &owning,
                                    const py::list &types) {
-  for (auto &inp : types) {
-    owning.push_back(mlir_type_t{inp.cast<PythonType>()});
-  }
-  return mlir_type_list_t{owning.data(), owning.size()};
+  return makeCList<mlir_type_list_t, PythonType>(owning, types);
+}
+
+static edsc_block_list_t
+makeCBlocks(llvm::SmallVectorImpl<edsc_block_t> &owning,
+            const py::list &blocks) {
+  return makeCList<edsc_block_list_t, PythonBlock>(owning, blocks);
 }
 
 PythonExpr::PythonExpr(const PythonBindable &bindable) : expr{bindable.expr} {}
@@ -390,6 +427,37 @@ void MLIRFunctionEmitter::emitBlockBody(PythonBlock block) {
   emitter.emitStmts(StmtBlock(block).getBody());
 }
 
+PythonExpr PythonMLIRModule::op(const std::string &name, PythonType type,
+                                const py::list &arguments,
+                                const py::list &successors,
+                                py::kwargs attributes) {
+  SmallVector<edsc_expr_t, 8> owningExprs;
+  SmallVector<edsc_block_t, 4> owningBlocks;
+  SmallVector<mlir_named_attr_t, 4> owningAttrs;
+  SmallVector<std::string, 4> owningAttrNames;
+
+  owningAttrs.reserve(attributes.size());
+  owningAttrNames.reserve(attributes.size());
+  for (const auto &kvp : attributes) {
+    owningAttrNames.push_back(kvp.first.str());
+    auto value = kvp.second.cast<PythonAttribute>();
+    owningAttrs.push_back({owningAttrNames.back().c_str(), value});
+  }
+
+  return PythonExpr(::Op(mlir_context_t(&mlirContext), name.c_str(), type,
+                         makeCExprs(owningExprs, arguments),
+                         makeCBlocks(owningBlocks, successors),
+                         {owningAttrs.data(), owningAttrs.size()}));
+}
+
+PythonAttribute PythonMLIRModule::integerAttr(PythonType type, int64_t value) {
+  return PythonAttribute(::makeIntegerAttr(type, value));
+}
+
+PythonAttribute PythonMLIRModule::boolAttr(bool value) {
+  return PythonAttribute(::makeBoolAttr(&mlirContext, value));
+}
+
 PythonBlock PythonBlock::set(const py::list &stmts) {
   SmallVector<edsc_stmt_t, 8> owning;
   ::BlockSetBody(blk, makeCStmts(owning, stmts));
@@ -548,6 +616,11 @@ PYBIND11_MODULE(pybind, m) {
       .def("set", &PythonBlock::set)
       .def("__str__", &PythonBlock::str);
 
+  py::class_<PythonAttribute>(m, "Attribute",
+                              "Wrapping class for mlir::Attribute")
+      .def(py::init<PythonAttribute>())
+      .def("__str__", &PythonAttribute::str);
+
   py::class_<PythonType>(m, "Type", "Wrapping class for mlir::Type.")
       .def(py::init<PythonType>())
       .def("__str__", &PythonType::str);
@@ -565,6 +638,15 @@ PYBIND11_MODULE(pybind, m) {
       "directly require integration with a tensor library (e.g. numpy). This "
       "is left as the prerogative of libraries and frameworks for now.")
       .def(py::init<>())
+      .def("op", &PythonMLIRModule::op, py::arg("name"), py::arg("type"),
+           py::arg("arguments"), py::arg("successors") = py::list(),
+           "Creates a new expression identified by its canonical name.")
+      .def("boolAttr", &PythonMLIRModule::boolAttr,
+           "Creates an mlir::BoolAttr with the given value")
+      .def(
+          "integerAttr", &PythonMLIRModule::integerAttr,
+          "Creates an mlir::IntegerAttr of the given type with the given value "
+          "in the context associated with this MLIR module.")
       .def("declare_function", &PythonMLIRModule::declareFunction,
            "Declares a new mlir::Function in the current mlir::Module.  The "
            "function has no definition and can be linked to an external "
index 4b78402f93009e49a57f4fcd068801ab0d6c7193..7cee2191a362fb1c07448a5f704c4cbb8fb2427c 100644 (file)
@@ -30,6 +30,17 @@ class EdscTest(unittest.TestCase):
       str = expr.__str__()
       self.assertIn("($1 * ($2 + $3))", str)
 
+  def testCustomOp(self):
+    with E.ContextManager():
+      a, b = (E.Expr(E.Bindable(self.i32Type)) for _ in range(2))
+      c1 = self.module.op(
+          "constant",
+          self.i32Type, [],
+          value=self.module.integerAttr(self.i32Type, 42))
+      expr = self.module.op("addi", self.i32Type, [c1, b])
+      str = expr.__str__()
+      self.assertIn("addi(42, $2)", str)
+
   def testOneLoop(self):
     with E.ContextManager():
       i, lb, ub, step = list(
@@ -319,6 +330,39 @@ class EdscTest(unittest.TestCase):
       self.module.compile()
       self.assertNotEqual(self.module.get_engine_address(), 0)
 
+  def testCustomOpEmission(self):
+    f = self.module.make_function("fooer", [self.i32Type, self.i32Type], [])
+    with E.ContextManager():
+      emitter = E.MLIRFunctionEmitter(f)
+      funcArg1, funcArg2 = emitter.bind_function_arguments()
+      boolAttr = self.module.boolAttr(True)
+      expr = self.module.op(
+          "foo", self.i32Type, [funcArg1, funcArg2], attr=boolAttr)
+      block = E.Block([E.Stmt(expr), E.Return()])
+      emitter.emit_inplace(block)
+
+      code = str(f)
+      self.assertIn('%0 = "foo"(%arg0, %arg1) {attr: true} : (i32, i32) -> i32',
+                    code)
+
+  # Create 'addi' using the generic Op interface.  We need an operation known
+  # to the execution engine so that the engine can compile it.
+  def testCustomOpCompilation(self):
+    f = self.module.make_function("adder", [self.i32Type], [])
+    with E.ContextManager():
+      emitter = E.MLIRFunctionEmitter(f)
+      funcArg, = emitter.bind_function_arguments()
+      c1 = self.module.op(
+          "constant",
+          self.i32Type, [],
+          value=self.module.integerAttr(self.i32Type, 42))
+      expr = self.module.op("addi", self.i32Type, [c1, funcArg])
+      block = E.Block([E.Stmt(expr), E.Return()])
+      emitter.emit_inplace(block)
+      self.module.compile()
+      self.assertNotEqual(self.module.get_engine_address(), 0)
+
+
   def testMLIREmission(self):
     shape = [3, 4, 5]
     m = self.module.make_memref_type(self.f32Type, shape)
index 56d89087a4ed51c2104436767b5b6cd87e5146aa..097a6dfc1b46fb5328e0716a2b55ba56875e0340 100644 (file)
@@ -36,6 +36,8 @@ typedef void *mlir_context_t;
 typedef const void *mlir_type_t;
 /// Opaque C type for mlir::Function*.
 typedef void *mlir_func_t;
+/// Opaque C type for mlir::Attribute.
+typedef const void *mlir_attr_t;
 /// Opaque C type for mlir::edsc::MLIREmiter.
 typedef void *edsc_mlir_emitter_t;
 /// Opaque C type for mlir::edsc::Expr.
@@ -85,6 +87,21 @@ typedef struct {
   uint64_t n;
 } edsc_indexed_list_t;
 
+typedef struct {
+  edsc_block_t *list;
+  uint64_t n;
+} edsc_block_list_t;
+
+typedef struct {
+  const char *name;
+  mlir_attr_t value;
+} mlir_named_attr_t;
+
+typedef struct {
+  mlir_named_attr_t *list;
+  uint64_t n;
+} mlir_named_attr_list_t;
+
 /// Minimal C API for exposing EDSCs to Swift, Python and other languages.
 
 /// Returns a simple scalar mlir::Type using the following convention:
@@ -113,6 +130,13 @@ mlir_type_t makeFunctionType(mlir_context_t context, mlir_type_list_t inputs,
 /// Returns an `mlir::IndexType`.
 mlir_type_t makeIndexType(mlir_context_t context);
 
+/// Returns an `mlir::IntegerAttr` of the specified type that contains the given
+/// value.
+mlir_attr_t makeIntegerAttr(mlir_type_t type, int64_t value);
+
+/// Returns an `mlir::BoolAttr` with the given value.
+mlir_attr_t makeBoolAttr(mlir_context_t context, bool value);
+
 /// Returns the arity of `function`.
 unsigned getFunctionArity(mlir_func_t function);
 
@@ -212,6 +236,12 @@ edsc_indexed_t makeIndexed(edsc_expr_t expr);
 ///   - `indexed` must not have been indexed previously.
 edsc_indexed_t index(edsc_indexed_t indexed, edsc_expr_list_t indices);
 
+/// Returns an opaque expression that will emit an abstract operation identified
+/// by its name.
+edsc_expr_t Op(mlir_context_t context, const char *name, mlir_type_t resultType,
+               edsc_expr_list_t arguments, edsc_block_list_t successors,
+               mlir_named_attr_list_t attrs);
+
 /// Returns an opaque expression that will emit an mlir::LoadOp.
 edsc_expr_t Load(edsc_indexed_t indexed, edsc_expr_list_t indices);
 
index f51af2123136cc3f93accd395f203fc3834b97c0..e0c1a5889b973de480e97303845e7acfb51328a7 100644 (file)
@@ -128,6 +128,14 @@ public:
   void print(raw_ostream &os) const;
   void dump() const;
 
+  /// Get an opaque pointer to the attribute.
+  const void *getAsOpaquePointer() const { return attr; }
+  /// Construct an attribute from the opaque pointer representation.
+  static Attribute getFromOpaquePointer(const void *ptr) {
+    return Attribute(
+        const_cast<ImplType *>(reinterpret_cast<const ImplType *>(ptr)));
+  }
+
   friend ::llvm::hash_code hash_value(Attribute arg);
 
 protected:
index 29abb212f10a7f56e3d7d31b1839b21fa0771164..14544c8b2887fb3cee096c17d5dd0192ef1a0b8a 100644 (file)
@@ -408,6 +408,20 @@ llvm::SmallVector<Expr, 8> mlir::edsc::makeNewExprs(unsigned n, Type type) {
   return res;
 }
 
+template <typename Target, size_t N, typename Source>
+SmallVector<Target, N> convertCList(Source list) {
+  SmallVector<Target, N> result;
+  result.reserve(list.n);
+  for (unsigned i = 0; i < list.n; ++i) {
+    result.push_back(Target(list.list[i]));
+  }
+  return result;
+}
+
+SmallVector<StmtBlock, 4> makeBlocks(edsc_block_list_t list) {
+  return convertCList<StmtBlock, 4>(list);
+}
+
 static llvm::SmallVector<Expr, 8> makeExprs(edsc_expr_list_t exprList) {
   llvm::SmallVector<Expr, 8> exprs;
   exprs.reserve(exprList.n);
@@ -425,6 +439,28 @@ static void fillStmts(edsc_stmt_list_t enclosedStmts,
   }
 }
 
+edsc_expr_t Op(mlir_context_t context, const char *name, mlir_type_t resultType,
+               edsc_expr_list_t arguments, edsc_block_list_t successors,
+               mlir_named_attr_list_t attrs) {
+  mlir::MLIRContext *ctx = reinterpret_cast<mlir::MLIRContext *>(context);
+
+  auto blocks = makeBlocks(successors);
+
+  SmallVector<NamedAttribute, 4> attributes;
+  attributes.reserve(attrs.n);
+  for (int i = 0; i < attrs.n; ++i) {
+    auto attribute = Attribute::getFromOpaquePointer(
+        reinterpret_cast<const void *>(attrs.list[i].value));
+    auto name = Identifier::get(attrs.list[i].name, ctx);
+    attributes.emplace_back(name, attribute);
+  }
+
+  return VariadicExpr(
+      name, makeExprs(arguments),
+      Type::getFromOpaquePointer(reinterpret_cast<const void *>(resultType)),
+      attributes, blocks);
+}
+
 Expr mlir::edsc::alloc(llvm::ArrayRef<Expr> sizes, Type memrefType) {
   return VariadicExpr::make<AllocOp>(sizes, memrefType);
 }
@@ -880,6 +916,7 @@ void mlir::edsc::Expr::print(raw_ostream &os) const {
   // Special case for integer constants that are printed as is.  Use
   // sign-extended result for everything but i1 (booleans).
   if (this->is_op<ConstantIndexOp>() || this->is_op<ConstantIntOp>()) {
+    assert(getAttribute("value"));
     APInt value = getAttribute("value").cast<IntegerAttr>().getValue();
     if (value.getBitWidth() == 1)
       os << value.getZExtValue();
@@ -1328,6 +1365,18 @@ mlir_type_t makeIndexType(mlir_context_t context) {
   return mlir_type_t{type.getAsOpaquePointer()};
 }
 
+mlir_attr_t makeIntegerAttr(mlir_type_t type, int64_t value) {
+  auto ty = Type::getFromOpaquePointer(reinterpret_cast<const void *>(type));
+  auto attr = IntegerAttr::get(ty, value);
+  return mlir_attr_t{attr.getAsOpaquePointer()};
+}
+
+mlir_attr_t makeBoolAttr(mlir_context_t context, bool value) {
+  auto *ctx = reinterpret_cast<mlir::MLIRContext *>(context);
+  auto attr = BoolAttr::get(value, ctx);
+  return mlir_attr_t{attr.getAsOpaquePointer()};
+}
+
 unsigned getFunctionArity(mlir_func_t function) {
   auto *f = reinterpret_cast<mlir::Function *>(function);
   return f->getNumArguments();