static llvm::cl::opt<std::string> mainFuncType(
"entry-point-result",
llvm::cl::desc("Textual description of the function type to be called"),
- llvm::cl::value_desc("f32 or memrefs"), llvm::cl::init("memrefs"));
+ llvm::cl::value_desc("f32 | memrefs | void"), llvm::cl::init("memrefs"));
static llvm::cl::OptionCategory optFlags("opt-like flags");
return manager.run(module);
}
+// JIT-compile the given module and run "entryPoint" with "args" as arguments.
+static Error
+compileAndExecute(ModuleOp module, StringRef entryPoint,
+ std::function<llvm::Error(llvm::Module *)> transformer,
+ void **args) {
+ SmallVector<StringRef, 4> libs(clSharedLibs.begin(), clSharedLibs.end());
+ auto expectedEngine =
+ mlir::ExecutionEngine::create(module, transformer, libs);
+ if (!expectedEngine)
+ return expectedEngine.takeError();
+
+ auto engine = std::move(*expectedEngine);
+ auto expectedFPtr = engine->lookup(entryPoint);
+ if (!expectedFPtr)
+ return expectedFPtr.takeError();
+ void (*fptr)(void **) = *expectedFPtr;
+ (*fptr)(args);
+
+ return Error::success();
+}
+
+static Error compileAndExecuteVoidFunction(
+ ModuleOp module, StringRef entryPoint,
+ std::function<llvm::Error(llvm::Module *)> transformer) {
+ FuncOp mainFunction = module.lookupSymbol<FuncOp>(entryPoint);
+ if (!mainFunction || mainFunction.getBlocks().empty())
+ return make_string_error("entry point not found");
+ void *empty = nullptr;
+ return compileAndExecute(module, entryPoint, transformer, &empty);
+}
+
static Error compileAndExecuteFunctionWithMemRefs(
ModuleOp module, StringRef entryPoint,
std::function<llvm::Error(llvm::Module *)> transformer) {
if (failed(convertAffineStandardToLLVMIR(module)))
return make_string_error("conversion to the LLVM IR dialect failed");
- SmallVector<StringRef, 4> libs(clSharedLibs.begin(), clSharedLibs.end());
- auto expectedEngine =
- mlir::ExecutionEngine::create(module, transformer, libs);
- if (!expectedEngine)
- return expectedEngine.takeError();
+ if (auto error = compileAndExecute(module, entryPoint, transformer,
+ expectedArguments->data()))
+ return error;
- auto engine = std::move(*expectedEngine);
- auto expectedFPtr = engine->lookup(entryPoint);
- if (!expectedFPtr)
- return expectedFPtr.takeError();
- void (*fptr)(void **) = *expectedFPtr;
- (*fptr)(expectedArguments->data());
printMemRefArguments(argTypes, resTypes, *expectedArguments);
freeMemRefArguments(*expectedArguments);
-
return Error::success();
}
if (llvmTy != llvmTy->getFloatTy(llvmTy->getContext()))
return make_string_error("only single llvm.f32 function result supported");
- SmallVector<StringRef, 4> libs(clSharedLibs.begin(), clSharedLibs.end());
- auto expectedEngine =
- mlir::ExecutionEngine::create(module, transformer, libs);
- if (!expectedEngine)
- return expectedEngine.takeError();
-
- auto engine = std::move(*expectedEngine);
- auto expectedFPtr = engine->lookup(entryPoint);
- if (!expectedFPtr)
- return expectedFPtr.takeError();
- void (*fptr)(void **) = *expectedFPtr;
-
float res;
struct {
void *data;
} data;
data.data = &res;
- (*fptr)((void **)&data);
+ if (auto error =
+ compileAndExecute(module, entryPoint, transformer, (void **)&data))
+ return error;
// Intentional printing of the output so we can test.
llvm::outs() << res;
auto transformer = mlir::makeLLVMPassesTransformer(
passes, optLevel, /*targetMachine=*/tmOrError->get(), optPosition);
- auto error = mainFuncType.getValue() == "f32"
- ? compileAndExecuteSingleFloatReturnFunction(
- m.get(), mainFuncName.getValue(), transformer)
- : compileAndExecuteFunctionWithMemRefs(
- m.get(), mainFuncName.getValue(), transformer);
+
+ Error error = make_string_error("unsupported function type");
+ if (mainFuncType.getValue() == "f32")
+ error = compileAndExecuteSingleFloatReturnFunction(
+ m.get(), mainFuncName.getValue(), transformer);
+ else if (mainFuncType.getValue() == "memrefs")
+ error = compileAndExecuteFunctionWithMemRefs(
+ m.get(), mainFuncName.getValue(), transformer);
+ else if (mainFuncType.getValue() == "void")
+ error = compileAndExecuteVoidFunction(m.get(), mainFuncName.getValue(),
+ transformer);
+
int exitCode = EXIT_SUCCESS;
llvm::handleAllErrors(std::move(error),
[&exitCode](const llvm::ErrorInfoBase &info) {