Add basic JIT Python Bindings
authorMehdi Amini <joker.eph@gmail.com>
Tue, 23 Feb 2021 01:56:01 +0000 (01:56 +0000)
committerMehdi Amini <joker.eph@gmail.com>
Wed, 3 Mar 2021 18:19:40 +0000 (18:19 +0000)
This offers the ability to create a JIT and invoke a function by passing
ctypes pointers to the argument and the result.

Differential Revision: https://reviews.llvm.org/D97523

14 files changed:
mlir/include/mlir-c/Bindings/Python/Interop.h
mlir/include/mlir-c/ExecutionEngine.h
mlir/lib/Bindings/Python/CMakeLists.txt
mlir/lib/Bindings/Python/Conversions/CMakeLists.txt [new file with mode: 0644]
mlir/lib/Bindings/Python/Conversions/Conversions.cpp [new file with mode: 0644]
mlir/lib/Bindings/Python/ExecutionEngine.cpp [new file with mode: 0644]
mlir/lib/Bindings/Python/ExecutionEngine.h [new file with mode: 0644]
mlir/lib/Bindings/Python/MainModule.cpp
mlir/lib/Bindings/Python/mlir/__init__.py
mlir/lib/Bindings/Python/mlir/conversions/__init__.py [new file with mode: 0644]
mlir/lib/Bindings/Python/mlir/execution_engine.py [new file with mode: 0644]
mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp
mlir/test/Bindings/Python/execution_engine.py [new file with mode: 0644]
mlir/test/CMakeLists.txt

index 506479e..d853159 100644 (file)
@@ -25,6 +25,7 @@
 
 #include "mlir-c/AffineExpr.h"
 #include "mlir-c/AffineMap.h"
+#include "mlir-c/ExecutionEngine.h"
 #include "mlir-c/IR.h"
 #include "mlir-c/IntegerSet.h"
 #include "mlir-c/Pass.h"
@@ -33,6 +34,8 @@
 #define MLIR_PYTHON_CAPSULE_AFFINE_MAP "mlir.ir.AffineMap._CAPIPtr"
 #define MLIR_PYTHON_CAPSULE_ATTRIBUTE "mlir.ir.Attribute._CAPIPtr"
 #define MLIR_PYTHON_CAPSULE_CONTEXT "mlir.ir.Context._CAPIPtr"
+#define MLIR_PYTHON_CAPSULE_EXECUTION_ENGINE                                   \
+  "mlir.execution_engine.ExecutionEngine._CAPIPtr"
 #define MLIR_PYTHON_CAPSULE_INTEGER_SET "mlir.ir.IntegerSet._CAPIPtr"
 #define MLIR_PYTHON_CAPSULE_LOCATION "mlir.ir.Location._CAPIPtr"
 #define MLIR_PYTHON_CAPSULE_MODULE "mlir.ir.Module._CAPIPtr"
@@ -261,6 +264,27 @@ static inline MlirIntegerSet mlirPythonCapsuleToIntegerSet(PyObject *capsule) {
   return integerSet;
 }
 
+/** Creates a capsule object encapsulating the raw C-API MlirExecutionEngine.
+ * The returned capsule does not extend or affect ownership of any Python
+ * objects that reference the set in any way. */
+static inline PyObject *
+mlirPythonExecutionEngineToCapsule(MlirExecutionEngine jit) {
+  return PyCapsule_New(MLIR_PYTHON_GET_WRAPPED_POINTER(jit),
+                       MLIR_PYTHON_CAPSULE_EXECUTION_ENGINE, NULL);
+}
+
+/** Extracts an MlirExecutionEngine from a capsule as produced from
+ * mlirPythonIntegerSetToCapsule. If the capsule is not of the right type, then
+ * a null set is returned (as checked via mlirExecutionEngineIsNull). In such a
+ * case, the Python APIs will have already set an error. */
+static inline MlirExecutionEngine
+mlirPythonCapsuleToExecutionEngine(PyObject *capsule) {
+  void *ptr =
+      PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_EXECUTION_ENGINE);
+  MlirExecutionEngine jit = {ptr};
+  return jit;
+}
+
 #ifdef __cplusplus
 }
 #endif
index 02d41dc..c256357 100644 (file)
@@ -56,6 +56,11 @@ static inline bool mlirExecutionEngineIsNull(MlirExecutionEngine jit) {
 MLIR_CAPI_EXPORTED MlirLogicalResult mlirExecutionEngineInvokePacked(
     MlirExecutionEngine jit, MlirStringRef name, void **arguments);
 
+/// Lookup a native function in the execution engine by name, returns nullptr
+/// if the name can't be looked-up.
+MLIR_CAPI_EXPORTED void *mlirExecutionEngineLookup(MlirExecutionEngine jit,
+                                                   MlirStringRef name);
+
 #ifdef __cplusplus
 }
 #endif
index 951aa78..199b30d 100644 (file)
@@ -8,11 +8,12 @@ add_custom_target(MLIRBindingsPythonExtension)
 set(PY_SRC_FILES
   mlir/__init__.py
   mlir/_dlloader.py
-  mlir/ir.py
+  mlir/conversions/__init__.py
   mlir/dialects/__init__.py
   mlir/dialects/_linalg.py
   mlir/dialects/_builtin.py
   mlir/ir.py
+  mlir/execution_engine.py
   mlir/passmanager.py
   mlir/transforms/__init__.py
 )
@@ -74,6 +75,7 @@ add_mlir_python_extension(MLIRCoreBindingsPythonExtension _mlir
     IRModules.cpp
     PybindUtils.cpp
     Pass.cpp
+    ExecutionEngine.cpp
 )
 add_dependencies(MLIRBindingsPythonExtension MLIRCoreBindingsPythonExtension)
 
@@ -114,3 +116,4 @@ if (NOT LLVM_ENABLE_IDE)
 endif()
 
 add_subdirectory(Transforms)
+add_subdirectory(Conversions)
diff --git a/mlir/lib/Bindings/Python/Conversions/CMakeLists.txt b/mlir/lib/Bindings/Python/Conversions/CMakeLists.txt
new file mode 100644 (file)
index 0000000..ad2aeef
--- /dev/null
@@ -0,0 +1,10 @@
+################################################################################
+# Build python extension
+################################################################################
+
+add_mlir_python_extension(MLIRConversionsBindingsPythonExtension _mlirConversions
+  INSTALL_DIR
+    python
+  SOURCES
+  Conversions.cpp
+)
diff --git a/mlir/lib/Bindings/Python/Conversions/Conversions.cpp b/mlir/lib/Bindings/Python/Conversions/Conversions.cpp
new file mode 100644 (file)
index 0000000..f8b3b20
--- /dev/null
@@ -0,0 +1,24 @@
+//===- Conversions.cpp - Pybind module for the Conversionss library -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir-c/Conversion.h"
+
+#include <pybind11/pybind11.h>
+
+namespace py = pybind11;
+
+// -----------------------------------------------------------------------------
+// Module initialization.
+// -----------------------------------------------------------------------------
+
+PYBIND11_MODULE(_mlirConversions, m) {
+  m.doc() = "MLIR Conversions library";
+
+  // Register all the passes in the Conversions library on load.
+  mlirRegisterConversionPasses();
+}
diff --git a/mlir/lib/Bindings/Python/ExecutionEngine.cpp b/mlir/lib/Bindings/Python/ExecutionEngine.cpp
new file mode 100644 (file)
index 0000000..f6f52e2
--- /dev/null
@@ -0,0 +1,87 @@
+//===- ExecutionEngine.cpp - Python MLIR ExecutionEngine Bindings ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "ExecutionEngine.h"
+
+#include "IRModules.h"
+#include "mlir-c/Bindings/Python/Interop.h"
+#include "mlir-c/ExecutionEngine.h"
+
+namespace py = pybind11;
+using namespace mlir;
+using namespace mlir::python;
+
+namespace {
+
+/// Owning Wrapper around an ExecutionEngine.
+class PyExecutionEngine {
+public:
+  PyExecutionEngine(MlirExecutionEngine executionEngine)
+      : executionEngine(executionEngine) {}
+  PyExecutionEngine(PyExecutionEngine &&other)
+      : executionEngine(other.executionEngine) {
+    other.executionEngine.ptr = nullptr;
+  }
+  ~PyExecutionEngine() {
+    if (!mlirExecutionEngineIsNull(executionEngine))
+      mlirExecutionEngineDestroy(executionEngine);
+  }
+  MlirExecutionEngine get() { return executionEngine; }
+
+  void release() { executionEngine.ptr = nullptr; }
+  pybind11::object getCapsule() {
+    return py::reinterpret_steal<py::object>(
+        mlirPythonExecutionEngineToCapsule(get()));
+  }
+
+  static pybind11::object createFromCapsule(pybind11::object capsule) {
+    MlirExecutionEngine rawPm =
+        mlirPythonCapsuleToExecutionEngine(capsule.ptr());
+    if (mlirExecutionEngineIsNull(rawPm))
+      throw py::error_already_set();
+    return py::cast(PyExecutionEngine(rawPm), py::return_value_policy::move);
+  }
+
+private:
+  MlirExecutionEngine executionEngine;
+};
+
+} // anonymous namespace
+
+/// Create the `mlir.execution_engine` module here.
+void mlir::python::populateExecutionEngineSubmodule(py::module &m) {
+  //----------------------------------------------------------------------------
+  // Mapping of the top-level PassManager
+  //----------------------------------------------------------------------------
+  py::class_<PyExecutionEngine>(m, "ExecutionEngine")
+      .def(py::init<>([](PyModule &module) {
+             MlirExecutionEngine executionEngine =
+                 mlirExecutionEngineCreate(module.get());
+             if (mlirExecutionEngineIsNull(executionEngine))
+               throw std::runtime_error(
+                   "Failure while creating the ExecutionEngine.");
+             return new PyExecutionEngine(executionEngine);
+           }),
+           "Create a new ExecutionEngine instance for the given Module. The "
+           "module must "
+           "contain only dialects that can be translated to LLVM.")
+      .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
+                             &PyExecutionEngine::getCapsule)
+      .def("_testing_release", &PyExecutionEngine::release,
+           "Releases (leaks) the backing ExecutionEngine (for testing purpose)")
+      .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyExecutionEngine::createFromCapsule)
+      .def(
+          "raw_lookup",
+          [](PyExecutionEngine &executionEngine, const std::string &func) {
+            auto *res = mlirExecutionEngineLookup(
+                executionEngine.get(),
+                mlirStringRefCreate(func.c_str(), func.size()));
+            return (int64_t)res;
+          },
+          "Lookup function `func` in the ExecutionEngine.");
+}
diff --git a/mlir/lib/Bindings/Python/ExecutionEngine.h b/mlir/lib/Bindings/Python/ExecutionEngine.h
new file mode 100644 (file)
index 0000000..cc61648
--- /dev/null
@@ -0,0 +1,22 @@
+//===- ExecutionEngine.h - ExecutionEngine submodule of pybind module -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_BINDINGS_PYTHON_EXECUTIONENGINE_H
+#define MLIR_BINDINGS_PYTHON_EXECUTIONENGINE_H
+
+#include "PybindUtils.h"
+
+namespace mlir {
+namespace python {
+
+void populateExecutionEngineSubmodule(pybind11::module &m);
+
+} // namespace python
+} // namespace mlir
+
+#endif // MLIR_BINDINGS_PYTHON_EXECUTIONENGINE_H
index 1f4b69d..9bfe8b0 100644 (file)
@@ -10,6 +10,7 @@
 
 #include "PybindUtils.h"
 
+#include "ExecutionEngine.h"
 #include "Globals.h"
 #include "IRModules.h"
 #include "Pass.h"
@@ -216,4 +217,9 @@ PYBIND11_MODULE(_mlir, m) {
   auto passModule =
       m.def_submodule("passmanager", "MLIR Pass Management Bindings");
   populatePassManagerSubmodule(passModule);
+
+  // Define and populate ExecutionEngine submodule.
+  auto executionEngineModule =
+      m.def_submodule("execution_engine", "MLIR JIT Execution Engine");
+  populateExecutionEngineSubmodule(executionEngineModule);
 }
index 5ae8151..3eb4eb7 100644 (file)
@@ -10,6 +10,7 @@
 
 __all__ = [
   "ir",
+  "execution_engine",
   "passmanager",
 ]
 
@@ -61,7 +62,7 @@ def _reexport_cext(cext_module_name, target_module_name):
 
 # Import sub-modules. Since these may import from here, this must come after
 # any exported definitions.
-from . import ir, passmanager
+from . import ir, execution_engine, passmanager
 
 # Add our 'dialects' parent module to the search path for implementations.
 _cext.globals.append_dialect_search_prefix("mlir.dialects")
diff --git a/mlir/lib/Bindings/Python/mlir/conversions/__init__.py b/mlir/lib/Bindings/Python/mlir/conversions/__init__.py
new file mode 100644 (file)
index 0000000..2171343
--- /dev/null
@@ -0,0 +1,8 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# Expose the corresponding C-Extension module with a well-known name at this
+# level.
+from .. import _load_extension
+_cextConversions = _load_extension("_mlirConversions")
diff --git a/mlir/lib/Bindings/Python/mlir/execution_engine.py b/mlir/lib/Bindings/Python/mlir/execution_engine.py
new file mode 100644 (file)
index 0000000..15a874a
--- /dev/null
@@ -0,0 +1,31 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# Simply a wrapper around the extension module of the same name.
+from . import _cext
+import ctypes
+
+class ExecutionEngine(_cext.execution_engine.ExecutionEngine):
+
+  def lookup(self, name):
+    """Lookup a function emitted with the `llvm.emit_c_interface`
+    attribute and returns a ctype callable.
+    Raise a RuntimeError if the function isn't found.
+    """
+    func = self.raw_lookup("_mlir_ciface_" + name)
+    if not func:
+      raise RuntimeError("Unknown function " + name)
+    prototype = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
+    return prototype(func)
+
+  def invoke(self, name, *ctypes_args):
+    """Invoke a function with the list of ctypes arguments.
+    All arguments must be pointers.
+    Raise a RuntimeError if the function isn't found.
+    """
+    func = self.lookup(name)
+    packed_args = (ctypes.c_void_p * len(ctypes_args))()
+    for argNum in range(len(ctypes_args)):
+      packed_args[argNum] = ctypes.cast(ctypes_args[argNum], ctypes.c_void_p)
+    func(packed_args)
index 1cffddc..49722e2 100644 (file)
@@ -10,6 +10,7 @@
 #include "mlir/CAPI/ExecutionEngine.h"
 #include "mlir/CAPI/IR.h"
 #include "mlir/CAPI/Support.h"
+#include "mlir/Target/LLVMIR.h"
 #include "llvm/Support/TargetSelect.h"
 
 using namespace mlir;
@@ -22,6 +23,7 @@ extern "C" MlirExecutionEngine mlirExecutionEngineCreate(MlirModule op) {
   }();
   (void)init_once;
 
+  mlir::registerLLVMDialectTranslation(*unwrap(op)->getContext());
   auto jitOrError = ExecutionEngine::create(unwrap(op));
   if (!jitOrError) {
     consumeError(jitOrError.takeError());
@@ -44,3 +46,11 @@ mlirExecutionEngineInvokePacked(MlirExecutionEngine jit, MlirStringRef name,
     return wrap(failure());
   return wrap(success());
 }
+
+extern "C" void *mlirExecutionEngineLookup(MlirExecutionEngine jit,
+                                           MlirStringRef name) {
+  auto expectedFPtr = unwrap(jit)->lookup(unwrap(name));
+  if (!expectedFPtr)
+    return nullptr;
+  return reinterpret_cast<void *>(*expectedFPtr);
+}
diff --git a/mlir/test/Bindings/Python/execution_engine.py b/mlir/test/Bindings/Python/execution_engine.py
new file mode 100644 (file)
index 0000000..0706ea4
--- /dev/null
@@ -0,0 +1,99 @@
+# RUN: %PYTHON %s 2>&1 | FileCheck %s
+
+import gc, sys
+from mlir.ir import *
+from mlir.passmanager import *
+from mlir.execution_engine import *
+
+# Log everything to stderr and flush so that we have a unified stream to match
+# errors/info emitted by MLIR to stderr.
+def log(*args):
+  print(*args, file=sys.stderr)
+  sys.stderr.flush()
+
+def run(f):
+  log("\nTEST:", f.__name__)
+  f()
+  gc.collect()
+  assert Context._get_live_count() == 0
+
+# Verify capsule interop.
+# CHECK-LABEL: TEST: testCapsule
+def testCapsule():
+  with Context():
+    module = Module.parse(r"""
+llvm.func @none() {
+  llvm.return
+}
+    """)
+    execution_engine = ExecutionEngine(module)
+    execution_engine_capsule = execution_engine._CAPIPtr
+    # CHECK: mlir.execution_engine.ExecutionEngine._CAPIPtr
+    log(repr(execution_engine_capsule))
+    execution_engine._testing_release()
+    execution_engine1 = ExecutionEngine._CAPICreate(execution_engine_capsule)
+    # CHECK: _mlir.execution_engine.ExecutionEngine
+    log(repr(execution_engine1))
+
+run(testCapsule)
+
+# Test invalid ExecutionEngine creation
+# CHECK-LABEL: TEST: testInvalidModule
+def testInvalidModule():
+  with Context():
+    # Builtin function
+    module = Module.parse(r"""
+    func @foo() { return }
+    """)
+    # CHECK: Got RuntimeError:  Failure while creating the ExecutionEngine.
+    try:
+      execution_engine = ExecutionEngine(module)
+    except RuntimeError as e:
+      log("Got RuntimeError: ", e)
+
+run(testInvalidModule)
+
+def lowerToLLVM(module):
+  import mlir.conversions
+  pm = PassManager.parse("convert-std-to-llvm")
+  pm.run(module)
+  return module
+
+# Test simple ExecutionEngine execution
+# CHECK-LABEL: TEST: testInvokeVoid
+def testInvokeVoid():
+  with Context():
+    module = Module.parse(r"""
+func @void() attributes { llvm.emit_c_interface } {
+  return
+}
+    """)    
+    execution_engine = ExecutionEngine(lowerToLLVM(module))
+    # Nothing to check other than no exception thrown here.
+    execution_engine.invoke("void")
+
+run(testInvokeVoid)
+
+
+# Test argument passing and result with a simple float addition.
+# CHECK-LABEL: TEST: testInvokeFloatAdd
+def testInvokeFloatAdd():
+  with Context():
+    module = Module.parse(r"""
+func @add(%arg0: f32, %arg1: f32) -> f32 attributes { llvm.emit_c_interface } {
+  %add = std.addf %arg0, %arg1 : f32
+  return %add : f32
+}
+    """)
+    execution_engine = ExecutionEngine(lowerToLLVM(module))
+    # Prepare arguments: two input floats and one result.
+    # Arguments must be passed as pointers.
+    c_float_p = ctypes.c_float * 1
+    arg0 = c_float_p(42.)
+    arg1 = c_float_p(2.)
+    res = c_float_p(-1.)
+    execution_engine.invoke("add", arg0, arg1, res)
+    # CHECK: 42.0 + 2.0 = 44.0
+    log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0]))
+
+run(testInvokeFloatAdd)
index d50946e..1c972d1 100644 (file)
@@ -119,6 +119,7 @@ if(MLIR_BINDINGS_PYTHON_ENABLED)
     MLIRBindingsPythonExtension
     MLIRBindingsPythonTestOps
     MLIRTransformsBindingsPythonExtension
+    MLIRConversionsBindingsPythonExtension
   )
 endif()