Python bindings: provide context managers for the Blocks
authorAlex Zinenko <zinenko@google.com>
Tue, 12 Mar 2019 13:55:03 +0000 (06:55 -0700)
committerjpienaar <jpienaar@google.com>
Sat, 30 Mar 2019 00:13:57 +0000 (17:13 -0700)
Expose EDSC block builders as Python context managers, similarly to loop
builders.  Note that blocks, unlike loops, are addressable and may need to be
"declared" without necessarily filling their bodies with instructions.  This is
the case, for example, when branching to a new block from the existing block.
Therefore, creating the block context manager immediately creates the block
(unless the manager captures an existing block) by creating and destroying the
block builder.  With this approach, one can either fill in the block and refer
to it later leveraging Python's dynamic variable lookup

    with BlockContext([indexType]) as b:
      op(...)  # operation inside the block
      ret()
    op(...)  # operation outside the block (in the function entry block)
    br(b, [...])    # branching to the block created above

or declare the block contexts upfront and enter them on demand

    bb1 = BlockContext()  # empty block created in the surrounding function
    bb2 = BlockContext()  # context
    cond_br(bb1.handle, [], bb2.handle, [])  # branch to blocks from here
    with bb1:
      op(...)  # operation inside the first block
    with bb2:
      op(...)  # operation inside the second block
    with bb1:
      op(...)  # append operation to the first block

Additionally, one can create multiple throw-away contexts that append to the
same block

    with BlockContext() as b:
      op(...)  # operation inside the block
    with BlockContext(appendTo(b)):
      op(...)  # new context appends to the block

which has a potential of being extended to control the insertion point of the
block at a finer level of granularity.

PiperOrigin-RevId: 238005298

mlir/bindings/python/pybind.cpp
mlir/bindings/python/test/test_py2and3.py

index 946f98a..8f4ec6f 100644 (file)
@@ -8,6 +8,7 @@
 
 #include "mlir-c/Core.h"
 #include "mlir/EDSC/Builders.h"
+#include "mlir/EDSC/Intrinsics.h"
 #include "mlir/EDSC/MLIREmitter.h"
 #include "mlir/EDSC/Types.h"
 #include "mlir/ExecutionEngine/ExecutionEngine.h"
@@ -223,6 +224,25 @@ struct PythonValueHandle {
   mlir::edsc::ValueHandle value;
 };
 
+struct PythonBlockHandle {
+  PythonBlockHandle() : value(nullptr) {}
+  PythonBlockHandle(const PythonBlockHandle &other) = default;
+  PythonBlockHandle(const mlir::edsc::BlockHandle &other) : value(other) {}
+  operator mlir::edsc::BlockHandle() const { return value; }
+
+  PythonValueHandle arg(int index) { return arguments[index]; }
+
+  std::string str() {
+    std::string s;
+    llvm::raw_string_ostream os(s);
+    value.getBlock()->print(os);
+    return os.str();
+  }
+
+  mlir::edsc::BlockHandle value;
+  std::vector<mlir::edsc::ValueHandle> arguments;
+};
+
 struct PythonLoopContext {
   PythonLoopContext(PythonValueHandle lb, PythonValueHandle ub, int64_t step)
       : lb(lb), ub(ub), step(step) {}
@@ -294,6 +314,76 @@ struct PythonLoopNestContext {
   LoopNestBuilder *builder = nullptr;
 };
 
+struct PythonBlockAppender {
+  PythonBlockAppender(const PythonBlockHandle &handle) : handle(handle) {}
+  PythonBlockHandle handle;
+};
+
+struct PythonBlockContext {
+public:
+  PythonBlockContext() {
+    createBlockBuilder();
+    clearBuilder();
+  }
+  PythonBlockContext(const std::vector<PythonType> &argTypes) {
+    handle.arguments.reserve(argTypes.size());
+    for (const auto &t : argTypes) {
+      auto type =
+          Type::getFromOpaquePointer(reinterpret_cast<const void *>(t.type));
+      handle.arguments.emplace_back(type);
+    }
+    createBlockBuilder();
+    clearBuilder();
+  }
+  PythonBlockContext(const PythonBlockAppender &a) : handle(a.handle) {}
+  PythonBlockContext(const PythonBlockContext &) = delete;
+  PythonBlockContext(PythonBlockContext &&) = default;
+  PythonBlockContext &operator=(const PythonBlockContext &) = delete;
+  PythonBlockContext &operator=(PythonBlockContext &&) = default;
+  ~PythonBlockContext() {
+    assert(!builder && "did not exit from the block context");
+  }
+
+  // EDSC maintain an implicit stack of builders (mostly for keeping track of
+  // insretion points); every operation gets inserted using the top-of-the-stack
+  // builder.  Creating a new EDSC Builder automatically puts it on the stack,
+  // effectively entering the block for it.
+  void createBlockBuilder() {
+    if (handle.value.getBlock()) {
+      builder = new BlockBuilder(handle.value, mlir::edsc::Append());
+    } else {
+      std::vector<ValueHandle *> args;
+      args.reserve(handle.arguments.size());
+      for (auto &a : handle.arguments)
+        args.push_back(&a);
+      builder = new BlockBuilder(&handle.value, args);
+    }
+  }
+
+  PythonBlockHandle enter() {
+    createBlockBuilder();
+    return handle;
+  }
+
+  void exit(py::object, py::object, py::object) { clearBuilder(); }
+
+  PythonBlockHandle getHandle() { return handle; }
+
+  // EDSC maintain an implicit stack of builders (mostly for keeping track of
+  // insretion points); every operation gets inserted using the top-of-the-stack
+  // builder.  Calling operator() on a builder pops the builder from the stack,
+  // effectively resetting the insertion point to its position before we entered
+  // the block.
+  void clearBuilder() {
+    (*builder)({}); // exit from the builder's scope.
+    delete builder;
+    builder = nullptr;
+  }
+
+  PythonBlockHandle handle;
+  BlockBuilder *builder = nullptr;
+};
+
 struct PythonBindable : public PythonExpr {
   explicit PythonBindable(const PythonType &type)
       : PythonExpr(edsc_expr_t{makeBindable(type.type)}) {}
@@ -333,17 +423,6 @@ struct PythonBlock {
   edsc_block_t blk;
 };
 
-struct PythonBlockHandle {
-  PythonBlockHandle() : value(nullptr) {}
-  PythonBlockHandle(const PythonBlockHandle &other) = default;
-  PythonBlockHandle(const mlir::edsc::BlockHandle &other) : value(other) {}
-  operator mlir::edsc::BlockHandle() const { return value; }
-
-  std::string str() const { return "^block"; }
-
-  mlir::edsc::BlockHandle value;
-};
-
 struct PythonAttribute {
   PythonAttribute() : attr(nullptr) {}
   PythonAttribute(const mlir_attr_t &a) : attr(a) {}
@@ -785,6 +864,26 @@ PYBIND11_MODULE(pybind, m) {
   m.def("IdxCst", [](int64_t val) -> PythonValueHandle {
     return ValueHandle(index_t(val));
   });
+  m.def("appendTo", [](const PythonBlockHandle &handle) {
+    return PythonBlockAppender(handle);
+  });
+  m.def(
+      "ret",
+      [](const std::vector<PythonValueHandle> &args) {
+        std::vector<ValueHandle> values(args.begin(), args.end());
+        intrinsics::RETURN(values);
+        return PythonValueHandle(nullptr);
+      },
+      py::arg("args") = std::vector<PythonValueHandle>());
+  m.def(
+      "br",
+      [](const PythonBlockHandle &dest,
+         const std::vector<PythonValueHandle> &args) {
+        std::vector<ValueHandle> values(args.begin(), args.end());
+        intrinsics::BR(dest, values);
+        return PythonValueHandle(nullptr);
+      },
+      py::arg("dest"), py::arg("args") = std::vector<PythonValueHandle>());
 
   m.def("Max", [](const py::list &args) {
     SmallVector<edsc_expr_t, 8> owning;
@@ -986,6 +1085,25 @@ PYBIND11_MODULE(pybind, m) {
                  -> PythonValueHandle { return lhs.value % rhs.value; });
   }
 
+  py::class_<PythonBlockAppender>(
+      m, "BlockAppender",
+      "A dummy class signaling BlockContext to append IR to the given block "
+      "instead of creating a new block")
+      .def(py::init<const PythonBlockHandle &>());
+  py::class_<PythonBlockHandle>(m, "BlockHandle",
+                                "A wrapper around mlir::edsc::BlockHandle")
+      .def(py::init<PythonBlockHandle>())
+      .def("arg", &PythonBlockHandle::arg);
+
+  py::class_<PythonBlockContext>(m, "BlockContext",
+                                 "A wrapper around mlir::edsc::BlockBuilder")
+      .def(py::init<>())
+      .def(py::init<const std::vector<PythonType> &>())
+      .def(py::init<const PythonBlockAppender &>())
+      .def("__enter__", &PythonBlockContext::enter)
+      .def("__exit__", &PythonBlockContext::exit)
+      .def("handle", &PythonBlockContext::getHandle);
+
   py::class_<MLIRFunctionEmitter>(
       m, "MLIRFunctionEmitter",
       "An MLIRFunctionEmitter is used to fill an empty function body. This is "
index b760cf6..4c7f729 100644 (file)
@@ -69,6 +69,157 @@ class EdscTest(unittest.TestCase):
         '          %2 = "affine.apply"(%i0, %i1, %i2, %i3) {map: (d0, d1, d2, d3) -> (d0 + d1 + d2 + d3)} : (index, index, index, index) -> index',
         code)
 
+  def testBlockContext(self):
+    fun = self.module.make_function("foo", [], [])
+    with E.FunctionContext(fun):
+      cst = E.IdxCst(42)
+      with E.BlockContext():
+        cst + cst
+    code = str(fun)
+    # Find positions of instructions and make sure they are in the block we
+    # put them by comparing those positions.
+    # TODO(zinenko,ntv): this (and tests below) should use FileCheck instead.
+    c42pos = code.find("%c42 = constant 42 : index")
+    bb1pos = code.find("^bb1:")
+    c84pos = code.find('%0 = "affine.apply"() {map: () -> (84)} : () -> index')
+    self.assertNotEqual(c42pos, -1)
+    self.assertNotEqual(bb1pos, -1)
+    self.assertNotEqual(c84pos, -1)
+    self.assertGreater(bb1pos, c42pos)
+    self.assertLess(bb1pos, c84pos)
+
+  def testBlockContextAppend(self):
+    fun = self.module.make_function("foo", [], [])
+    with E.FunctionContext(fun):
+      E.IdxCst(41)
+      with E.BlockContext() as b:
+        blk = b  # save block handle for later
+        E.IdxCst(0)
+      E.IdxCst(42)
+      with E.BlockContext(E.appendTo(blk)):
+        E.IdxCst(1)
+    code = str(fun)
+    # Find positions of instructions and make sure they are in the block we put
+    # them by comparing those positions.
+    c41pos = code.find("%c41 = constant 41 : index")
+    c42pos = code.find("%c42 = constant 42 : index")
+    bb1pos = code.find("^bb1:")
+    c0pos = code.find("%c0 = constant 0 : index")
+    c1pos = code.find("%c1 = constant 1 : index")
+    self.assertNotEqual(c41pos, -1)
+    self.assertNotEqual(c42pos, -1)
+    self.assertNotEqual(bb1pos, -1)
+    self.assertNotEqual(c0pos, -1)
+    self.assertNotEqual(c1pos, -1)
+    self.assertGreater(bb1pos, c41pos)
+    self.assertGreater(bb1pos, c42pos)
+    self.assertLess(bb1pos, c0pos)
+    self.assertLess(bb1pos, c1pos)
+
+  def testBlockContextStandalone(self):
+    fun = self.module.make_function("foo", [], [])
+    with E.FunctionContext(fun):
+      blk1 = E.BlockContext()
+      blk2 = E.BlockContext()
+      with blk1:
+        E.IdxCst(0)
+      with blk2:
+        E.IdxCst(56)
+        E.IdxCst(57)
+      E.IdxCst(41)
+      with blk1:
+        E.IdxCst(1)
+      E.IdxCst(42)
+    code = str(fun)
+    # Find positions of instructions and make sure they are in the block we put
+    # them by comparing those positions.
+    c41pos = code.find("  %c41 = constant 41 : index")
+    c42pos = code.find("  %c42 = constant 42 : index")
+    bb1pos = code.find("^bb1:")
+    c0pos = code.find("  %c0 = constant 0 : index")
+    c1pos = code.find("  %c1 = constant 1 : index")
+    bb2pos = code.find("^bb2:")
+    c56pos = code.find("  %c56 = constant 56 : index")
+    c57pos = code.find("  %c57 = constant 57 : index")
+    self.assertNotEqual(c41pos, -1)
+    self.assertNotEqual(c42pos, -1)
+    self.assertNotEqual(bb1pos, -1)
+    self.assertNotEqual(c0pos, -1)
+    self.assertNotEqual(c1pos, -1)
+    self.assertNotEqual(bb2pos, -1)
+    self.assertNotEqual(c56pos, -1)
+    self.assertNotEqual(c57pos, -1)
+    self.assertGreater(bb1pos, c41pos)
+    self.assertGreater(bb1pos, c42pos)
+    self.assertLess(bb1pos, c0pos)
+    self.assertLess(bb1pos, c1pos)
+    self.assertGreater(bb2pos, c0pos)
+    self.assertGreater(bb2pos, c1pos)
+    self.assertGreater(bb2pos, bb1pos)
+    self.assertLess(bb2pos, c56pos)
+    self.assertLess(bb2pos, c57pos)
+
+
+  def testBlockArguments(self):
+    fun = self.module.make_function("foo", [], [])
+    with E.FunctionContext(fun):
+      E.IdxCst(42)
+      with E.BlockContext([self.f32Type, self.f32Type]) as b:
+        b.arg(0) + b.arg(1)
+    code = str(fun)
+    self.assertIn("%c42 = constant 42 : index", code)
+    self.assertIn("^bb1(%0: f32, %1: f32):", code)
+    self.assertIn("  %2 = addf %0, %1 : f32", code)
+
+  def testBr(self):
+    fun = self.module.make_function("foo", [], [])
+    with E.FunctionContext(fun):
+      with E.BlockContext() as b:
+        blk = b
+        E.ret()
+      E.br(blk)
+    code = str(fun)
+    self.assertIn("  br ^bb1", code)
+    self.assertIn("^bb1:", code)
+    self.assertIn("  return", code)
+
+  def testBrDeclaration(self):
+    fun = self.module.make_function("foo", [], [])
+    with E.FunctionContext(fun):
+      blk = E.BlockContext()
+      E.br(blk.handle())
+      with blk:
+        E.ret()
+    code = str(fun)
+    self.assertIn("  br ^bb1", code)
+    self.assertIn("^bb1:", code)
+    self.assertIn("  return", code)
+
+  def testBrArgs(self):
+    fun = self.module.make_function("foo", [], [])
+    with E.FunctionContext(fun):
+      # Create an infinite loop.
+      with E.BlockContext([self.indexType, self.indexType]) as b:
+        E.br(b, [b.arg(1), b.arg(0)])
+      E.br(b, [E.IdxCst(0), E.IdxCst(1)])
+    code = str(fun)
+    self.assertIn("  %c0 = constant 0 : index", code)
+    self.assertIn("  %c1 = constant 1 : index", code)
+    self.assertIn("  br ^bb1(%c0, %c1 : index, index)", code)
+    self.assertIn("^bb1(%0: index, %1: index):", code)
+    self.assertIn("  br ^bb1(%1, %0 : index, index)", code)
+
+  def testRet(self):
+    fun = self.module.make_function("foo", [], [self.indexType, self.indexType])
+    with E.FunctionContext(fun):
+      c42 = E.IdxCst(42)
+      c0 = E.IdxCst(0)
+      E.ret([c42, c0])
+    code = str(fun)
+    self.assertIn("  %c42 = constant 42 : index", code)
+    self.assertIn("  %c0 = constant 0 : index", code)
+    self.assertIn("  return %c42, %c0 : index, index", code)
+
 
   def testBindables(self):
     with E.ContextManager():