namespace mlir {
namespace impl {
+
+/// Wrapper class around DynamicLibrarySearchGenerator to allow searching
+/// in-process symbols that have not been explicitly exported.
+/// This first tries to resolve a symbol by using DynamicLibrarySearchGenerator.
+/// For symbols that are not found this way, it then uses
+/// `llvm::sys::DynamicLibrary::SearchForAddressOfSymbol` to extract symbols
+/// that have been explicitly added with `llvm::sys::DynamicLibrary::AddSymbol`,
+/// previously.
+class SearchGenerator {
+public:
+ SearchGenerator(char GlobalPrefix)
+ : defaultGenerator(cantFail(
+ llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(
+ GlobalPrefix))) {}
+
+ // This function forwards to DynamicLibrarySearchGenerator::operator() and
+ // adds an extra resolution for names explicitly registered via
+ // `llvm::sys::DynamicLibrary::AddSymbol`.
+ Expected<llvm::orc::SymbolNameSet>
+ operator()(llvm::orc::JITDylib &JD, const llvm::orc::SymbolNameSet &Names) {
+ auto res = defaultGenerator(JD, Names);
+ if (!res)
+ return res;
+ llvm::orc::SymbolMap newSymbols;
+ for (auto &Name : Names) {
+ if (res.get().count(Name) > 0)
+ continue;
+ res.get().insert(Name);
+ auto addedSymbolAddress =
+ llvm::sys::DynamicLibrary::SearchForAddressOfSymbol(*Name);
+ if (!addedSymbolAddress)
+ continue;
+ llvm::JITEvaluatedSymbol Sym(
+ reinterpret_cast<uintptr_t>(addedSymbolAddress),
+ llvm::JITSymbolFlags::Exported);
+ newSymbols[Name] = Sym;
+ }
+ if (!newSymbols.empty())
+ cantFail(JD.define(absoluteSymbols(std::move(newSymbols))));
+ return res;
+ }
+
+private:
+ llvm::orc::DynamicLibrarySearchGenerator defaultGenerator;
+};
+
// Simple layered Orc JIT compilation engine.
class OrcJIT {
public:
// Construct a JIT engine for the target host defined by `machineBuilder`,
// using the data layout provided as `dataLayout`.
- // Setup the object layer to use our custom memory manager in order to resolve
- // calls to library functions present in the process.
+ // Setup the object layer to use our custom memory manager in order to
+ // resolve calls to library functions present in the process.
OrcJIT(llvm::orc::JITTargetMachineBuilder machineBuilder,
llvm::DataLayout layout, IRTransformer transform)
: irTransformer(transform),
dataLayout(layout), mangler(session, this->dataLayout),
threadSafeCtx(llvm::make_unique<llvm::LLVMContext>()) {
session.getMainJITDylib().setGenerator(
- cantFail(llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(
- layout.getGlobalPrefix())));
+ SearchGenerator(layout.getGlobalPrefix()));
}
// Create a JIT engine for the current host.
private:
// Wrap the `irTransformer` into a function that can be called by the
- // IRTranformLayer. If `irTransformer` is not set up, return the module as is
- // without errors.
+ // IRTranformLayer. If `irTransformer` is not set up, return the module as
+ // is without errors.
llvm::orc::IRTransformLayer::TransformFunction makeIRTransformFunction() {
return [this](llvm::orc::ThreadSafeModule module,
const llvm::orc::MaterializationResponsibility &resp)
+add_subdirectory(mlir-cpu-runner)
+
llvm_canonicalize_cmake_booleans(
LLVM_BUILD_EXAMPLES
)
-
configure_lit_site_cfg(
${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in
${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py
set(MLIR_TEST_DEPENDS
FileCheck count not
MLIRUnitTests
+ mlir-blas-cpu-runner
mlir-cpu-runner
mlir-opt
mlir-tblgen
--- /dev/null
+set(LIBS
+ MLIRAffineOps
+ MLIRAnalysis
+ MLIREDSC
+ MLIRExecutionEngine
+ MLIRIR
+ MLIRLLVMIR
+ MLIRParser
+ MLIRTargetLLVMIR
+ MLIRTransforms
+ MLIRSupport
+ MLIRCPURunnerLib
+ LLVMCore
+ LLVMSupport
+)
+add_executable(mlir-blas-cpu-runner
+ mlir-blas-cpu-runner.cpp
+)
+llvm_update_compile_flags(mlir-blas-cpu-runner)
+whole_archive_link(mlir-blas-cpu-runner
+ MLIRLLVMIR
+ MLIRStandardOps
+ MLIRTargetLLVMIR
+ MLIRTransforms
+ MLIRTranslation
+)
+target_link_libraries(mlir-blas-cpu-runner MLIRIR ${LIBS})
--- /dev/null
+//===- mlir-blas-cpu-runner.cpp - MLIR CPU Execution Driver + Blas Support ===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Main entry point.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/DynamicLibrary.h"
+
+#ifdef WITH_LAPACK
+#include "lapack/cblas.h"
+#else
+extern "C" float cblas_sdot(const int N, const float *X, const int incX,
+ const float *Y, const int incY) {
+ float res = 0.0f;
+ for (int i = 0; i < N; ++i)
+ res += X[i * incX] * Y[i * incY];
+ return res;
+}
+#endif
+
+extern int run(int argc, char **argv);
+
+void addSymbols() {
+ using llvm::sys::DynamicLibrary;
+ DynamicLibrary::AddSymbol("cblas_sdot", (void *)(&cblas_sdot));
+}
+
+int main(int argc, char **argv) {
+ addSymbols();
+ return run(argc, argv);
+}
-// RUN: mlir-opt %s -linalg-lower-to-llvm-dialect | mlir-cpu-runner -e entry1 -entry-point-result=f32 | FileCheck %s
+// RUN: mlir-opt %s -linalg-lower-to-llvm-dialect | mlir-blas-cpu-runner -e entry1 -entry-point-result=f32 | FileCheck %s
-func @linalg_dot(!llvm<"{ float*, i64, [1 x i64], [1 x i64] }">,
- !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">,
- !llvm<"{ float*, i64, [0 x i64], [0 x i64] }">) {
+func @cblas_sdot(!llvm.i64, !llvm<"float*">, !llvm.i64, !llvm<"float*">, !llvm.i64) -> !llvm.float
+
+func @linalg_dot(%arg0 : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">,
+ %arg1 : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">,
+ %arg2 : !llvm<"{ float*, i64, [0 x i64], [0 x i64] }">) {
+ %n = llvm.extractvalue %arg0[2, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
+
+ %x0 = llvm.extractvalue %arg0[0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
+ %x1 = llvm.extractvalue %arg0[1] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
+ %x = llvm.getelementptr %x0[%x1] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
+
+ %inc_x = llvm.extractvalue %arg0[3, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
+
+ %y0 = llvm.extractvalue %arg1[0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
+ %y1 = llvm.extractvalue %arg1[1] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
+ %y = llvm.getelementptr %y0[%y1] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
+
+ %inc_y = llvm.extractvalue %arg1[3, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
+
+ %res = llvm.call @cblas_sdot(%n, %x, %inc_x, %y, %inc_y) : (!llvm.i64, !llvm<"float*">, !llvm.i64, !llvm<"float*">, !llvm.i64) -> (!llvm.float)
+ %0 = llvm.extractvalue %arg2[0] : !llvm<"{ float*, i64, [0 x i64], [0 x i64] }">
+ %old = llvm.load %0 : !llvm<"float*">
+ %new = llvm.fadd %res, %old : !llvm.float
+ llvm.store %new, %0 : !llvm<"float*">
return
}
%c0 = constant 0 : index
%c1 = constant 1 : index
%c16 = constant 16 : index
- %f0 = constant 0.00000e+00 : f32
- %f1 = constant 0.00000e+00 : f32
+ %f10 = constant 10.00000e+00 : f32
+ %f1 = constant 1.00000e+00 : f32
%f2 = constant 2.00000e+00 : f32
%A = call @alloc_filled_f32(%c16, %f2) : (index, f32) -> (!linalg.buffer<f32>)
%B = call @alloc_filled_f32(%c16, %f1) : (index, f32) -> (!linalg.buffer<f32>)
- %C = call @alloc_filled_f32(%c1, %f0) : (index, f32) -> (!linalg.buffer<f32>)
+ %C = call @alloc_filled_f32(%c1, %f10) : (index, f32) -> (!linalg.buffer<f32>)
+
%res = call @dot(%A, %B, %C) : (!linalg.buffer<f32>, !linalg.buffer<f32>, !linalg.buffer<f32>) -> (f32)
+
linalg.buffer_dealloc %C : !linalg.buffer<f32>
linalg.buffer_dealloc %B : !linalg.buffer<f32>
linalg.buffer_dealloc %A : !linalg.buffer<f32>
+
return %res : f32
}
-// CHECK: 0.{{0+}}e+00
\ No newline at end of file
+// CHECK: 4.2{{0+}}e+01
\ No newline at end of file
+set(LLVM_OPTIONAL_SOURCES
+ mlir-cpu-runner-lib.cpp
+ mlir-cpu-runner.cpp
+ )
+
set(LIBS
MLIRAffineOps
MLIRAnalysis
LLVMCore
LLVMSupport
)
+
+add_llvm_library(MLIRCPURunnerLib
+ mlir-cpu-runner-lib.cpp
+)
+target_link_libraries(MLIRCPURunnerLib ${LIBS})
+
add_executable(mlir-cpu-runner
mlir-cpu-runner.cpp
)
llvm_update_compile_flags(mlir-cpu-runner)
whole_archive_link(mlir-cpu-runner MLIRLLVMIR MLIRStandardOps MLIRTargetLLVMIR MLIRTransforms MLIRTranslation)
-target_link_libraries(mlir-cpu-runner MLIRIR ${LIBS})
+target_link_libraries(mlir-cpu-runner MLIRIR ${LIBS} MLIRCPURunnerLib)
--- /dev/null
+//===- mlir-cpu-runner-lib.cpp - MLIR CPU Execution Driver Library --------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This is a command line utility that executes an MLIR file on the CPU by
+// translating MLIR to LLVM IR before JIT-compiling and executing the latter.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/ExecutionEngine/ExecutionEngine.h"
+#include "mlir/ExecutionEngine/MemRefUtils.h"
+#include "mlir/ExecutionEngine/OptUtils.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/LLVMIR/LLVMDialect.h"
+#include "mlir/Parser.h"
+#include "mlir/Support/FileUtilities.h"
+
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/LegacyPassNameParser.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/FileUtilities.h"
+#include "llvm/Support/InitLLVM.h"
+#include "llvm/Support/PrettyStackTrace.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/StringSaver.h"
+#include "llvm/Support/TargetSelect.h"
+#include "llvm/Support/ToolOutputFile.h"
+#include <numeric>
+
+using namespace mlir;
+using llvm::Error;
+
+static llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional,
+ llvm::cl::desc("<input file>"),
+ llvm::cl::init("-"));
+static llvm::cl::opt<std::string>
+ initValue("init-value", llvm::cl::desc("Initial value of MemRef elements"),
+ llvm::cl::value_desc("<float value>"), llvm::cl::init("0.0"));
+static llvm::cl::opt<std::string>
+ mainFuncName("e", llvm::cl::desc("The function to be called"),
+ llvm::cl::value_desc("<function name>"),
+ llvm::cl::init("main"));
+static llvm::cl::opt<std::string> mainFuncType(
+ "entry-point-result",
+ llvm::cl::desc("Textual description of the function type to be called"),
+ llvm::cl::value_desc("f32 or memrefs"), llvm::cl::init("memrefs"));
+
+static llvm::cl::OptionCategory optFlags("opt-like flags");
+
+// CLI list of pass information
+static llvm::cl::list<const llvm::PassInfo *, bool, llvm::PassNameParser>
+ llvmPasses(llvm::cl::desc("LLVM optimizing passes to run"),
+ llvm::cl::cat(optFlags));
+
+// CLI variables for -On options.
+static llvm::cl::opt<bool> optO0("O0", llvm::cl::desc("Run opt O0 passes"),
+ llvm::cl::cat(optFlags));
+static llvm::cl::opt<bool> optO1("O1", llvm::cl::desc("Run opt O1 passes"),
+ llvm::cl::cat(optFlags));
+static llvm::cl::opt<bool> optO2("O2", llvm::cl::desc("Run opt O2 passes"),
+ llvm::cl::cat(optFlags));
+static llvm::cl::opt<bool> optO3("O3", llvm::cl::desc("Run opt O3 passes"),
+ llvm::cl::cat(optFlags));
+
+static std::unique_ptr<Module> parseMLIRInput(StringRef inputFilename,
+ MLIRContext *context) {
+ // Set up the input file.
+ std::string errorMessage;
+ auto file = openInputFile(inputFilename, &errorMessage);
+ if (!file) {
+ llvm::errs() << errorMessage << "\n";
+ return nullptr;
+ }
+
+ llvm::SourceMgr sourceMgr;
+ sourceMgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc());
+ return std::unique_ptr<Module>(parseSourceFile(sourceMgr, context));
+}
+
+// Initialize the relevant subsystems of LLVM.
+static void initializeLLVM() {
+ llvm::InitializeNativeTarget();
+ llvm::InitializeNativeTargetAsmPrinter();
+}
+
+static inline Error make_string_error(const llvm::Twine &message) {
+ return llvm::make_error<llvm::StringError>(message.str(),
+ llvm::inconvertibleErrorCode());
+}
+
+static void printOneMemRef(Type t, void *val) {
+ auto memRefType = t.cast<MemRefType>();
+ auto shape = memRefType.getShape();
+ int64_t size = std::accumulate(shape.begin(), shape.end(), 1,
+ std::multiplies<int64_t>());
+ for (int64_t i = 0; i < size; ++i) {
+ llvm::outs() << reinterpret_cast<StaticFloatMemRef *>(val)->data[i] << ' ';
+ }
+ llvm::outs() << '\n';
+}
+
+static void printMemRefArguments(ArrayRef<Type> argTypes,
+ ArrayRef<Type> resTypes,
+ ArrayRef<void *> args) {
+ auto properArgs = args.take_front(argTypes.size());
+ for (const auto &kvp : llvm::zip(argTypes, properArgs)) {
+ auto type = std::get<0>(kvp);
+ auto val = std::get<1>(kvp);
+ printOneMemRef(type, val);
+ }
+
+ auto results = args.drop_front(argTypes.size());
+ for (const auto &kvp : llvm::zip(resTypes, results)) {
+ auto type = std::get<0>(kvp);
+ auto val = std::get<1>(kvp);
+ printOneMemRef(type, val);
+ }
+}
+
+static Error compileAndExecuteFunctionWithMemRefs(
+ Module *module, StringRef entryPoint,
+ std::function<llvm::Error(llvm::Module *)> transformer) {
+ Function *mainFunction = module->getNamedFunction(entryPoint);
+ if (!mainFunction || mainFunction->getBlocks().empty()) {
+ return make_string_error("entry point not found");
+ }
+
+ // Store argument and result types of the original function necessary to
+ // pretty print the results, because the function itself will be rewritten
+ // to use the LLVM dialect.
+ SmallVector<Type, 8> argTypes =
+ llvm::to_vector<8>(mainFunction->getType().getInputs());
+ SmallVector<Type, 8> resTypes =
+ llvm::to_vector<8>(mainFunction->getType().getResults());
+
+ float init = std::stof(initValue.getValue());
+
+ auto expectedArguments = allocateMemRefArguments(mainFunction, init);
+ if (!expectedArguments)
+ return expectedArguments.takeError();
+
+ auto expectedEngine = mlir::ExecutionEngine::create(module, transformer);
+ if (!expectedEngine)
+ return expectedEngine.takeError();
+
+ auto engine = std::move(*expectedEngine);
+ auto expectedFPtr = engine->lookup(entryPoint);
+ if (!expectedFPtr)
+ return expectedFPtr.takeError();
+ void (*fptr)(void **) = *expectedFPtr;
+ (*fptr)(expectedArguments->data());
+ printMemRefArguments(argTypes, resTypes, *expectedArguments);
+ freeMemRefArguments(*expectedArguments);
+
+ return Error::success();
+}
+
+static Error compileAndExecuteSingleFloatReturnFunction(
+ Module *module, StringRef entryPoint,
+ std::function<llvm::Error(llvm::Module *)> transformer) {
+ Function *mainFunction = module->getNamedFunction(entryPoint);
+ if (!mainFunction || mainFunction->isExternal()) {
+ return make_string_error("entry point not found");
+ }
+
+ if (!mainFunction->getType().getInputs().empty())
+ return make_string_error("function inputs not supported");
+
+ if (mainFunction->getType().getResults().size() != 1)
+ return make_string_error("only single f32 function result supported");
+
+ auto t = mainFunction->getType().getResults()[0].dyn_cast<LLVM::LLVMType>();
+ if (!t)
+ return make_string_error("only single llvm.f32 function result supported");
+ auto *llvmTy = t.getUnderlyingType();
+ if (llvmTy != llvmTy->getFloatTy(llvmTy->getContext()))
+ return make_string_error("only single llvm.f32 function result supported");
+
+ auto expectedEngine = mlir::ExecutionEngine::create(module, transformer);
+ if (!expectedEngine)
+ return expectedEngine.takeError();
+
+ auto engine = std::move(*expectedEngine);
+ auto expectedFPtr = engine->lookup(entryPoint);
+ if (!expectedFPtr)
+ return expectedFPtr.takeError();
+ void (*fptr)(void **) = *expectedFPtr;
+
+ float res;
+ struct {
+ void *data;
+ } data;
+ data.data = &res;
+ (*fptr)((void **)&data);
+
+ // Intentional printing of the output so we can test.
+ llvm::outs() << res;
+
+ return Error::success();
+}
+
+int run(int argc, char **argv) {
+ llvm::PrettyStackTraceProgram x(argc, argv);
+ llvm::InitLLVM y(argc, argv);
+
+ initializeLLVM();
+ mlir::initializeLLVMPasses();
+
+ llvm::SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{
+ optO0, optO1, optO2, optO3};
+
+ llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR CPU execution driver\n");
+
+ llvm::SmallVector<const llvm::PassInfo *, 4> passes;
+ llvm::Optional<unsigned> optLevel;
+ unsigned optCLIPosition = 0;
+ // Determine if there is an optimization flag present, and its CLI position
+ // (optCLIPosition).
+ for (unsigned j = 0; j < 4; ++j) {
+ auto &flag = optFlags[j].get();
+ if (flag) {
+ optLevel = j;
+ optCLIPosition = flag.getPosition();
+ break;
+ }
+ }
+ // Generate vector of pass information, plus the index at which we should
+ // insert any optimization passes in that vector (optPosition).
+ unsigned optPosition = 0;
+ for (unsigned i = 0, e = llvmPasses.size(); i < e; ++i) {
+ passes.push_back(llvmPasses[i]);
+ if (optCLIPosition < llvmPasses.getPosition(i)) {
+ optPosition = i;
+ optCLIPosition = UINT_MAX; // To ensure we never insert again
+ }
+ }
+
+ MLIRContext context;
+ auto m = parseMLIRInput(inputFilename, &context);
+ if (!m) {
+ llvm::errs() << "could not parse the input IR\n";
+ return 1;
+ }
+
+ auto transformer =
+ mlir::makeLLVMPassesTransformer(passes, optLevel, optPosition);
+ auto error = mainFuncType.getValue() == "f32"
+ ? compileAndExecuteSingleFloatReturnFunction(
+ m.get(), mainFuncName.getValue(), transformer)
+ : compileAndExecuteFunctionWithMemRefs(
+ m.get(), mainFuncName.getValue(), transformer);
+ int exitCode = EXIT_SUCCESS;
+ llvm::handleAllErrors(std::move(error),
+ [&exitCode](const llvm::ErrorInfoBase &info) {
+ llvm::errs() << "Error: ";
+ info.log(llvm::errs());
+ llvm::errs() << '\n';
+ exitCode = EXIT_FAILURE;
+ });
+
+ return exitCode;
+}
// limitations under the License.
// =============================================================================
//
-// This is a command line utility that executes an MLIR file on the CPU by
-// translating MLIR to LLVM IR before JIT-compiling and executing the latter.
+// Main entry point to a command line utility that executes an MLIR file on the
+// CPU by translating MLIR to LLVM IR before JIT-compiling and executing the
+// latter.
//
//===----------------------------------------------------------------------===//
+extern int run(int argc, char **argv);
-#include "mlir/ExecutionEngine/ExecutionEngine.h"
-#include "mlir/ExecutionEngine/MemRefUtils.h"
-#include "mlir/ExecutionEngine/OptUtils.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/Module.h"
-#include "mlir/IR/StandardTypes.h"
-#include "mlir/LLVMIR/LLVMDialect.h"
-#include "mlir/Parser.h"
-#include "mlir/Support/FileUtilities.h"
-
-#include "llvm/IR/IRBuilder.h"
-#include "llvm/IR/LLVMContext.h"
-#include "llvm/IR/LegacyPassNameParser.h"
-#include "llvm/IR/Module.h"
-#include "llvm/Support/CommandLine.h"
-#include "llvm/Support/FileUtilities.h"
-#include "llvm/Support/InitLLVM.h"
-#include "llvm/Support/PrettyStackTrace.h"
-#include "llvm/Support/SourceMgr.h"
-#include "llvm/Support/StringSaver.h"
-#include "llvm/Support/TargetSelect.h"
-#include "llvm/Support/ToolOutputFile.h"
-#include <numeric>
-
-using namespace mlir;
-using llvm::Error;
-
-static llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional,
- llvm::cl::desc("<input file>"),
- llvm::cl::init("-"));
-static llvm::cl::opt<std::string>
- initValue("init-value", llvm::cl::desc("Initial value of MemRef elements"),
- llvm::cl::value_desc("<float value>"), llvm::cl::init("0.0"));
-static llvm::cl::opt<std::string>
- mainFuncName("e", llvm::cl::desc("The function to be called"),
- llvm::cl::value_desc("<function name>"),
- llvm::cl::init("main"));
-static llvm::cl::opt<std::string> mainFuncType(
- "entry-point-result",
- llvm::cl::desc("Textual description of the function type to be called"),
- llvm::cl::value_desc("f32 or memrefs"), llvm::cl::init("memrefs"));
-
-static llvm::cl::OptionCategory optFlags("opt-like flags");
-
-// CLI list of pass information
-static llvm::cl::list<const llvm::PassInfo *, bool, llvm::PassNameParser>
- llvmPasses(llvm::cl::desc("LLVM optimizing passes to run"),
- llvm::cl::cat(optFlags));
-
-// CLI variables for -On options.
-static llvm::cl::opt<bool> optO0("O0", llvm::cl::desc("Run opt O0 passes"),
- llvm::cl::cat(optFlags));
-static llvm::cl::opt<bool> optO1("O1", llvm::cl::desc("Run opt O1 passes"),
- llvm::cl::cat(optFlags));
-static llvm::cl::opt<bool> optO2("O2", llvm::cl::desc("Run opt O2 passes"),
- llvm::cl::cat(optFlags));
-static llvm::cl::opt<bool> optO3("O3", llvm::cl::desc("Run opt O3 passes"),
- llvm::cl::cat(optFlags));
-
-static std::unique_ptr<Module> parseMLIRInput(StringRef inputFilename,
- MLIRContext *context) {
- // Set up the input file.
- std::string errorMessage;
- auto file = openInputFile(inputFilename, &errorMessage);
- if (!file) {
- llvm::errs() << errorMessage << "\n";
- return nullptr;
- }
-
- llvm::SourceMgr sourceMgr;
- sourceMgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc());
- return std::unique_ptr<Module>(parseSourceFile(sourceMgr, context));
-}
-
-// Initialize the relevant subsystems of LLVM.
-static void initializeLLVM() {
- llvm::InitializeNativeTarget();
- llvm::InitializeNativeTargetAsmPrinter();
-}
-
-static inline Error make_string_error(const llvm::Twine &message) {
- return llvm::make_error<llvm::StringError>(message.str(),
- llvm::inconvertibleErrorCode());
-}
-
-static void printOneMemRef(Type t, void *val) {
- auto memRefType = t.cast<MemRefType>();
- auto shape = memRefType.getShape();
- int64_t size = std::accumulate(shape.begin(), shape.end(), 1,
- std::multiplies<int64_t>());
- for (int64_t i = 0; i < size; ++i) {
- llvm::outs() << reinterpret_cast<StaticFloatMemRef *>(val)->data[i] << ' ';
- }
- llvm::outs() << '\n';
-}
-
-static void printMemRefArguments(ArrayRef<Type> argTypes,
- ArrayRef<Type> resTypes,
- ArrayRef<void *> args) {
- auto properArgs = args.take_front(argTypes.size());
- for (const auto &kvp : llvm::zip(argTypes, properArgs)) {
- auto type = std::get<0>(kvp);
- auto val = std::get<1>(kvp);
- printOneMemRef(type, val);
- }
-
- auto results = args.drop_front(argTypes.size());
- for (const auto &kvp : llvm::zip(resTypes, results)) {
- auto type = std::get<0>(kvp);
- auto val = std::get<1>(kvp);
- printOneMemRef(type, val);
- }
-}
-
-static Error compileAndExecuteFunctionWithMemRefs(
- Module *module, StringRef entryPoint,
- std::function<llvm::Error(llvm::Module *)> transformer) {
- Function *mainFunction = module->getNamedFunction(entryPoint);
- if (!mainFunction || mainFunction->getBlocks().empty()) {
- return make_string_error("entry point not found");
- }
-
- // Store argument and result types of the original function necessary to
- // pretty print the results, because the function itself will be rewritten
- // to use the LLVM dialect.
- SmallVector<Type, 8> argTypes =
- llvm::to_vector<8>(mainFunction->getType().getInputs());
- SmallVector<Type, 8> resTypes =
- llvm::to_vector<8>(mainFunction->getType().getResults());
-
- float init = std::stof(initValue.getValue());
-
- auto expectedArguments = allocateMemRefArguments(mainFunction, init);
- if (!expectedArguments)
- return expectedArguments.takeError();
-
- auto expectedEngine = mlir::ExecutionEngine::create(module, transformer);
- if (!expectedEngine)
- return expectedEngine.takeError();
-
- auto engine = std::move(*expectedEngine);
- auto expectedFPtr = engine->lookup(entryPoint);
- if (!expectedFPtr)
- return expectedFPtr.takeError();
- void (*fptr)(void **) = *expectedFPtr;
- (*fptr)(expectedArguments->data());
- printMemRefArguments(argTypes, resTypes, *expectedArguments);
- freeMemRefArguments(*expectedArguments);
-
- return Error::success();
-}
-
-static Error compileAndExecuteSingleFloatReturnFunction(
- Module *module, StringRef entryPoint,
- std::function<llvm::Error(llvm::Module *)> transformer) {
- Function *mainFunction = module->getNamedFunction(entryPoint);
- if (!mainFunction || mainFunction->isExternal()) {
- return make_string_error("entry point not found");
- }
-
- if (!mainFunction->getType().getInputs().empty())
- return make_string_error("function inputs not supported");
-
- if (mainFunction->getType().getResults().size() != 1)
- return make_string_error("only single f32 function result supported");
-
- auto t = mainFunction->getType().getResults()[0].dyn_cast<LLVM::LLVMType>();
- if (!t)
- return make_string_error("only single llvm.f32 function result supported");
- auto *llvmTy = t.getUnderlyingType();
- if (llvmTy != llvmTy->getFloatTy(llvmTy->getContext()))
- return make_string_error("only single llvm.f32 function result supported");
-
- auto expectedEngine = mlir::ExecutionEngine::create(module, transformer);
- if (!expectedEngine)
- return expectedEngine.takeError();
-
- auto engine = std::move(*expectedEngine);
- auto expectedFPtr = engine->lookup(entryPoint);
- if (!expectedFPtr)
- return expectedFPtr.takeError();
- void (*fptr)(void **) = *expectedFPtr;
-
- float res;
- struct {
- void *data;
- } data;
- data.data = &res;
- (*fptr)((void **)&data);
-
- // Intentional printing of the output so we can test.
- llvm::outs() << res;
-
- return Error::success();
-}
-
-int main(int argc, char **argv) {
- llvm::PrettyStackTraceProgram x(argc, argv);
- llvm::InitLLVM y(argc, argv);
-
- initializeLLVM();
- mlir::initializeLLVMPasses();
-
- llvm::SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{
- optO0, optO1, optO2, optO3};
-
- llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR CPU execution driver\n");
-
- llvm::SmallVector<const llvm::PassInfo *, 4> passes;
- llvm::Optional<unsigned> optLevel;
- unsigned optCLIPosition = 0;
- // Determine if there is an optimization flag present, and its CLI position
- // (optCLIPosition).
- for (unsigned j = 0; j < 4; ++j) {
- auto &flag = optFlags[j].get();
- if (flag) {
- optLevel = j;
- optCLIPosition = flag.getPosition();
- break;
- }
- }
- // Generate vector of pass information, plus the index at which we should
- // insert any optimization passes in that vector (optPosition).
- unsigned optPosition = 0;
- for (unsigned i = 0, e = llvmPasses.size(); i < e; ++i) {
- passes.push_back(llvmPasses[i]);
- if (optCLIPosition < llvmPasses.getPosition(i)) {
- optPosition = i;
- optCLIPosition = UINT_MAX; // To ensure we never insert again
- }
- }
-
- MLIRContext context;
- auto m = parseMLIRInput(inputFilename, &context);
- if (!m) {
- llvm::errs() << "could not parse the input IR\n";
- return 1;
- }
-
- auto transformer =
- mlir::makeLLVMPassesTransformer(passes, optLevel, optPosition);
- auto error = mainFuncType.getValue() == "f32"
- ? compileAndExecuteSingleFloatReturnFunction(
- m.get(), mainFuncName.getValue(), transformer)
- : compileAndExecuteFunctionWithMemRefs(
- m.get(), mainFuncName.getValue(), transformer);
- int exitCode = EXIT_SUCCESS;
- llvm::handleAllErrors(std::move(error),
- [&exitCode](const llvm::ErrorInfoBase &info) {
- llvm::errs() << "Error: ";
- info.log(llvm::errs());
- llvm::errs() << '\n';
- exitCode = EXIT_FAILURE;
- });
-
- return exitCode;
-}
+int main(int argc, char **argv) { return run(argc, argv); }