#include <unordered_map>
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir-c/Core.h"
+#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/EDSC/Builders.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/EDSC/MLIREmitter.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/EDSC/Types.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/ExecutionEngine/ExecutionEngine.h"
mlir::edsc::ScopedEDSCContext *context;
};
+struct PythonFunctionContext {
+ PythonFunctionContext(PythonFunction f) : function(f) {}
+
+ void enter() {
+ assert(function.function && "function is not set up");
+ assert(context);
+ context = new mlir::edsc::ScopedContext(
+ static_cast<mlir::Function *>(function.function));
+ }
+
+ void exit(py::object, py::object, py::object) {
+ delete context;
+ context = nullptr;
+ }
+
+ PythonFunction function;
+ mlir::edsc::ScopedContext *context;
+};
+
struct PythonExpr {
PythonExpr() : expr{nullptr} {}
PythonExpr(const PythonBindable &bindable);
edsc_expr_t expr;
};
+struct PythonValueHandle {
+ PythonValueHandle(PythonType type)
+ : value(mlir::Type::getFromOpaquePointer(type.type)) {}
+ PythonValueHandle(const PythonValueHandle &other) = default;
+ PythonValueHandle(const mlir::edsc::ValueHandle &other) : value(other) {}
+ operator ValueHandle() const { return value; }
+ operator ValueHandle &() { return value; }
+
+ std::string str() const {
+ return std::to_string(reinterpret_cast<intptr_t>(value.getValue()));
+ }
+
+ mlir::edsc::ValueHandle value;
+};
+
+struct PythonLoopContext {
+ PythonLoopContext(PythonValueHandle lb, PythonValueHandle ub, int64_t step)
+ : lb(lb), ub(ub), step(step) {}
+ PythonLoopContext(const PythonLoopContext &) = delete;
+ PythonLoopContext(PythonLoopContext &&) = default;
+ PythonLoopContext &operator=(const PythonLoopContext &) = delete;
+ PythonLoopContext &operator=(PythonLoopContext &&) = default;
+ ~PythonLoopContext() { assert(!builder && "did not exit from the context"); }
+
+ PythonValueHandle enter() {
+ ValueHandle iv(lb.value.getType());
+ builder = new LoopBuilder(&iv, lb.value, ub.value, step);
+ return iv;
+ }
+
+ void exit(py::object, py::object, py::object) {
+ (*builder)({}); // exit from the builder's scope.
+ delete builder;
+ builder = nullptr;
+ }
+
+ PythonValueHandle lb, ub;
+ int64_t step;
+ LoopBuilder *builder = nullptr;
+};
+
+struct PythonLoopNestContext {
+ PythonLoopNestContext(const std::vector<PythonValueHandle> &lbs,
+ const std::vector<PythonValueHandle> &ubs,
+ const std::vector<int64_t> steps)
+ : lbs(lbs), ubs(ubs), steps(steps) {
+ assert(lbs.size() == ubs.size() && lbs.size() == steps.size() &&
+ "expected the same number of lower, upper bounds, and steps");
+ }
+ PythonLoopNestContext(const PythonLoopNestContext &) = delete;
+ PythonLoopNestContext(PythonLoopNestContext &&) = default;
+ PythonLoopNestContext &operator=(const PythonLoopNestContext &) = delete;
+ PythonLoopNestContext &operator=(PythonLoopNestContext &&) = default;
+ ~PythonLoopNestContext() {
+ assert(!builder && "did not exit from the context");
+ }
+
+ std::vector<PythonValueHandle> enter() {
+ if (steps.empty())
+ return {};
+
+ auto type = mlir_type_t(lbs.front().value.getType().getAsOpaquePointer());
+ std::vector<PythonValueHandle> handles(steps.size(),
+ PythonValueHandle(type));
+ std::vector<ValueHandle *> handlePtrs;
+ handlePtrs.reserve(steps.size());
+ for (auto &h : handles)
+ handlePtrs.push_back(&h.value);
+ builder = new LoopNestBuilder(
+ handlePtrs, std::vector<ValueHandle>(lbs.begin(), lbs.end()),
+ std::vector<ValueHandle>(ubs.begin(), ubs.end()), steps);
+ return handles;
+ }
+
+ void exit(py::object, py::object, py::object) {
+ (*builder)({}); // exit from the builder's scope.
+ delete builder;
+ builder = nullptr;
+ }
+
+ std::vector<PythonValueHandle> lbs;
+ std::vector<PythonValueHandle> ubs;
+ std::vector<int64_t> steps;
+ LoopNestBuilder *builder = nullptr;
+};
+
struct PythonBindable : public PythonExpr {
explicit PythonBindable(const PythonType &type)
: PythonExpr(edsc_expr_t{makeBindable(type.type)}) {}
edsc_block_t blk;
};
+struct PythonBlockHandle {
+ PythonBlockHandle() : value(nullptr) {}
+ PythonBlockHandle(const PythonBlockHandle &other) = default;
+ PythonBlockHandle(const mlir::edsc::BlockHandle &other) : value(other) {}
+ operator mlir::edsc::BlockHandle() const { return value; }
+
+ std::string str() const { return "^block"; }
+
+ mlir::edsc::BlockHandle value;
+};
+
struct PythonAttribute {
PythonAttribute() : attr(nullptr) {}
PythonAttribute(const mlir_attr_t &a) : attr(a) {}
makeCExprs(owningUBs, ubs), makeCExprs(owningSteps, steps),
makeCStmts(owningStmts, stmts)));
});
+
+ py::class_<PythonLoopContext>(
+ m, "LoopContext", "A context for building the body of a 'for' loop")
+ .def(py::init<PythonValueHandle, PythonValueHandle, int64_t>())
+ .def("__enter__", &PythonLoopContext::enter)
+ .def("__exit__", &PythonLoopContext::exit);
+
+ py::class_<PythonLoopNestContext>(m, "LoopNestContext",
+ "A context for building the body of a the "
+ "innermost loop in a nest of 'for' loops")
+ .def(py::init<const std::vector<PythonValueHandle> &,
+ const std::vector<PythonValueHandle> &,
+ const std::vector<int64_t> &>())
+ .def("__enter__", &PythonLoopNestContext::enter)
+ .def("__exit__", &PythonLoopNestContext::exit);
+
+ m.def("IdxCst", [](int64_t val) -> PythonValueHandle {
+ return ValueHandle(index_t(val));
+ });
+
m.def("Max", [](const py::list &args) {
SmallVector<edsc_expr_t, 8> owning;
return PythonMaxExpr(::Max(makeCExprs(owning, args)));
.def("__enter__", &ContextManager::enter)
.def("__exit__", &ContextManager::exit);
+ py::class_<PythonFunctionContext>(
+ m, "FunctionContext", "A wrapper around mlir::edsc::ScopedContext")
+ .def(py::init<PythonFunction>())
+ .def("__enter__", &PythonFunctionContext::enter)
+ .def("__exit__", &PythonFunctionContext::exit);
+
+ {
+ using namespace mlir::edsc::op;
+ py::class_<PythonValueHandle>(m, "ValueHandle",
+ "A wrapper around mlir::edsc::ValueHandle")
+ .def(py::init<PythonType>())
+ .def(py::init<PythonValueHandle>())
+ .def("__add__",
+ [](PythonValueHandle lhs, PythonValueHandle rhs)
+ -> PythonValueHandle { return lhs.value + rhs.value; })
+ .def("__sub__",
+ [](PythonValueHandle lhs, PythonValueHandle rhs)
+ -> PythonValueHandle { return lhs.value - rhs.value; })
+ .def("__mul__",
+ [](PythonValueHandle lhs, PythonValueHandle rhs)
+ -> PythonValueHandle { return lhs.value * rhs.value; })
+ .def("__div__",
+ [](PythonValueHandle lhs, PythonValueHandle rhs)
+ -> PythonValueHandle { return lhs.value / rhs.value; })
+ .def("__truediv__",
+ [](PythonValueHandle lhs, PythonValueHandle rhs)
+ -> PythonValueHandle { return lhs.value / rhs.value; })
+ .def("__mod__",
+ [](PythonValueHandle lhs, PythonValueHandle rhs)
+ -> PythonValueHandle { return lhs.value % rhs.value; });
+ }
+
py::class_<MLIRFunctionEmitter>(
m, "MLIRFunctionEmitter",
"An MLIRFunctionEmitter is used to fill an empty function body. This is "
self.f32Type = self.module.make_scalar_type("f32")
self.indexType = self.module.make_index_type()
+ def testLoopContext(self):
+ fun = self.module.make_function("foo", [], [])
+ with E.FunctionContext(fun):
+ lhs = E.IdxCst(0)
+ rhs = E.IdxCst(42)
+ with E.LoopContext(lhs, rhs, 1) as i:
+ lhs + rhs + i
+ with E.LoopContext(rhs, rhs + rhs, 2) as j:
+ x = i + j
+ code = str(fun)
+ # TODO(zinenko,ntv): use FileCheck for these tests
+ self.assertIn(
+ ' "for"() {lower_bound: () -> (0), step: 1 : index, upper_bound: () -> (42)} : () -> () {\n',
+ code)
+ self.assertIn(" ^bb1(%i0: index):", code)
+ self.assertIn(
+ ' "for"(%c42, %2) {lower_bound: (d0) -> (d0), step: 2 : index, upper_bound: (d0) -> (d0)} : (index, index) -> () {\n',
+ code)
+ self.assertIn(" ^bb2(%i1: index):", code)
+ self.assertIn(
+ ' %3 = "affine.apply"(%i0, %i1) {map: (d0, d1) -> (d0 + d1)} : (index, index) -> index',
+ code)
+
+ def testLoopNestContext(self):
+ fun = self.module.make_function("foo", [], [])
+ with E.FunctionContext(fun):
+ lbs = [E.IdxCst(i) for i in range(4)]
+ ubs = [E.IdxCst(10 * i + 5) for i in range(4)]
+ with E.LoopNestContext(lbs, ubs, [1, 3, 5, 7]) as (i, j, k, l):
+ i + j + k + l
+
+ code = str(fun)
+ self.assertIn(
+ ' "for"() {lower_bound: () -> (0), step: 1 : index, upper_bound: () -> (5)} : () -> () {\n',
+ code)
+ self.assertIn(" ^bb1(%i0: index):", code)
+ self.assertIn(
+ ' "for"() {lower_bound: () -> (1), step: 3 : index, upper_bound: () -> (15)} : () -> () {\n',
+ code)
+ self.assertIn(" ^bb2(%i1: index):", code)
+ self.assertIn(
+ ' "for"() {lower_bound: () -> (2), step: 5 : index, upper_bound: () -> (25)} : () -> () {\n',
+ code)
+ self.assertIn(" ^bb3(%i2: index):", code)
+ self.assertIn(
+ ' "for"() {lower_bound: () -> (3), step: 7 : index, upper_bound: () -> (35)} : () -> () {\n',
+ code)
+ self.assertIn(" ^bb4(%i3: index):", code)
+ self.assertIn(
+ ' %2 = "affine.apply"(%i0, %i1, %i2, %i3) {map: (d0, d1, d2, d3) -> (d0 + d1 + d2 + d3)} : (index, index, index, index) -> index',
+ code)
+
+
def testBindables(self):
with E.ContextManager():
i = E.Expr(E.Bindable(self.i32Type))