JitRunner: support entry functions returning void
authorAlex Zinenko <zinenko@google.com>
Tue, 20 Aug 2019 14:45:47 +0000 (07:45 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 20 Aug 2019 14:46:17 +0000 (07:46 -0700)
JitRunner can use as entry points functions that produce either a single
'!llvm.f32' value or a list of memrefs.  Memref support is legacy and was
introduced before MLIR could lower memref allocation and deallocation to
malloc/free calls so as to allocate the memory externally, and is likely to be
dropped in the future since it unconditionally runs affine+standard-to-llvm
lowering on the module instead of accepting the LLVM dialect.  CUDA runner
relies on memref-based flow in the runner without actually returning anything.
Introduce a runner flow to use functions that return void as entry points.

PiperOrigin-RevId: 264381686

mlir/lib/Support/JitRunner.cpp
mlir/test/mlir-cuda-runner/gpu-to-cubin.mlir

index afa356e..f88d9b8 100644 (file)
@@ -70,7 +70,7 @@ static llvm::cl::opt<std::string>
 static llvm::cl::opt<std::string> mainFuncType(
     "entry-point-result",
     llvm::cl::desc("Textual description of the function type to be called"),
-    llvm::cl::value_desc("f32 or memrefs"), llvm::cl::init("memrefs"));
+    llvm::cl::value_desc("f32 | memrefs | void"), llvm::cl::init("memrefs"));
 
 static llvm::cl::OptionCategory optFlags("opt-like flags");
 
@@ -166,6 +166,37 @@ static LogicalResult convertAffineStandardToLLVMIR(ModuleOp module) {
   return manager.run(module);
 }
 
+// JIT-compile the given module and run "entryPoint" with "args" as arguments.
+static Error
+compileAndExecute(ModuleOp module, StringRef entryPoint,
+                  std::function<llvm::Error(llvm::Module *)> transformer,
+                  void **args) {
+  SmallVector<StringRef, 4> libs(clSharedLibs.begin(), clSharedLibs.end());
+  auto expectedEngine =
+      mlir::ExecutionEngine::create(module, transformer, libs);
+  if (!expectedEngine)
+    return expectedEngine.takeError();
+
+  auto engine = std::move(*expectedEngine);
+  auto expectedFPtr = engine->lookup(entryPoint);
+  if (!expectedFPtr)
+    return expectedFPtr.takeError();
+  void (*fptr)(void **) = *expectedFPtr;
+  (*fptr)(args);
+
+  return Error::success();
+}
+
+static Error compileAndExecuteVoidFunction(
+    ModuleOp module, StringRef entryPoint,
+    std::function<llvm::Error(llvm::Module *)> transformer) {
+  FuncOp mainFunction = module.lookupSymbol<FuncOp>(entryPoint);
+  if (!mainFunction || mainFunction.getBlocks().empty())
+    return make_string_error("entry point not found");
+  void *empty = nullptr;
+  return compileAndExecute(module, entryPoint, transformer, &empty);
+}
+
 static Error compileAndExecuteFunctionWithMemRefs(
     ModuleOp module, StringRef entryPoint,
     std::function<llvm::Error(llvm::Module *)> transformer) {
@@ -191,21 +222,12 @@ static Error compileAndExecuteFunctionWithMemRefs(
   if (failed(convertAffineStandardToLLVMIR(module)))
     return make_string_error("conversion to the LLVM IR dialect failed");
 
-  SmallVector<StringRef, 4> libs(clSharedLibs.begin(), clSharedLibs.end());
-  auto expectedEngine =
-      mlir::ExecutionEngine::create(module, transformer, libs);
-  if (!expectedEngine)
-    return expectedEngine.takeError();
+  if (auto error = compileAndExecute(module, entryPoint, transformer,
+                                     expectedArguments->data()))
+    return error;
 
-  auto engine = std::move(*expectedEngine);
-  auto expectedFPtr = engine->lookup(entryPoint);
-  if (!expectedFPtr)
-    return expectedFPtr.takeError();
-  void (*fptr)(void **) = *expectedFPtr;
-  (*fptr)(expectedArguments->data());
   printMemRefArguments(argTypes, resTypes, *expectedArguments);
   freeMemRefArguments(*expectedArguments);
-
   return Error::success();
 }
 
@@ -230,24 +252,14 @@ static Error compileAndExecuteSingleFloatReturnFunction(
   if (llvmTy != llvmTy->getFloatTy(llvmTy->getContext()))
     return make_string_error("only single llvm.f32 function result supported");
 
-  SmallVector<StringRef, 4> libs(clSharedLibs.begin(), clSharedLibs.end());
-  auto expectedEngine =
-      mlir::ExecutionEngine::create(module, transformer, libs);
-  if (!expectedEngine)
-    return expectedEngine.takeError();
-
-  auto engine = std::move(*expectedEngine);
-  auto expectedFPtr = engine->lookup(entryPoint);
-  if (!expectedFPtr)
-    return expectedFPtr.takeError();
-  void (*fptr)(void **) = *expectedFPtr;
-
   float res;
   struct {
     void *data;
   } data;
   data.data = &res;
-  (*fptr)((void **)&data);
+  if (auto error =
+          compileAndExecute(module, entryPoint, transformer, (void **)&data))
+    return error;
 
   // Intentional printing of the output so we can test.
   llvm::outs() << res;
@@ -320,11 +332,18 @@ int mlir::JitRunnerMain(
 
   auto transformer = mlir::makeLLVMPassesTransformer(
       passes, optLevel, /*targetMachine=*/tmOrError->get(), optPosition);
-  auto error = mainFuncType.getValue() == "f32"
-                   ? compileAndExecuteSingleFloatReturnFunction(
-                         m.get(), mainFuncName.getValue(), transformer)
-                   : compileAndExecuteFunctionWithMemRefs(
-                         m.get(), mainFuncName.getValue(), transformer);
+
+  Error error = make_string_error("unsupported function type");
+  if (mainFuncType.getValue() == "f32")
+    error = compileAndExecuteSingleFloatReturnFunction(
+        m.get(), mainFuncName.getValue(), transformer);
+  else if (mainFuncType.getValue() == "memrefs")
+    error = compileAndExecuteFunctionWithMemRefs(
+        m.get(), mainFuncName.getValue(), transformer);
+  else if (mainFuncType.getValue() == "void")
+    error = compileAndExecuteVoidFunction(m.get(), mainFuncName.getValue(),
+                                          transformer);
+
   int exitCode = EXIT_SUCCESS;
   llvm::handleAllErrors(std::move(error),
                         [&exitCode](const llvm::ErrorInfoBase &info) {
index 6610337..871327f 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-cuda-runner %s --shared-libs=%cuda_wrapper_library_dir/libcuda-runtime-wrappers%shlibext | FileCheck %s
+// RUN: mlir-cuda-runner %s --shared-libs=%cuda_wrapper_library_dir/libcuda-runtime-wrappers%shlibext --entry-point-result=void | FileCheck %s
 
 func @other_func(%arg0 : f32, %arg1 : memref<?xf32>) {
   %cst = constant 1 : index