#include "mlir-c/Core.h"
#include "mlir/EDSC/Builders.h"
+#include "mlir/EDSC/Intrinsics.h"
#include "mlir/EDSC/MLIREmitter.h"
#include "mlir/EDSC/Types.h"
#include "mlir/ExecutionEngine/ExecutionEngine.h"
mlir::edsc::ValueHandle value;
};
+struct PythonBlockHandle {
+ PythonBlockHandle() : value(nullptr) {}
+ PythonBlockHandle(const PythonBlockHandle &other) = default;
+ PythonBlockHandle(const mlir::edsc::BlockHandle &other) : value(other) {}
+ operator mlir::edsc::BlockHandle() const { return value; }
+
+ PythonValueHandle arg(int index) { return arguments[index]; }
+
+ std::string str() {
+ std::string s;
+ llvm::raw_string_ostream os(s);
+ value.getBlock()->print(os);
+ return os.str();
+ }
+
+ mlir::edsc::BlockHandle value;
+ std::vector<mlir::edsc::ValueHandle> arguments;
+};
+
struct PythonLoopContext {
PythonLoopContext(PythonValueHandle lb, PythonValueHandle ub, int64_t step)
: lb(lb), ub(ub), step(step) {}
LoopNestBuilder *builder = nullptr;
};
+struct PythonBlockAppender {
+ PythonBlockAppender(const PythonBlockHandle &handle) : handle(handle) {}
+ PythonBlockHandle handle;
+};
+
+struct PythonBlockContext {
+public:
+ PythonBlockContext() {
+ createBlockBuilder();
+ clearBuilder();
+ }
+ PythonBlockContext(const std::vector<PythonType> &argTypes) {
+ handle.arguments.reserve(argTypes.size());
+ for (const auto &t : argTypes) {
+ auto type =
+ Type::getFromOpaquePointer(reinterpret_cast<const void *>(t.type));
+ handle.arguments.emplace_back(type);
+ }
+ createBlockBuilder();
+ clearBuilder();
+ }
+ PythonBlockContext(const PythonBlockAppender &a) : handle(a.handle) {}
+ PythonBlockContext(const PythonBlockContext &) = delete;
+ PythonBlockContext(PythonBlockContext &&) = default;
+ PythonBlockContext &operator=(const PythonBlockContext &) = delete;
+ PythonBlockContext &operator=(PythonBlockContext &&) = default;
+ ~PythonBlockContext() {
+ assert(!builder && "did not exit from the block context");
+ }
+
+ // EDSC maintain an implicit stack of builders (mostly for keeping track of
+ // insretion points); every operation gets inserted using the top-of-the-stack
+ // builder. Creating a new EDSC Builder automatically puts it on the stack,
+ // effectively entering the block for it.
+ void createBlockBuilder() {
+ if (handle.value.getBlock()) {
+ builder = new BlockBuilder(handle.value, mlir::edsc::Append());
+ } else {
+ std::vector<ValueHandle *> args;
+ args.reserve(handle.arguments.size());
+ for (auto &a : handle.arguments)
+ args.push_back(&a);
+ builder = new BlockBuilder(&handle.value, args);
+ }
+ }
+
+ PythonBlockHandle enter() {
+ createBlockBuilder();
+ return handle;
+ }
+
+ void exit(py::object, py::object, py::object) { clearBuilder(); }
+
+ PythonBlockHandle getHandle() { return handle; }
+
+ // EDSC maintain an implicit stack of builders (mostly for keeping track of
+ // insretion points); every operation gets inserted using the top-of-the-stack
+ // builder. Calling operator() on a builder pops the builder from the stack,
+ // effectively resetting the insertion point to its position before we entered
+ // the block.
+ void clearBuilder() {
+ (*builder)({}); // exit from the builder's scope.
+ delete builder;
+ builder = nullptr;
+ }
+
+ PythonBlockHandle handle;
+ BlockBuilder *builder = nullptr;
+};
+
struct PythonBindable : public PythonExpr {
explicit PythonBindable(const PythonType &type)
: PythonExpr(edsc_expr_t{makeBindable(type.type)}) {}
edsc_block_t blk;
};
-struct PythonBlockHandle {
- PythonBlockHandle() : value(nullptr) {}
- PythonBlockHandle(const PythonBlockHandle &other) = default;
- PythonBlockHandle(const mlir::edsc::BlockHandle &other) : value(other) {}
- operator mlir::edsc::BlockHandle() const { return value; }
-
- std::string str() const { return "^block"; }
-
- mlir::edsc::BlockHandle value;
-};
-
struct PythonAttribute {
PythonAttribute() : attr(nullptr) {}
PythonAttribute(const mlir_attr_t &a) : attr(a) {}
m.def("IdxCst", [](int64_t val) -> PythonValueHandle {
return ValueHandle(index_t(val));
});
+ m.def("appendTo", [](const PythonBlockHandle &handle) {
+ return PythonBlockAppender(handle);
+ });
+ m.def(
+ "ret",
+ [](const std::vector<PythonValueHandle> &args) {
+ std::vector<ValueHandle> values(args.begin(), args.end());
+ intrinsics::RETURN(values);
+ return PythonValueHandle(nullptr);
+ },
+ py::arg("args") = std::vector<PythonValueHandle>());
+ m.def(
+ "br",
+ [](const PythonBlockHandle &dest,
+ const std::vector<PythonValueHandle> &args) {
+ std::vector<ValueHandle> values(args.begin(), args.end());
+ intrinsics::BR(dest, values);
+ return PythonValueHandle(nullptr);
+ },
+ py::arg("dest"), py::arg("args") = std::vector<PythonValueHandle>());
m.def("Max", [](const py::list &args) {
SmallVector<edsc_expr_t, 8> owning;
-> PythonValueHandle { return lhs.value % rhs.value; });
}
+ py::class_<PythonBlockAppender>(
+ m, "BlockAppender",
+ "A dummy class signaling BlockContext to append IR to the given block "
+ "instead of creating a new block")
+ .def(py::init<const PythonBlockHandle &>());
+ py::class_<PythonBlockHandle>(m, "BlockHandle",
+ "A wrapper around mlir::edsc::BlockHandle")
+ .def(py::init<PythonBlockHandle>())
+ .def("arg", &PythonBlockHandle::arg);
+
+ py::class_<PythonBlockContext>(m, "BlockContext",
+ "A wrapper around mlir::edsc::BlockBuilder")
+ .def(py::init<>())
+ .def(py::init<const std::vector<PythonType> &>())
+ .def(py::init<const PythonBlockAppender &>())
+ .def("__enter__", &PythonBlockContext::enter)
+ .def("__exit__", &PythonBlockContext::exit)
+ .def("handle", &PythonBlockContext::getHandle);
+
py::class_<MLIRFunctionEmitter>(
m, "MLIRFunctionEmitter",
"An MLIRFunctionEmitter is used to fill an empty function body. This is "
' %2 = "affine.apply"(%i0, %i1, %i2, %i3) {map: (d0, d1, d2, d3) -> (d0 + d1 + d2 + d3)} : (index, index, index, index) -> index',
code)
+ def testBlockContext(self):
+ fun = self.module.make_function("foo", [], [])
+ with E.FunctionContext(fun):
+ cst = E.IdxCst(42)
+ with E.BlockContext():
+ cst + cst
+ code = str(fun)
+ # Find positions of instructions and make sure they are in the block we
+ # put them by comparing those positions.
+ # TODO(zinenko,ntv): this (and tests below) should use FileCheck instead.
+ c42pos = code.find("%c42 = constant 42 : index")
+ bb1pos = code.find("^bb1:")
+ c84pos = code.find('%0 = "affine.apply"() {map: () -> (84)} : () -> index')
+ self.assertNotEqual(c42pos, -1)
+ self.assertNotEqual(bb1pos, -1)
+ self.assertNotEqual(c84pos, -1)
+ self.assertGreater(bb1pos, c42pos)
+ self.assertLess(bb1pos, c84pos)
+
+ def testBlockContextAppend(self):
+ fun = self.module.make_function("foo", [], [])
+ with E.FunctionContext(fun):
+ E.IdxCst(41)
+ with E.BlockContext() as b:
+ blk = b # save block handle for later
+ E.IdxCst(0)
+ E.IdxCst(42)
+ with E.BlockContext(E.appendTo(blk)):
+ E.IdxCst(1)
+ code = str(fun)
+ # Find positions of instructions and make sure they are in the block we put
+ # them by comparing those positions.
+ c41pos = code.find("%c41 = constant 41 : index")
+ c42pos = code.find("%c42 = constant 42 : index")
+ bb1pos = code.find("^bb1:")
+ c0pos = code.find("%c0 = constant 0 : index")
+ c1pos = code.find("%c1 = constant 1 : index")
+ self.assertNotEqual(c41pos, -1)
+ self.assertNotEqual(c42pos, -1)
+ self.assertNotEqual(bb1pos, -1)
+ self.assertNotEqual(c0pos, -1)
+ self.assertNotEqual(c1pos, -1)
+ self.assertGreater(bb1pos, c41pos)
+ self.assertGreater(bb1pos, c42pos)
+ self.assertLess(bb1pos, c0pos)
+ self.assertLess(bb1pos, c1pos)
+
+ def testBlockContextStandalone(self):
+ fun = self.module.make_function("foo", [], [])
+ with E.FunctionContext(fun):
+ blk1 = E.BlockContext()
+ blk2 = E.BlockContext()
+ with blk1:
+ E.IdxCst(0)
+ with blk2:
+ E.IdxCst(56)
+ E.IdxCst(57)
+ E.IdxCst(41)
+ with blk1:
+ E.IdxCst(1)
+ E.IdxCst(42)
+ code = str(fun)
+ # Find positions of instructions and make sure they are in the block we put
+ # them by comparing those positions.
+ c41pos = code.find(" %c41 = constant 41 : index")
+ c42pos = code.find(" %c42 = constant 42 : index")
+ bb1pos = code.find("^bb1:")
+ c0pos = code.find(" %c0 = constant 0 : index")
+ c1pos = code.find(" %c1 = constant 1 : index")
+ bb2pos = code.find("^bb2:")
+ c56pos = code.find(" %c56 = constant 56 : index")
+ c57pos = code.find(" %c57 = constant 57 : index")
+ self.assertNotEqual(c41pos, -1)
+ self.assertNotEqual(c42pos, -1)
+ self.assertNotEqual(bb1pos, -1)
+ self.assertNotEqual(c0pos, -1)
+ self.assertNotEqual(c1pos, -1)
+ self.assertNotEqual(bb2pos, -1)
+ self.assertNotEqual(c56pos, -1)
+ self.assertNotEqual(c57pos, -1)
+ self.assertGreater(bb1pos, c41pos)
+ self.assertGreater(bb1pos, c42pos)
+ self.assertLess(bb1pos, c0pos)
+ self.assertLess(bb1pos, c1pos)
+ self.assertGreater(bb2pos, c0pos)
+ self.assertGreater(bb2pos, c1pos)
+ self.assertGreater(bb2pos, bb1pos)
+ self.assertLess(bb2pos, c56pos)
+ self.assertLess(bb2pos, c57pos)
+
+
+ def testBlockArguments(self):
+ fun = self.module.make_function("foo", [], [])
+ with E.FunctionContext(fun):
+ E.IdxCst(42)
+ with E.BlockContext([self.f32Type, self.f32Type]) as b:
+ b.arg(0) + b.arg(1)
+ code = str(fun)
+ self.assertIn("%c42 = constant 42 : index", code)
+ self.assertIn("^bb1(%0: f32, %1: f32):", code)
+ self.assertIn(" %2 = addf %0, %1 : f32", code)
+
+ def testBr(self):
+ fun = self.module.make_function("foo", [], [])
+ with E.FunctionContext(fun):
+ with E.BlockContext() as b:
+ blk = b
+ E.ret()
+ E.br(blk)
+ code = str(fun)
+ self.assertIn(" br ^bb1", code)
+ self.assertIn("^bb1:", code)
+ self.assertIn(" return", code)
+
+ def testBrDeclaration(self):
+ fun = self.module.make_function("foo", [], [])
+ with E.FunctionContext(fun):
+ blk = E.BlockContext()
+ E.br(blk.handle())
+ with blk:
+ E.ret()
+ code = str(fun)
+ self.assertIn(" br ^bb1", code)
+ self.assertIn("^bb1:", code)
+ self.assertIn(" return", code)
+
+ def testBrArgs(self):
+ fun = self.module.make_function("foo", [], [])
+ with E.FunctionContext(fun):
+ # Create an infinite loop.
+ with E.BlockContext([self.indexType, self.indexType]) as b:
+ E.br(b, [b.arg(1), b.arg(0)])
+ E.br(b, [E.IdxCst(0), E.IdxCst(1)])
+ code = str(fun)
+ self.assertIn(" %c0 = constant 0 : index", code)
+ self.assertIn(" %c1 = constant 1 : index", code)
+ self.assertIn(" br ^bb1(%c0, %c1 : index, index)", code)
+ self.assertIn("^bb1(%0: index, %1: index):", code)
+ self.assertIn(" br ^bb1(%1, %0 : index, index)", code)
+
+ def testRet(self):
+ fun = self.module.make_function("foo", [], [self.indexType, self.indexType])
+ with E.FunctionContext(fun):
+ c42 = E.IdxCst(42)
+ c0 = E.IdxCst(0)
+ E.ret([c42, c0])
+ code = str(fun)
+ self.assertIn(" %c42 = constant 42 : index", code)
+ self.assertIn(" %c0 = constant 0 : index", code)
+ self.assertIn(" return %c42, %c0 : index, index", code)
+
def testBindables(self):
with E.ContextManager():