Python bindings: introduce loop and loop nest contexts
authorAlex Zinenko <zinenko@google.com>
Fri, 8 Mar 2019 15:45:26 +0000 (07:45 -0800)
committerjpienaar <jpienaar@google.com>
Sat, 30 Mar 2019 00:06:36 +0000 (17:06 -0700)
Recently, EDSC introduced an eager mode for building IR in different contexts.
Introduce Python bindings support for loop and loop nest contexts of EDSC
builders.  The eager mode is built around the notion of ValueHandle, which is
convenience class for delayed initialization and operator overloads.  Expose
this class and overloads directly.  The model of insertion contexts maps
naturally to Python context manager mechanism, therefore new bindings are
defined bypassing the C APIs.  The bindings now provide three new context
manager classes: FunctionContext, LoopContext and LoopNestContext.  The last
two can be used with the `with`-construct in Python to create loop (nests) and
obtain handles to the loop induction variables seamlessly:

    with LoopContext(lhs, rhs, 1) as i:
      lhs + rhs + i
      with LoopContext(rhs, rhs + rhs, 2) as j:
        x = i + j

Any statement within the Python context will trigger immediate emission of the
corresponding IR constructs into the context owned by the nearest context
manager.

PiperOrigin-RevId: 237447732

mlir/bindings/python/pybind.cpp
mlir/bindings/python/test/test_py2and3.py
mlir/include/mlir/EDSC/Builders.h

index 5270ea0..a2b2959 100644 (file)
@@ -7,6 +7,7 @@
 #include <unordered_map>
 
 #include "third_party/llvm/llvm/projects/google_mlir/include/mlir-c/Core.h"
+#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/EDSC/Builders.h"
 #include "third_party/llvm/llvm/projects/google_mlir/include/mlir/EDSC/MLIREmitter.h"
 #include "third_party/llvm/llvm/projects/google_mlir/include/mlir/EDSC/Types.h"
 #include "third_party/llvm/llvm/projects/google_mlir/include/mlir/ExecutionEngine/ExecutionEngine.h"
@@ -177,6 +178,25 @@ struct ContextManager {
   mlir::edsc::ScopedEDSCContext *context;
 };
 
+struct PythonFunctionContext {
+  PythonFunctionContext(PythonFunction f) : function(f) {}
+
+  void enter() {
+    assert(function.function && "function is not set up");
+    assert(context);
+    context = new mlir::edsc::ScopedContext(
+        static_cast<mlir::Function *>(function.function));
+  }
+
+  void exit(py::object, py::object, py::object) {
+    delete context;
+    context = nullptr;
+  }
+
+  PythonFunction function;
+  mlir::edsc::ScopedContext *context;
+};
+
 struct PythonExpr {
   PythonExpr() : expr{nullptr} {}
   PythonExpr(const PythonBindable &bindable);
@@ -189,6 +209,92 @@ struct PythonExpr {
   edsc_expr_t expr;
 };
 
+struct PythonValueHandle {
+  PythonValueHandle(PythonType type)
+      : value(mlir::Type::getFromOpaquePointer(type.type)) {}
+  PythonValueHandle(const PythonValueHandle &other) = default;
+  PythonValueHandle(const mlir::edsc::ValueHandle &other) : value(other) {}
+  operator ValueHandle() const { return value; }
+  operator ValueHandle &() { return value; }
+
+  std::string str() const {
+    return std::to_string(reinterpret_cast<intptr_t>(value.getValue()));
+  }
+
+  mlir::edsc::ValueHandle value;
+};
+
+struct PythonLoopContext {
+  PythonLoopContext(PythonValueHandle lb, PythonValueHandle ub, int64_t step)
+      : lb(lb), ub(ub), step(step) {}
+  PythonLoopContext(const PythonLoopContext &) = delete;
+  PythonLoopContext(PythonLoopContext &&) = default;
+  PythonLoopContext &operator=(const PythonLoopContext &) = delete;
+  PythonLoopContext &operator=(PythonLoopContext &&) = default;
+  ~PythonLoopContext() { assert(!builder && "did not exit from the context"); }
+
+  PythonValueHandle enter() {
+    ValueHandle iv(lb.value.getType());
+    builder = new LoopBuilder(&iv, lb.value, ub.value, step);
+    return iv;
+  }
+
+  void exit(py::object, py::object, py::object) {
+    (*builder)({}); // exit from the builder's scope.
+    delete builder;
+    builder = nullptr;
+  }
+
+  PythonValueHandle lb, ub;
+  int64_t step;
+  LoopBuilder *builder = nullptr;
+};
+
+struct PythonLoopNestContext {
+  PythonLoopNestContext(const std::vector<PythonValueHandle> &lbs,
+                        const std::vector<PythonValueHandle> &ubs,
+                        const std::vector<int64_t> steps)
+      : lbs(lbs), ubs(ubs), steps(steps) {
+    assert(lbs.size() == ubs.size() && lbs.size() == steps.size() &&
+           "expected the same number of lower, upper bounds, and steps");
+  }
+  PythonLoopNestContext(const PythonLoopNestContext &) = delete;
+  PythonLoopNestContext(PythonLoopNestContext &&) = default;
+  PythonLoopNestContext &operator=(const PythonLoopNestContext &) = delete;
+  PythonLoopNestContext &operator=(PythonLoopNestContext &&) = default;
+  ~PythonLoopNestContext() {
+    assert(!builder && "did not exit from the context");
+  }
+
+  std::vector<PythonValueHandle> enter() {
+    if (steps.empty())
+      return {};
+
+    auto type = mlir_type_t(lbs.front().value.getType().getAsOpaquePointer());
+    std::vector<PythonValueHandle> handles(steps.size(),
+                                           PythonValueHandle(type));
+    std::vector<ValueHandle *> handlePtrs;
+    handlePtrs.reserve(steps.size());
+    for (auto &h : handles)
+      handlePtrs.push_back(&h.value);
+    builder = new LoopNestBuilder(
+        handlePtrs, std::vector<ValueHandle>(lbs.begin(), lbs.end()),
+        std::vector<ValueHandle>(ubs.begin(), ubs.end()), steps);
+    return handles;
+  }
+
+  void exit(py::object, py::object, py::object) {
+    (*builder)({}); // exit from the builder's scope.
+    delete builder;
+    builder = nullptr;
+  }
+
+  std::vector<PythonValueHandle> lbs;
+  std::vector<PythonValueHandle> ubs;
+  std::vector<int64_t> steps;
+  LoopNestBuilder *builder = nullptr;
+};
+
 struct PythonBindable : public PythonExpr {
   explicit PythonBindable(const PythonType &type)
       : PythonExpr(edsc_expr_t{makeBindable(type.type)}) {}
@@ -228,6 +334,17 @@ struct PythonBlock {
   edsc_block_t blk;
 };
 
+struct PythonBlockHandle {
+  PythonBlockHandle() : value(nullptr) {}
+  PythonBlockHandle(const PythonBlockHandle &other) = default;
+  PythonBlockHandle(const mlir::edsc::BlockHandle &other) : value(other) {}
+  operator mlir::edsc::BlockHandle() const { return value; }
+
+  std::string str() const { return "^block"; }
+
+  mlir::edsc::BlockHandle value;
+};
+
 struct PythonAttribute {
   PythonAttribute() : attr(nullptr) {}
   PythonAttribute(const mlir_attr_t &a) : attr(a) {}
@@ -650,6 +767,26 @@ PYBIND11_MODULE(pybind, m) {
                   makeCExprs(owningUBs, ubs), makeCExprs(owningSteps, steps),
                   makeCStmts(owningStmts, stmts)));
   });
+
+  py::class_<PythonLoopContext>(
+      m, "LoopContext", "A context for building the body of a 'for' loop")
+      .def(py::init<PythonValueHandle, PythonValueHandle, int64_t>())
+      .def("__enter__", &PythonLoopContext::enter)
+      .def("__exit__", &PythonLoopContext::exit);
+
+  py::class_<PythonLoopNestContext>(m, "LoopNestContext",
+                                    "A context for building the body of a the "
+                                    "innermost loop in a nest of 'for' loops")
+      .def(py::init<const std::vector<PythonValueHandle> &,
+                    const std::vector<PythonValueHandle> &,
+                    const std::vector<int64_t> &>())
+      .def("__enter__", &PythonLoopNestContext::enter)
+      .def("__exit__", &PythonLoopNestContext::exit);
+
+  m.def("IdxCst", [](int64_t val) -> PythonValueHandle {
+    return ValueHandle(index_t(val));
+  });
+
   m.def("Max", [](const py::list &args) {
     SmallVector<edsc_expr_t, 8> owning;
     return PythonMaxExpr(::Max(makeCExprs(owning, args)));
@@ -818,6 +955,38 @@ PYBIND11_MODULE(pybind, m) {
       .def("__enter__", &ContextManager::enter)
       .def("__exit__", &ContextManager::exit);
 
+  py::class_<PythonFunctionContext>(
+      m, "FunctionContext", "A wrapper around mlir::edsc::ScopedContext")
+      .def(py::init<PythonFunction>())
+      .def("__enter__", &PythonFunctionContext::enter)
+      .def("__exit__", &PythonFunctionContext::exit);
+
+  {
+    using namespace mlir::edsc::op;
+    py::class_<PythonValueHandle>(m, "ValueHandle",
+                                  "A wrapper around mlir::edsc::ValueHandle")
+        .def(py::init<PythonType>())
+        .def(py::init<PythonValueHandle>())
+        .def("__add__",
+             [](PythonValueHandle lhs, PythonValueHandle rhs)
+                 -> PythonValueHandle { return lhs.value + rhs.value; })
+        .def("__sub__",
+             [](PythonValueHandle lhs, PythonValueHandle rhs)
+                 -> PythonValueHandle { return lhs.value - rhs.value; })
+        .def("__mul__",
+             [](PythonValueHandle lhs, PythonValueHandle rhs)
+                 -> PythonValueHandle { return lhs.value * rhs.value; })
+        .def("__div__",
+             [](PythonValueHandle lhs, PythonValueHandle rhs)
+                 -> PythonValueHandle { return lhs.value / rhs.value; })
+        .def("__truediv__",
+             [](PythonValueHandle lhs, PythonValueHandle rhs)
+                 -> PythonValueHandle { return lhs.value / rhs.value; })
+        .def("__mod__",
+             [](PythonValueHandle lhs, PythonValueHandle rhs)
+                 -> PythonValueHandle { return lhs.value % rhs.value; });
+  }
+
   py::class_<MLIRFunctionEmitter>(
       m, "MLIRFunctionEmitter",
       "An MLIRFunctionEmitter is used to fill an empty function body. This is "
index e2b0588..b760cf6 100644 (file)
@@ -17,6 +17,59 @@ class EdscTest(unittest.TestCase):
     self.f32Type = self.module.make_scalar_type("f32")
     self.indexType = self.module.make_index_type()
 
+  def testLoopContext(self):
+    fun = self.module.make_function("foo", [], [])
+    with E.FunctionContext(fun):
+      lhs = E.IdxCst(0)
+      rhs = E.IdxCst(42)
+      with E.LoopContext(lhs, rhs, 1) as i:
+        lhs + rhs + i
+        with E.LoopContext(rhs, rhs + rhs, 2) as j:
+          x = i + j
+    code = str(fun)
+    # TODO(zinenko,ntv): use FileCheck for these tests
+    self.assertIn(
+        '  "for"() {lower_bound: () -> (0), step: 1 : index, upper_bound: () -> (42)} : () -> () {\n',
+        code)
+    self.assertIn("  ^bb1(%i0: index):", code)
+    self.assertIn(
+        '    "for"(%c42, %2) {lower_bound: (d0) -> (d0), step: 2 : index, upper_bound: (d0) -> (d0)} : (index, index) -> () {\n',
+        code)
+    self.assertIn("    ^bb2(%i1: index):", code)
+    self.assertIn(
+        '      %3 = "affine.apply"(%i0, %i1) {map: (d0, d1) -> (d0 + d1)} : (index, index) -> index',
+        code)
+
+  def testLoopNestContext(self):
+    fun = self.module.make_function("foo", [], [])
+    with E.FunctionContext(fun):
+      lbs = [E.IdxCst(i) for i in range(4)]
+      ubs = [E.IdxCst(10 * i + 5) for i in range(4)]
+      with E.LoopNestContext(lbs, ubs, [1, 3, 5, 7]) as (i, j, k, l):
+        i + j + k + l
+
+    code = str(fun)
+    self.assertIn(
+        ' "for"() {lower_bound: () -> (0), step: 1 : index, upper_bound: () -> (5)} : () -> () {\n',
+        code)
+    self.assertIn("  ^bb1(%i0: index):", code)
+    self.assertIn(
+        '    "for"() {lower_bound: () -> (1), step: 3 : index, upper_bound: () -> (15)} : () -> () {\n',
+        code)
+    self.assertIn("    ^bb2(%i1: index):", code)
+    self.assertIn(
+        '      "for"() {lower_bound: () -> (2), step: 5 : index, upper_bound: () -> (25)} : () -> () {\n',
+        code)
+    self.assertIn("      ^bb3(%i2: index):", code)
+    self.assertIn(
+        '        "for"() {lower_bound: () -> (3), step: 7 : index, upper_bound: () -> (35)} : () -> () {\n',
+        code)
+    self.assertIn("        ^bb4(%i3: index):", code)
+    self.assertIn(
+        '          %2 = "affine.apply"(%i0, %i1, %i2, %i3) {map: (d0, d1, d2, d3) -> (d0 + d1 + d2 + d3)} : (index, index, index, index) -> index',
+        code)
+
+
   def testBindables(self):
     with E.ContextManager():
       i = E.Expr(E.Bindable(self.i32Type))
index d5d8c66..f013270 100644 (file)
@@ -93,6 +93,18 @@ private:
 /// BlockHandle.
 class NestedBuilder {
 protected:
+  NestedBuilder() = default;
+  NestedBuilder(const NestedBuilder &) = delete;
+  NestedBuilder(NestedBuilder &&other) : bodyScope(other.bodyScope) {
+    other.bodyScope = nullptr;
+  }
+
+  NestedBuilder &operator=(const NestedBuilder &) = delete;
+  NestedBuilder &operator=(NestedBuilder &&other) {
+    std::swap(bodyScope, other.bodyScope);
+    return *this;
+  }
+
   /// Enter an mlir::Block and setup a ScopedContext to insert instructions at
   /// the end of it. Since we cannot use c++ language-level scoping to implement
   /// scoping itself, we use enter/exit pairs of instructions.
@@ -138,6 +150,11 @@ public:
   /// *only* way to capture the loop induction variable.
   LoopBuilder(ValueHandle *iv, ArrayRef<ValueHandle> lbHandles,
               ArrayRef<ValueHandle> ubHandles, int64_t step);
+  LoopBuilder(const LoopBuilder &) = delete;
+  LoopBuilder(LoopBuilder &&) = default;
+
+  LoopBuilder &operator=(const LoopBuilder &) = delete;
+  LoopBuilder &operator=(LoopBuilder &&) = default;
 
   /// The only purpose of this operator is to serve as a sequence point so that
   /// the evaluation of `stmts` (which build IR snippets in a scoped fashion) is