Rename MlirExecutionEngine lookup to lookupPacked
authorTres Popp <tpopp@google.com>
Mon, 22 Nov 2021 09:37:42 +0000 (10:37 +0100)
committerTres Popp <tpopp@google.com>
Mon, 22 Nov 2021 13:12:09 +0000 (14:12 +0100)
The purpose of the change is to make clear whether the user is
retrieving the original function or the wrapper function, in line with
the invoke commands. This new functionality is useful for users that
already have defined their own packed interface, so they do not want the
extra layer of indirection, or for users wanting to the look at the
resulting primary function rather than the wrapper function.

All locations, except the python bindings now have a `lookupPacked`
method that matches the original `lookup` functionality. `lookup`
still exists, but with new semantics.

- `lookup` returns the function with a given name. If `bool f(int,int)`
is compiled, `lookup` will return a reference to `bool(*f)(int,int)`.
- `lookupPacked` returns the packed wrapper of the function with the
given name. If `bool f(int,int)` is compiled, `lookupPacked` will return
`void(*mlir_f)(void**)`.

Differential Revision: https://reviews.llvm.org/D114352

mlir/include/mlir-c/ExecutionEngine.h
mlir/include/mlir/ExecutionEngine/ExecutionEngine.h
mlir/lib/Bindings/Python/ExecutionEngineModule.cpp
mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp
mlir/lib/ExecutionEngine/ExecutionEngine.cpp
mlir/lib/ExecutionEngine/JitRunner.cpp

index bb45452..cd3df8e 100644 (file)
@@ -62,6 +62,11 @@ static inline bool mlirExecutionEngineIsNull(MlirExecutionEngine jit) {
 MLIR_CAPI_EXPORTED MlirLogicalResult mlirExecutionEngineInvokePacked(
     MlirExecutionEngine jit, MlirStringRef name, void **arguments);
 
+/// Lookup the wrapper of the native function in the execution engine with the
+/// given name, returns nullptr if the function can't be looked-up.
+MLIR_CAPI_EXPORTED void *
+mlirExecutionEngineLookupPacked(MlirExecutionEngine jit, MlirStringRef name);
+
 /// Lookup a native function in the execution engine by name, returns nullptr
 /// if the name can't be looked-up.
 MLIR_CAPI_EXPORTED void *mlirExecutionEngineLookup(MlirExecutionEngine jit,
index 5c9f0fa..a1cb088 100644 (file)
@@ -97,9 +97,14 @@ public:
          bool enableGDBNotificationListener = true,
          bool enablePerfNotificationListener = true);
 
-  /// Looks up a packed-argument function with the given name and returns a
-  /// pointer to it.  Propagates errors in case of failure.
-  llvm::Expected<void (*)(void **)> lookup(StringRef name) const;
+  /// Looks up a packed-argument function wrapping the function with the given
+  /// name and returns a pointer to it. Propagates errors in case of failure.
+  llvm::Expected<void (*)(void **)> lookupPacked(StringRef name) const;
+
+  /// Looks up the original function with the given name and returns a
+  /// pointer to it. This is not necesarily a packed function. Propagates
+  /// errors in case of failure.
+  llvm::Expected<void *> lookup(StringRef name) const;
 
   /// Invokes the function with the given name passing it the list of opaque
   /// pointers to the actual arguments.
index 07c3516..c49d990 100644 (file)
@@ -100,7 +100,7 @@ PYBIND11_MODULE(_mlirExecutionEngine, m) {
       .def(
           "raw_lookup",
           [](PyExecutionEngine &executionEngine, const std::string &func) {
-            auto *res = mlirExecutionEngineLookup(
+            auto *res = mlirExecutionEngineLookupPacked(
                 executionEngine.get(),
                 mlirStringRefCreate(func.c_str(), func.size()));
             return reinterpret_cast<uintptr_t>(res);
index a9bb09e..604cc45 100644 (file)
@@ -75,6 +75,14 @@ mlirExecutionEngineInvokePacked(MlirExecutionEngine jit, MlirStringRef name,
   return wrap(success());
 }
 
+extern "C" void *mlirExecutionEngineLookupPacked(MlirExecutionEngine jit,
+                                                 MlirStringRef name) {
+  auto expectedFPtr = unwrap(jit)->lookupPacked(unwrap(name));
+  if (!expectedFPtr)
+    return nullptr;
+  return reinterpret_cast<void *>(*expectedFPtr);
+}
+
 extern "C" void *mlirExecutionEngineLookup(MlirExecutionEngine jit,
                                            MlirStringRef name) {
   auto expectedFPtr = unwrap(jit)->lookup(unwrap(name));
index 9d66085..19459f2 100644 (file)
@@ -328,8 +328,16 @@ Expected<std::unique_ptr<ExecutionEngine>> ExecutionEngine::create(
   return std::move(engine);
 }
 
-Expected<void (*)(void **)> ExecutionEngine::lookup(StringRef name) const {
-  auto expectedSymbol = jit->lookup(makePackedFunctionName(name));
+Expected<void (*)(void **)>
+ExecutionEngine::lookupPacked(StringRef name) const {
+  auto result = lookup(makePackedFunctionName(name));
+  if (!result)
+    return result.takeError();
+  return reinterpret_cast<void (*)(void **)>(result.get());
+}
+
+Expected<void *> ExecutionEngine::lookup(StringRef name) const {
+  auto expectedSymbol = jit->lookup(name);
 
   // JIT lookup may return an Error referring to strings stored internally by
   // the JIT. If the Error outlives the ExecutionEngine, it would want have a
@@ -346,7 +354,7 @@ Expected<void (*)(void **)> ExecutionEngine::lookup(StringRef name) const {
   }
 
   auto rawFPtr = expectedSymbol->getAddress();
-  auto fptr = reinterpret_cast<void (*)(void **)>(rawFPtr);
+  auto fptr = reinterpret_cast<void *>(rawFPtr);
   if (!fptr)
     return make_string_error("looked up function is null");
   return fptr;
@@ -354,7 +362,7 @@ Expected<void (*)(void **)> ExecutionEngine::lookup(StringRef name) const {
 
 Error ExecutionEngine::invokePacked(StringRef name,
                                     MutableArrayRef<void *> args) {
-  auto expectedFPtr = lookup(name);
+  auto expectedFPtr = lookupPacked(name);
   if (!expectedFPtr)
     return expectedFPtr.takeError();
   auto fptr = *expectedFPtr;
index cf32c08..5ecce56 100644 (file)
@@ -216,7 +216,7 @@ static Error compileAndExecute(Options &options, ModuleOp module,
   auto engine = std::move(*expectedEngine);
   engine->registerSymbols(runtimeSymbolMap);
 
-  auto expectedFPtr = engine->lookup(entryPoint);
+  auto expectedFPtr = engine->lookupPacked(entryPoint);
   if (!expectedFPtr)
     return expectedFPtr.takeError();