From 8b4b9b31f19846d9721b5cf9ec20d1e97a38707c Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Fri, 8 Mar 2019 07:45:26 -0800 Subject: [PATCH] Python bindings: introduce loop and loop nest contexts Recently, EDSC introduced an eager mode for building IR in different contexts. Introduce Python bindings support for loop and loop nest contexts of EDSC builders. The eager mode is built around the notion of ValueHandle, which is convenience class for delayed initialization and operator overloads. Expose this class and overloads directly. The model of insertion contexts maps naturally to Python context manager mechanism, therefore new bindings are defined bypassing the C APIs. The bindings now provide three new context manager classes: FunctionContext, LoopContext and LoopNestContext. The last two can be used with the `with`-construct in Python to create loop (nests) and obtain handles to the loop induction variables seamlessly: with LoopContext(lhs, rhs, 1) as i: lhs + rhs + i with LoopContext(rhs, rhs + rhs, 2) as j: x = i + j Any statement within the Python context will trigger immediate emission of the corresponding IR constructs into the context owned by the nearest context manager. PiperOrigin-RevId: 237447732 --- mlir/bindings/python/pybind.cpp | 169 ++++++++++++++++++++++++++++++ mlir/bindings/python/test/test_py2and3.py | 53 ++++++++++ mlir/include/mlir/EDSC/Builders.h | 17 +++ 3 files changed, 239 insertions(+) diff --git a/mlir/bindings/python/pybind.cpp b/mlir/bindings/python/pybind.cpp index 5270ea0..a2b2959 100644 --- a/mlir/bindings/python/pybind.cpp +++ b/mlir/bindings/python/pybind.cpp @@ -7,6 +7,7 @@ #include #include "third_party/llvm/llvm/projects/google_mlir/include/mlir-c/Core.h" +#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/EDSC/Builders.h" #include "third_party/llvm/llvm/projects/google_mlir/include/mlir/EDSC/MLIREmitter.h" #include "third_party/llvm/llvm/projects/google_mlir/include/mlir/EDSC/Types.h" #include "third_party/llvm/llvm/projects/google_mlir/include/mlir/ExecutionEngine/ExecutionEngine.h" @@ -177,6 +178,25 @@ struct ContextManager { mlir::edsc::ScopedEDSCContext *context; }; +struct PythonFunctionContext { + PythonFunctionContext(PythonFunction f) : function(f) {} + + void enter() { + assert(function.function && "function is not set up"); + assert(context); + context = new mlir::edsc::ScopedContext( + static_cast(function.function)); + } + + void exit(py::object, py::object, py::object) { + delete context; + context = nullptr; + } + + PythonFunction function; + mlir::edsc::ScopedContext *context; +}; + struct PythonExpr { PythonExpr() : expr{nullptr} {} PythonExpr(const PythonBindable &bindable); @@ -189,6 +209,92 @@ struct PythonExpr { edsc_expr_t expr; }; +struct PythonValueHandle { + PythonValueHandle(PythonType type) + : value(mlir::Type::getFromOpaquePointer(type.type)) {} + PythonValueHandle(const PythonValueHandle &other) = default; + PythonValueHandle(const mlir::edsc::ValueHandle &other) : value(other) {} + operator ValueHandle() const { return value; } + operator ValueHandle &() { return value; } + + std::string str() const { + return std::to_string(reinterpret_cast(value.getValue())); + } + + mlir::edsc::ValueHandle value; +}; + +struct PythonLoopContext { + PythonLoopContext(PythonValueHandle lb, PythonValueHandle ub, int64_t step) + : lb(lb), ub(ub), step(step) {} + PythonLoopContext(const PythonLoopContext &) = delete; + PythonLoopContext(PythonLoopContext &&) = default; + PythonLoopContext &operator=(const PythonLoopContext &) = delete; + PythonLoopContext &operator=(PythonLoopContext &&) = default; + ~PythonLoopContext() { assert(!builder && "did not exit from the context"); } + + PythonValueHandle enter() { + ValueHandle iv(lb.value.getType()); + builder = new LoopBuilder(&iv, lb.value, ub.value, step); + return iv; + } + + void exit(py::object, py::object, py::object) { + (*builder)({}); // exit from the builder's scope. + delete builder; + builder = nullptr; + } + + PythonValueHandle lb, ub; + int64_t step; + LoopBuilder *builder = nullptr; +}; + +struct PythonLoopNestContext { + PythonLoopNestContext(const std::vector &lbs, + const std::vector &ubs, + const std::vector steps) + : lbs(lbs), ubs(ubs), steps(steps) { + assert(lbs.size() == ubs.size() && lbs.size() == steps.size() && + "expected the same number of lower, upper bounds, and steps"); + } + PythonLoopNestContext(const PythonLoopNestContext &) = delete; + PythonLoopNestContext(PythonLoopNestContext &&) = default; + PythonLoopNestContext &operator=(const PythonLoopNestContext &) = delete; + PythonLoopNestContext &operator=(PythonLoopNestContext &&) = default; + ~PythonLoopNestContext() { + assert(!builder && "did not exit from the context"); + } + + std::vector enter() { + if (steps.empty()) + return {}; + + auto type = mlir_type_t(lbs.front().value.getType().getAsOpaquePointer()); + std::vector handles(steps.size(), + PythonValueHandle(type)); + std::vector handlePtrs; + handlePtrs.reserve(steps.size()); + for (auto &h : handles) + handlePtrs.push_back(&h.value); + builder = new LoopNestBuilder( + handlePtrs, std::vector(lbs.begin(), lbs.end()), + std::vector(ubs.begin(), ubs.end()), steps); + return handles; + } + + void exit(py::object, py::object, py::object) { + (*builder)({}); // exit from the builder's scope. + delete builder; + builder = nullptr; + } + + std::vector lbs; + std::vector ubs; + std::vector steps; + LoopNestBuilder *builder = nullptr; +}; + struct PythonBindable : public PythonExpr { explicit PythonBindable(const PythonType &type) : PythonExpr(edsc_expr_t{makeBindable(type.type)}) {} @@ -228,6 +334,17 @@ struct PythonBlock { edsc_block_t blk; }; +struct PythonBlockHandle { + PythonBlockHandle() : value(nullptr) {} + PythonBlockHandle(const PythonBlockHandle &other) = default; + PythonBlockHandle(const mlir::edsc::BlockHandle &other) : value(other) {} + operator mlir::edsc::BlockHandle() const { return value; } + + std::string str() const { return "^block"; } + + mlir::edsc::BlockHandle value; +}; + struct PythonAttribute { PythonAttribute() : attr(nullptr) {} PythonAttribute(const mlir_attr_t &a) : attr(a) {} @@ -650,6 +767,26 @@ PYBIND11_MODULE(pybind, m) { makeCExprs(owningUBs, ubs), makeCExprs(owningSteps, steps), makeCStmts(owningStmts, stmts))); }); + + py::class_( + m, "LoopContext", "A context for building the body of a 'for' loop") + .def(py::init()) + .def("__enter__", &PythonLoopContext::enter) + .def("__exit__", &PythonLoopContext::exit); + + py::class_(m, "LoopNestContext", + "A context for building the body of a the " + "innermost loop in a nest of 'for' loops") + .def(py::init &, + const std::vector &, + const std::vector &>()) + .def("__enter__", &PythonLoopNestContext::enter) + .def("__exit__", &PythonLoopNestContext::exit); + + m.def("IdxCst", [](int64_t val) -> PythonValueHandle { + return ValueHandle(index_t(val)); + }); + m.def("Max", [](const py::list &args) { SmallVector owning; return PythonMaxExpr(::Max(makeCExprs(owning, args))); @@ -818,6 +955,38 @@ PYBIND11_MODULE(pybind, m) { .def("__enter__", &ContextManager::enter) .def("__exit__", &ContextManager::exit); + py::class_( + m, "FunctionContext", "A wrapper around mlir::edsc::ScopedContext") + .def(py::init()) + .def("__enter__", &PythonFunctionContext::enter) + .def("__exit__", &PythonFunctionContext::exit); + + { + using namespace mlir::edsc::op; + py::class_(m, "ValueHandle", + "A wrapper around mlir::edsc::ValueHandle") + .def(py::init()) + .def(py::init()) + .def("__add__", + [](PythonValueHandle lhs, PythonValueHandle rhs) + -> PythonValueHandle { return lhs.value + rhs.value; }) + .def("__sub__", + [](PythonValueHandle lhs, PythonValueHandle rhs) + -> PythonValueHandle { return lhs.value - rhs.value; }) + .def("__mul__", + [](PythonValueHandle lhs, PythonValueHandle rhs) + -> PythonValueHandle { return lhs.value * rhs.value; }) + .def("__div__", + [](PythonValueHandle lhs, PythonValueHandle rhs) + -> PythonValueHandle { return lhs.value / rhs.value; }) + .def("__truediv__", + [](PythonValueHandle lhs, PythonValueHandle rhs) + -> PythonValueHandle { return lhs.value / rhs.value; }) + .def("__mod__", + [](PythonValueHandle lhs, PythonValueHandle rhs) + -> PythonValueHandle { return lhs.value % rhs.value; }); + } + py::class_( m, "MLIRFunctionEmitter", "An MLIRFunctionEmitter is used to fill an empty function body. This is " diff --git a/mlir/bindings/python/test/test_py2and3.py b/mlir/bindings/python/test/test_py2and3.py index e2b0588..b760cf6 100644 --- a/mlir/bindings/python/test/test_py2and3.py +++ b/mlir/bindings/python/test/test_py2and3.py @@ -17,6 +17,59 @@ class EdscTest(unittest.TestCase): self.f32Type = self.module.make_scalar_type("f32") self.indexType = self.module.make_index_type() + def testLoopContext(self): + fun = self.module.make_function("foo", [], []) + with E.FunctionContext(fun): + lhs = E.IdxCst(0) + rhs = E.IdxCst(42) + with E.LoopContext(lhs, rhs, 1) as i: + lhs + rhs + i + with E.LoopContext(rhs, rhs + rhs, 2) as j: + x = i + j + code = str(fun) + # TODO(zinenko,ntv): use FileCheck for these tests + self.assertIn( + ' "for"() {lower_bound: () -> (0), step: 1 : index, upper_bound: () -> (42)} : () -> () {\n', + code) + self.assertIn(" ^bb1(%i0: index):", code) + self.assertIn( + ' "for"(%c42, %2) {lower_bound: (d0) -> (d0), step: 2 : index, upper_bound: (d0) -> (d0)} : (index, index) -> () {\n', + code) + self.assertIn(" ^bb2(%i1: index):", code) + self.assertIn( + ' %3 = "affine.apply"(%i0, %i1) {map: (d0, d1) -> (d0 + d1)} : (index, index) -> index', + code) + + def testLoopNestContext(self): + fun = self.module.make_function("foo", [], []) + with E.FunctionContext(fun): + lbs = [E.IdxCst(i) for i in range(4)] + ubs = [E.IdxCst(10 * i + 5) for i in range(4)] + with E.LoopNestContext(lbs, ubs, [1, 3, 5, 7]) as (i, j, k, l): + i + j + k + l + + code = str(fun) + self.assertIn( + ' "for"() {lower_bound: () -> (0), step: 1 : index, upper_bound: () -> (5)} : () -> () {\n', + code) + self.assertIn(" ^bb1(%i0: index):", code) + self.assertIn( + ' "for"() {lower_bound: () -> (1), step: 3 : index, upper_bound: () -> (15)} : () -> () {\n', + code) + self.assertIn(" ^bb2(%i1: index):", code) + self.assertIn( + ' "for"() {lower_bound: () -> (2), step: 5 : index, upper_bound: () -> (25)} : () -> () {\n', + code) + self.assertIn(" ^bb3(%i2: index):", code) + self.assertIn( + ' "for"() {lower_bound: () -> (3), step: 7 : index, upper_bound: () -> (35)} : () -> () {\n', + code) + self.assertIn(" ^bb4(%i3: index):", code) + self.assertIn( + ' %2 = "affine.apply"(%i0, %i1, %i2, %i3) {map: (d0, d1, d2, d3) -> (d0 + d1 + d2 + d3)} : (index, index, index, index) -> index', + code) + + def testBindables(self): with E.ContextManager(): i = E.Expr(E.Bindable(self.i32Type)) diff --git a/mlir/include/mlir/EDSC/Builders.h b/mlir/include/mlir/EDSC/Builders.h index d5d8c66..f013270 100644 --- a/mlir/include/mlir/EDSC/Builders.h +++ b/mlir/include/mlir/EDSC/Builders.h @@ -93,6 +93,18 @@ private: /// BlockHandle. class NestedBuilder { protected: + NestedBuilder() = default; + NestedBuilder(const NestedBuilder &) = delete; + NestedBuilder(NestedBuilder &&other) : bodyScope(other.bodyScope) { + other.bodyScope = nullptr; + } + + NestedBuilder &operator=(const NestedBuilder &) = delete; + NestedBuilder &operator=(NestedBuilder &&other) { + std::swap(bodyScope, other.bodyScope); + return *this; + } + /// Enter an mlir::Block and setup a ScopedContext to insert instructions at /// the end of it. Since we cannot use c++ language-level scoping to implement /// scoping itself, we use enter/exit pairs of instructions. @@ -138,6 +150,11 @@ public: /// *only* way to capture the loop induction variable. LoopBuilder(ValueHandle *iv, ArrayRef lbHandles, ArrayRef ubHandles, int64_t step); + LoopBuilder(const LoopBuilder &) = delete; + LoopBuilder(LoopBuilder &&) = default; + + LoopBuilder &operator=(const LoopBuilder &) = delete; + LoopBuilder &operator=(LoopBuilder &&) = default; /// The only purpose of this operator is to serve as a sequence point so that /// the evaluation of `stmts` (which build IR snippets in a scoped fashion) is -- 2.7.4