Pipe Linalg to a cblas call via mlir-cpu-runner
authorNicolas Vasilache <ntv@google.com>
Mon, 13 May 2019 17:59:04 +0000 (10:59 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Mon, 20 May 2019 20:39:02 +0000 (13:39 -0700)
    This CL extends the execution engine to allow the additional resolution of symbols names
    that have been registered explicitly. This allows linking static library symbols that have not been explicitly exported with the -rdynamic linking flag (which is deemed too intrusive).

--

PiperOrigin-RevId: 247969504

mlir/lib/ExecutionEngine/ExecutionEngine.cpp
mlir/test/CMakeLists.txt
mlir/test/mlir-cpu-runner/CMakeLists.txt [new file with mode: 0644]
mlir/test/mlir-cpu-runner/mlir-blas-cpu-runner.cpp [new file with mode: 0644]
mlir/test/mlir-cpu-runner/simple_linalg.mlir
mlir/tools/mlir-cpu-runner/CMakeLists.txt
mlir/tools/mlir-cpu-runner/mlir-cpu-runner-lib.cpp [new file with mode: 0644]
mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp

index ece17f3..b10fcb8 100644 (file)
@@ -75,6 +75,52 @@ private:
 
 namespace mlir {
 namespace impl {
+
+/// Wrapper class around DynamicLibrarySearchGenerator to allow searching
+/// in-process symbols that have not been explicitly exported.
+/// This first tries to resolve a symbol by using DynamicLibrarySearchGenerator.
+/// For symbols that are not found this way, it then uses
+///   `llvm::sys::DynamicLibrary::SearchForAddressOfSymbol` to extract symbols
+/// that have been explicitly added with `llvm::sys::DynamicLibrary::AddSymbol`,
+/// previously.
+class SearchGenerator {
+public:
+  SearchGenerator(char GlobalPrefix)
+      : defaultGenerator(cantFail(
+            llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(
+                GlobalPrefix))) {}
+
+  // This function forwards to DynamicLibrarySearchGenerator::operator() and
+  // adds an extra resolution for names explicitly registered via
+  // `llvm::sys::DynamicLibrary::AddSymbol`.
+  Expected<llvm::orc::SymbolNameSet>
+  operator()(llvm::orc::JITDylib &JD, const llvm::orc::SymbolNameSet &Names) {
+    auto res = defaultGenerator(JD, Names);
+    if (!res)
+      return res;
+    llvm::orc::SymbolMap newSymbols;
+    for (auto &Name : Names) {
+      if (res.get().count(Name) > 0)
+        continue;
+      res.get().insert(Name);
+      auto addedSymbolAddress =
+          llvm::sys::DynamicLibrary::SearchForAddressOfSymbol(*Name);
+      if (!addedSymbolAddress)
+        continue;
+      llvm::JITEvaluatedSymbol Sym(
+          reinterpret_cast<uintptr_t>(addedSymbolAddress),
+          llvm::JITSymbolFlags::Exported);
+      newSymbols[Name] = Sym;
+    }
+    if (!newSymbols.empty())
+      cantFail(JD.define(absoluteSymbols(std::move(newSymbols))));
+    return res;
+  }
+
+private:
+  llvm::orc::DynamicLibrarySearchGenerator defaultGenerator;
+};
+
 // Simple layered Orc JIT compilation engine.
 class OrcJIT {
 public:
@@ -82,8 +128,8 @@ public:
 
   // Construct a JIT engine for the target host defined by `machineBuilder`,
   // using the data layout provided as `dataLayout`.
-  // Setup the object layer to use our custom memory manager in order to resolve
-  // calls to library functions present in the process.
+  // Setup the object layer to use our custom memory manager in order to
+  // resolve calls to library functions present in the process.
   OrcJIT(llvm::orc::JITTargetMachineBuilder machineBuilder,
          llvm::DataLayout layout, IRTransformer transform)
       : irTransformer(transform),
@@ -97,8 +143,7 @@ public:
         dataLayout(layout), mangler(session, this->dataLayout),
         threadSafeCtx(llvm::make_unique<llvm::LLVMContext>()) {
     session.getMainJITDylib().setGenerator(
-        cantFail(llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(
-            layout.getGlobalPrefix())));
+        SearchGenerator(layout.getGlobalPrefix()));
   }
 
   // Create a JIT engine for the current host.
@@ -130,8 +175,8 @@ public:
 
 private:
   // Wrap the `irTransformer` into a function that can be called by the
-  // IRTranformLayer.  If `irTransformer` is not set up, return the module as is
-  // without errors.
+  // IRTranformLayer.  If `irTransformer` is not set up, return the module as
+  // is without errors.
   llvm::orc::IRTransformLayer::TransformFunction makeIRTransformFunction() {
     return [this](llvm::orc::ThreadSafeModule module,
                   const llvm::orc::MaterializationResponsibility &resp)
index 4a775a6..a17c309 100644 (file)
@@ -1,8 +1,9 @@
+add_subdirectory(mlir-cpu-runner)
+
 llvm_canonicalize_cmake_booleans(
   LLVM_BUILD_EXAMPLES
   )
 
-
 configure_lit_site_cfg(
   ${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in
   ${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py
@@ -19,6 +20,7 @@ configure_lit_site_cfg(
 set(MLIR_TEST_DEPENDS
   FileCheck count not
   MLIRUnitTests
+  mlir-blas-cpu-runner
   mlir-cpu-runner
   mlir-opt
   mlir-tblgen
diff --git a/mlir/test/mlir-cpu-runner/CMakeLists.txt b/mlir/test/mlir-cpu-runner/CMakeLists.txt
new file mode 100644 (file)
index 0000000..dac832e
--- /dev/null
@@ -0,0 +1,27 @@
+set(LIBS
+  MLIRAffineOps
+  MLIRAnalysis
+  MLIREDSC
+  MLIRExecutionEngine
+  MLIRIR
+  MLIRLLVMIR
+  MLIRParser
+  MLIRTargetLLVMIR
+  MLIRTransforms
+  MLIRSupport
+  MLIRCPURunnerLib
+  LLVMCore
+  LLVMSupport
+)
+add_executable(mlir-blas-cpu-runner
+  mlir-blas-cpu-runner.cpp
+)
+llvm_update_compile_flags(mlir-blas-cpu-runner)
+whole_archive_link(mlir-blas-cpu-runner
+  MLIRLLVMIR
+  MLIRStandardOps
+  MLIRTargetLLVMIR
+  MLIRTransforms
+  MLIRTranslation
+)
+target_link_libraries(mlir-blas-cpu-runner MLIRIR ${LIBS})
diff --git a/mlir/test/mlir-cpu-runner/mlir-blas-cpu-runner.cpp b/mlir/test/mlir-cpu-runner/mlir-blas-cpu-runner.cpp
new file mode 100644 (file)
index 0000000..52969bd
--- /dev/null
@@ -0,0 +1,47 @@
+//===- mlir-blas-cpu-runner.cpp - MLIR CPU Execution Driver + Blas Support ===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Main entry point.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/DynamicLibrary.h"
+
+#ifdef WITH_LAPACK
+#include "lapack/cblas.h"
+#else
+extern "C" float cblas_sdot(const int N, const float *X, const int incX,
+                            const float *Y, const int incY) {
+  float res = 0.0f;
+  for (int i = 0; i < N; ++i)
+    res += X[i * incX] * Y[i * incY];
+  return res;
+}
+#endif
+
+extern int run(int argc, char **argv);
+
+void addSymbols() {
+  using llvm::sys::DynamicLibrary;
+  DynamicLibrary::AddSymbol("cblas_sdot", (void *)(&cblas_sdot));
+}
+
+int main(int argc, char **argv) {
+  addSymbols();
+  return run(argc, argv);
+}
index 119cea6..5731c4d 100644 (file)
@@ -1,8 +1,29 @@
-// RUN: mlir-opt %s -linalg-lower-to-llvm-dialect | mlir-cpu-runner -e entry1 -entry-point-result=f32 | FileCheck %s
+// RUN: mlir-opt %s -linalg-lower-to-llvm-dialect | mlir-blas-cpu-runner -e entry1 -entry-point-result=f32 | FileCheck %s
 
-func @linalg_dot(!llvm<"{ float*, i64, [1 x i64], [1 x i64] }">,
-                 !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">,
-                 !llvm<"{ float*, i64, [0 x i64], [0 x i64] }">) {
+func @cblas_sdot(!llvm.i64, !llvm<"float*">, !llvm.i64, !llvm<"float*">, !llvm.i64) -> !llvm.float
+
+func @linalg_dot(%arg0 : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">,
+                 %arg1 : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">,
+                 %arg2 : !llvm<"{ float*, i64, [0 x i64], [0 x i64] }">) {
+  %n = llvm.extractvalue %arg0[2, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
+
+  %x0 = llvm.extractvalue %arg0[0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
+  %x1 = llvm.extractvalue %arg0[1] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
+  %x = llvm.getelementptr %x0[%x1] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
+
+  %inc_x = llvm.extractvalue %arg0[3, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
+
+  %y0 = llvm.extractvalue %arg1[0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
+  %y1 = llvm.extractvalue %arg1[1] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
+  %y = llvm.getelementptr %y0[%y1] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
+
+  %inc_y = llvm.extractvalue %arg1[3, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
+
+  %res = llvm.call @cblas_sdot(%n, %x, %inc_x, %y, %inc_y) : (!llvm.i64, !llvm<"float*">, !llvm.i64, !llvm<"float*">, !llvm.i64) -> (!llvm.float)
+  %0 = llvm.extractvalue %arg2[0] : !llvm<"{ float*, i64, [0 x i64], [0 x i64] }">
+  %old = llvm.load %0 : !llvm<"float*">
+  %new = llvm.fadd %res, %old : !llvm.float
+  llvm.store %new, %0 : !llvm<"float*">
   return
 }
 
@@ -41,18 +62,21 @@ func @entry1() -> f32 {
   %c0 = constant 0 : index
   %c1 = constant 1 : index
   %c16 = constant 16 : index
-  %f0 = constant 0.00000e+00 : f32
-  %f1 = constant 0.00000e+00 : f32
+  %f10 = constant 10.00000e+00 : f32
+  %f1 = constant 1.00000e+00 : f32
   %f2 = constant 2.00000e+00 : f32
 
   %A = call @alloc_filled_f32(%c16, %f2) : (index, f32) -> (!linalg.buffer<f32>)
   %B = call @alloc_filled_f32(%c16, %f1) : (index, f32) -> (!linalg.buffer<f32>)
-  %C = call @alloc_filled_f32(%c1, %f0) : (index, f32) -> (!linalg.buffer<f32>)
+  %C = call @alloc_filled_f32(%c1, %f10) : (index, f32) -> (!linalg.buffer<f32>)
+
   %res = call @dot(%A, %B, %C) : (!linalg.buffer<f32>, !linalg.buffer<f32>, !linalg.buffer<f32>) -> (f32)
+
   linalg.buffer_dealloc %C : !linalg.buffer<f32>
   linalg.buffer_dealloc %B : !linalg.buffer<f32>
   linalg.buffer_dealloc %A : !linalg.buffer<f32>
+
   return %res : f32
 }
 
-// CHECK: 0.{{0+}}e+00
\ No newline at end of file
+// CHECK: 4.2{{0+}}e+01
\ No newline at end of file
index 844e8db..9460264 100644 (file)
@@ -1,3 +1,8 @@
+set(LLVM_OPTIONAL_SOURCES
+  mlir-cpu-runner-lib.cpp
+  mlir-cpu-runner.cpp
+  )
+
 set(LIBS
   MLIRAffineOps
   MLIRAnalysis
@@ -12,9 +17,15 @@ set(LIBS
   LLVMCore
   LLVMSupport
 )
+
+add_llvm_library(MLIRCPURunnerLib
+  mlir-cpu-runner-lib.cpp
+)
+target_link_libraries(MLIRCPURunnerLib ${LIBS})
+
 add_executable(mlir-cpu-runner
   mlir-cpu-runner.cpp
 )
 llvm_update_compile_flags(mlir-cpu-runner)
 whole_archive_link(mlir-cpu-runner MLIRLLVMIR MLIRStandardOps MLIRTargetLLVMIR MLIRTransforms MLIRTranslation)
-target_link_libraries(mlir-cpu-runner MLIRIR ${LIBS})
+target_link_libraries(mlir-cpu-runner MLIRIR ${LIBS} MLIRCPURunnerLib)
diff --git a/mlir/tools/mlir-cpu-runner/mlir-cpu-runner-lib.cpp b/mlir/tools/mlir-cpu-runner/mlir-cpu-runner-lib.cpp
new file mode 100644 (file)
index 0000000..d3e65dc
--- /dev/null
@@ -0,0 +1,279 @@
+//===- mlir-cpu-runner-lib.cpp - MLIR CPU Execution Driver Library --------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This is a command line utility that executes an MLIR file on the CPU by
+// translating MLIR to LLVM IR before JIT-compiling and executing the latter.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/ExecutionEngine/ExecutionEngine.h"
+#include "mlir/ExecutionEngine/MemRefUtils.h"
+#include "mlir/ExecutionEngine/OptUtils.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/LLVMIR/LLVMDialect.h"
+#include "mlir/Parser.h"
+#include "mlir/Support/FileUtilities.h"
+
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/LegacyPassNameParser.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/FileUtilities.h"
+#include "llvm/Support/InitLLVM.h"
+#include "llvm/Support/PrettyStackTrace.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/StringSaver.h"
+#include "llvm/Support/TargetSelect.h"
+#include "llvm/Support/ToolOutputFile.h"
+#include <numeric>
+
+using namespace mlir;
+using llvm::Error;
+
+static llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional,
+                                                llvm::cl::desc("<input file>"),
+                                                llvm::cl::init("-"));
+static llvm::cl::opt<std::string>
+    initValue("init-value", llvm::cl::desc("Initial value of MemRef elements"),
+              llvm::cl::value_desc("<float value>"), llvm::cl::init("0.0"));
+static llvm::cl::opt<std::string>
+    mainFuncName("e", llvm::cl::desc("The function to be called"),
+                 llvm::cl::value_desc("<function name>"),
+                 llvm::cl::init("main"));
+static llvm::cl::opt<std::string> mainFuncType(
+    "entry-point-result",
+    llvm::cl::desc("Textual description of the function type to be called"),
+    llvm::cl::value_desc("f32 or memrefs"), llvm::cl::init("memrefs"));
+
+static llvm::cl::OptionCategory optFlags("opt-like flags");
+
+// CLI list of pass information
+static llvm::cl::list<const llvm::PassInfo *, bool, llvm::PassNameParser>
+    llvmPasses(llvm::cl::desc("LLVM optimizing passes to run"),
+               llvm::cl::cat(optFlags));
+
+// CLI variables for -On options.
+static llvm::cl::opt<bool> optO0("O0", llvm::cl::desc("Run opt O0 passes"),
+                                 llvm::cl::cat(optFlags));
+static llvm::cl::opt<bool> optO1("O1", llvm::cl::desc("Run opt O1 passes"),
+                                 llvm::cl::cat(optFlags));
+static llvm::cl::opt<bool> optO2("O2", llvm::cl::desc("Run opt O2 passes"),
+                                 llvm::cl::cat(optFlags));
+static llvm::cl::opt<bool> optO3("O3", llvm::cl::desc("Run opt O3 passes"),
+                                 llvm::cl::cat(optFlags));
+
+static std::unique_ptr<Module> parseMLIRInput(StringRef inputFilename,
+                                              MLIRContext *context) {
+  // Set up the input file.
+  std::string errorMessage;
+  auto file = openInputFile(inputFilename, &errorMessage);
+  if (!file) {
+    llvm::errs() << errorMessage << "\n";
+    return nullptr;
+  }
+
+  llvm::SourceMgr sourceMgr;
+  sourceMgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc());
+  return std::unique_ptr<Module>(parseSourceFile(sourceMgr, context));
+}
+
+// Initialize the relevant subsystems of LLVM.
+static void initializeLLVM() {
+  llvm::InitializeNativeTarget();
+  llvm::InitializeNativeTargetAsmPrinter();
+}
+
+static inline Error make_string_error(const llvm::Twine &message) {
+  return llvm::make_error<llvm::StringError>(message.str(),
+                                             llvm::inconvertibleErrorCode());
+}
+
+static void printOneMemRef(Type t, void *val) {
+  auto memRefType = t.cast<MemRefType>();
+  auto shape = memRefType.getShape();
+  int64_t size = std::accumulate(shape.begin(), shape.end(), 1,
+                                 std::multiplies<int64_t>());
+  for (int64_t i = 0; i < size; ++i) {
+    llvm::outs() << reinterpret_cast<StaticFloatMemRef *>(val)->data[i] << ' ';
+  }
+  llvm::outs() << '\n';
+}
+
+static void printMemRefArguments(ArrayRef<Type> argTypes,
+                                 ArrayRef<Type> resTypes,
+                                 ArrayRef<void *> args) {
+  auto properArgs = args.take_front(argTypes.size());
+  for (const auto &kvp : llvm::zip(argTypes, properArgs)) {
+    auto type = std::get<0>(kvp);
+    auto val = std::get<1>(kvp);
+    printOneMemRef(type, val);
+  }
+
+  auto results = args.drop_front(argTypes.size());
+  for (const auto &kvp : llvm::zip(resTypes, results)) {
+    auto type = std::get<0>(kvp);
+    auto val = std::get<1>(kvp);
+    printOneMemRef(type, val);
+  }
+}
+
+static Error compileAndExecuteFunctionWithMemRefs(
+    Module *module, StringRef entryPoint,
+    std::function<llvm::Error(llvm::Module *)> transformer) {
+  Function *mainFunction = module->getNamedFunction(entryPoint);
+  if (!mainFunction || mainFunction->getBlocks().empty()) {
+    return make_string_error("entry point not found");
+  }
+
+  // Store argument and result types of the original function necessary to
+  // pretty print the results, because the function itself will be rewritten
+  // to use the LLVM dialect.
+  SmallVector<Type, 8> argTypes =
+      llvm::to_vector<8>(mainFunction->getType().getInputs());
+  SmallVector<Type, 8> resTypes =
+      llvm::to_vector<8>(mainFunction->getType().getResults());
+
+  float init = std::stof(initValue.getValue());
+
+  auto expectedArguments = allocateMemRefArguments(mainFunction, init);
+  if (!expectedArguments)
+    return expectedArguments.takeError();
+
+  auto expectedEngine = mlir::ExecutionEngine::create(module, transformer);
+  if (!expectedEngine)
+    return expectedEngine.takeError();
+
+  auto engine = std::move(*expectedEngine);
+  auto expectedFPtr = engine->lookup(entryPoint);
+  if (!expectedFPtr)
+    return expectedFPtr.takeError();
+  void (*fptr)(void **) = *expectedFPtr;
+  (*fptr)(expectedArguments->data());
+  printMemRefArguments(argTypes, resTypes, *expectedArguments);
+  freeMemRefArguments(*expectedArguments);
+
+  return Error::success();
+}
+
+static Error compileAndExecuteSingleFloatReturnFunction(
+    Module *module, StringRef entryPoint,
+    std::function<llvm::Error(llvm::Module *)> transformer) {
+  Function *mainFunction = module->getNamedFunction(entryPoint);
+  if (!mainFunction || mainFunction->isExternal()) {
+    return make_string_error("entry point not found");
+  }
+
+  if (!mainFunction->getType().getInputs().empty())
+    return make_string_error("function inputs not supported");
+
+  if (mainFunction->getType().getResults().size() != 1)
+    return make_string_error("only single f32 function result supported");
+
+  auto t = mainFunction->getType().getResults()[0].dyn_cast<LLVM::LLVMType>();
+  if (!t)
+    return make_string_error("only single llvm.f32 function result supported");
+  auto *llvmTy = t.getUnderlyingType();
+  if (llvmTy != llvmTy->getFloatTy(llvmTy->getContext()))
+    return make_string_error("only single llvm.f32 function result supported");
+
+  auto expectedEngine = mlir::ExecutionEngine::create(module, transformer);
+  if (!expectedEngine)
+    return expectedEngine.takeError();
+
+  auto engine = std::move(*expectedEngine);
+  auto expectedFPtr = engine->lookup(entryPoint);
+  if (!expectedFPtr)
+    return expectedFPtr.takeError();
+  void (*fptr)(void **) = *expectedFPtr;
+
+  float res;
+  struct {
+    void *data;
+  } data;
+  data.data = &res;
+  (*fptr)((void **)&data);
+
+  // Intentional printing of the output so we can test.
+  llvm::outs() << res;
+
+  return Error::success();
+}
+
+int run(int argc, char **argv) {
+  llvm::PrettyStackTraceProgram x(argc, argv);
+  llvm::InitLLVM y(argc, argv);
+
+  initializeLLVM();
+  mlir::initializeLLVMPasses();
+
+  llvm::SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{
+      optO0, optO1, optO2, optO3};
+
+  llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR CPU execution driver\n");
+
+  llvm::SmallVector<const llvm::PassInfo *, 4> passes;
+  llvm::Optional<unsigned> optLevel;
+  unsigned optCLIPosition = 0;
+  // Determine if there is an optimization flag present, and its CLI position
+  // (optCLIPosition).
+  for (unsigned j = 0; j < 4; ++j) {
+    auto &flag = optFlags[j].get();
+    if (flag) {
+      optLevel = j;
+      optCLIPosition = flag.getPosition();
+      break;
+    }
+  }
+  // Generate vector of pass information, plus the index at which we should
+  // insert any optimization passes in that vector (optPosition).
+  unsigned optPosition = 0;
+  for (unsigned i = 0, e = llvmPasses.size(); i < e; ++i) {
+    passes.push_back(llvmPasses[i]);
+    if (optCLIPosition < llvmPasses.getPosition(i)) {
+      optPosition = i;
+      optCLIPosition = UINT_MAX; // To ensure we never insert again
+    }
+  }
+
+  MLIRContext context;
+  auto m = parseMLIRInput(inputFilename, &context);
+  if (!m) {
+    llvm::errs() << "could not parse the input IR\n";
+    return 1;
+  }
+
+  auto transformer =
+      mlir::makeLLVMPassesTransformer(passes, optLevel, optPosition);
+  auto error = mainFuncType.getValue() == "f32"
+                   ? compileAndExecuteSingleFloatReturnFunction(
+                         m.get(), mainFuncName.getValue(), transformer)
+                   : compileAndExecuteFunctionWithMemRefs(
+                         m.get(), mainFuncName.getValue(), transformer);
+  int exitCode = EXIT_SUCCESS;
+  llvm::handleAllErrors(std::move(error),
+                        [&exitCode](const llvm::ErrorInfoBase &info) {
+                          llvm::errs() << "Error: ";
+                          info.log(llvm::errs());
+                          llvm::errs() << '\n';
+                          exitCode = EXIT_FAILURE;
+                        });
+
+  return exitCode;
+}
index 5deadb0..43f4eab 100644 (file)
 // limitations under the License.
 // =============================================================================
 //
-// This is a command line utility that executes an MLIR file on the CPU by
-// translating MLIR to LLVM IR before JIT-compiling and executing the latter.
+// Main entry point to a command line utility that executes an MLIR file on the
+// CPU by  translating MLIR to LLVM IR before JIT-compiling and executing the
+// latter.
 //
 //===----------------------------------------------------------------------===//
+extern int run(int argc, char **argv);
 
-#include "mlir/ExecutionEngine/ExecutionEngine.h"
-#include "mlir/ExecutionEngine/MemRefUtils.h"
-#include "mlir/ExecutionEngine/OptUtils.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/Module.h"
-#include "mlir/IR/StandardTypes.h"
-#include "mlir/LLVMIR/LLVMDialect.h"
-#include "mlir/Parser.h"
-#include "mlir/Support/FileUtilities.h"
-
-#include "llvm/IR/IRBuilder.h"
-#include "llvm/IR/LLVMContext.h"
-#include "llvm/IR/LegacyPassNameParser.h"
-#include "llvm/IR/Module.h"
-#include "llvm/Support/CommandLine.h"
-#include "llvm/Support/FileUtilities.h"
-#include "llvm/Support/InitLLVM.h"
-#include "llvm/Support/PrettyStackTrace.h"
-#include "llvm/Support/SourceMgr.h"
-#include "llvm/Support/StringSaver.h"
-#include "llvm/Support/TargetSelect.h"
-#include "llvm/Support/ToolOutputFile.h"
-#include <numeric>
-
-using namespace mlir;
-using llvm::Error;
-
-static llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional,
-                                                llvm::cl::desc("<input file>"),
-                                                llvm::cl::init("-"));
-static llvm::cl::opt<std::string>
-    initValue("init-value", llvm::cl::desc("Initial value of MemRef elements"),
-              llvm::cl::value_desc("<float value>"), llvm::cl::init("0.0"));
-static llvm::cl::opt<std::string>
-    mainFuncName("e", llvm::cl::desc("The function to be called"),
-                 llvm::cl::value_desc("<function name>"),
-                 llvm::cl::init("main"));
-static llvm::cl::opt<std::string> mainFuncType(
-    "entry-point-result",
-    llvm::cl::desc("Textual description of the function type to be called"),
-    llvm::cl::value_desc("f32 or memrefs"), llvm::cl::init("memrefs"));
-
-static llvm::cl::OptionCategory optFlags("opt-like flags");
-
-// CLI list of pass information
-static llvm::cl::list<const llvm::PassInfo *, bool, llvm::PassNameParser>
-    llvmPasses(llvm::cl::desc("LLVM optimizing passes to run"),
-               llvm::cl::cat(optFlags));
-
-// CLI variables for -On options.
-static llvm::cl::opt<bool> optO0("O0", llvm::cl::desc("Run opt O0 passes"),
-                                 llvm::cl::cat(optFlags));
-static llvm::cl::opt<bool> optO1("O1", llvm::cl::desc("Run opt O1 passes"),
-                                 llvm::cl::cat(optFlags));
-static llvm::cl::opt<bool> optO2("O2", llvm::cl::desc("Run opt O2 passes"),
-                                 llvm::cl::cat(optFlags));
-static llvm::cl::opt<bool> optO3("O3", llvm::cl::desc("Run opt O3 passes"),
-                                 llvm::cl::cat(optFlags));
-
-static std::unique_ptr<Module> parseMLIRInput(StringRef inputFilename,
-                                              MLIRContext *context) {
-  // Set up the input file.
-  std::string errorMessage;
-  auto file = openInputFile(inputFilename, &errorMessage);
-  if (!file) {
-    llvm::errs() << errorMessage << "\n";
-    return nullptr;
-  }
-
-  llvm::SourceMgr sourceMgr;
-  sourceMgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc());
-  return std::unique_ptr<Module>(parseSourceFile(sourceMgr, context));
-}
-
-// Initialize the relevant subsystems of LLVM.
-static void initializeLLVM() {
-  llvm::InitializeNativeTarget();
-  llvm::InitializeNativeTargetAsmPrinter();
-}
-
-static inline Error make_string_error(const llvm::Twine &message) {
-  return llvm::make_error<llvm::StringError>(message.str(),
-                                             llvm::inconvertibleErrorCode());
-}
-
-static void printOneMemRef(Type t, void *val) {
-  auto memRefType = t.cast<MemRefType>();
-  auto shape = memRefType.getShape();
-  int64_t size = std::accumulate(shape.begin(), shape.end(), 1,
-                                 std::multiplies<int64_t>());
-  for (int64_t i = 0; i < size; ++i) {
-    llvm::outs() << reinterpret_cast<StaticFloatMemRef *>(val)->data[i] << ' ';
-  }
-  llvm::outs() << '\n';
-}
-
-static void printMemRefArguments(ArrayRef<Type> argTypes,
-                                 ArrayRef<Type> resTypes,
-                                 ArrayRef<void *> args) {
-  auto properArgs = args.take_front(argTypes.size());
-  for (const auto &kvp : llvm::zip(argTypes, properArgs)) {
-    auto type = std::get<0>(kvp);
-    auto val = std::get<1>(kvp);
-    printOneMemRef(type, val);
-  }
-
-  auto results = args.drop_front(argTypes.size());
-  for (const auto &kvp : llvm::zip(resTypes, results)) {
-    auto type = std::get<0>(kvp);
-    auto val = std::get<1>(kvp);
-    printOneMemRef(type, val);
-  }
-}
-
-static Error compileAndExecuteFunctionWithMemRefs(
-    Module *module, StringRef entryPoint,
-    std::function<llvm::Error(llvm::Module *)> transformer) {
-  Function *mainFunction = module->getNamedFunction(entryPoint);
-  if (!mainFunction || mainFunction->getBlocks().empty()) {
-    return make_string_error("entry point not found");
-  }
-
-  // Store argument and result types of the original function necessary to
-  // pretty print the results, because the function itself will be rewritten
-  // to use the LLVM dialect.
-  SmallVector<Type, 8> argTypes =
-      llvm::to_vector<8>(mainFunction->getType().getInputs());
-  SmallVector<Type, 8> resTypes =
-      llvm::to_vector<8>(mainFunction->getType().getResults());
-
-  float init = std::stof(initValue.getValue());
-
-  auto expectedArguments = allocateMemRefArguments(mainFunction, init);
-  if (!expectedArguments)
-    return expectedArguments.takeError();
-
-  auto expectedEngine = mlir::ExecutionEngine::create(module, transformer);
-  if (!expectedEngine)
-    return expectedEngine.takeError();
-
-  auto engine = std::move(*expectedEngine);
-  auto expectedFPtr = engine->lookup(entryPoint);
-  if (!expectedFPtr)
-    return expectedFPtr.takeError();
-  void (*fptr)(void **) = *expectedFPtr;
-  (*fptr)(expectedArguments->data());
-  printMemRefArguments(argTypes, resTypes, *expectedArguments);
-  freeMemRefArguments(*expectedArguments);
-
-  return Error::success();
-}
-
-static Error compileAndExecuteSingleFloatReturnFunction(
-    Module *module, StringRef entryPoint,
-    std::function<llvm::Error(llvm::Module *)> transformer) {
-  Function *mainFunction = module->getNamedFunction(entryPoint);
-  if (!mainFunction || mainFunction->isExternal()) {
-    return make_string_error("entry point not found");
-  }
-
-  if (!mainFunction->getType().getInputs().empty())
-    return make_string_error("function inputs not supported");
-
-  if (mainFunction->getType().getResults().size() != 1)
-    return make_string_error("only single f32 function result supported");
-
-  auto t = mainFunction->getType().getResults()[0].dyn_cast<LLVM::LLVMType>();
-  if (!t)
-    return make_string_error("only single llvm.f32 function result supported");
-  auto *llvmTy = t.getUnderlyingType();
-  if (llvmTy != llvmTy->getFloatTy(llvmTy->getContext()))
-    return make_string_error("only single llvm.f32 function result supported");
-
-  auto expectedEngine = mlir::ExecutionEngine::create(module, transformer);
-  if (!expectedEngine)
-    return expectedEngine.takeError();
-
-  auto engine = std::move(*expectedEngine);
-  auto expectedFPtr = engine->lookup(entryPoint);
-  if (!expectedFPtr)
-    return expectedFPtr.takeError();
-  void (*fptr)(void **) = *expectedFPtr;
-
-  float res;
-  struct {
-    void *data;
-  } data;
-  data.data = &res;
-  (*fptr)((void **)&data);
-
-  // Intentional printing of the output so we can test.
-  llvm::outs() << res;
-
-  return Error::success();
-}
-
-int main(int argc, char **argv) {
-  llvm::PrettyStackTraceProgram x(argc, argv);
-  llvm::InitLLVM y(argc, argv);
-
-  initializeLLVM();
-  mlir::initializeLLVMPasses();
-
-  llvm::SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{
-      optO0, optO1, optO2, optO3};
-
-  llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR CPU execution driver\n");
-
-  llvm::SmallVector<const llvm::PassInfo *, 4> passes;
-  llvm::Optional<unsigned> optLevel;
-  unsigned optCLIPosition = 0;
-  // Determine if there is an optimization flag present, and its CLI position
-  // (optCLIPosition).
-  for (unsigned j = 0; j < 4; ++j) {
-    auto &flag = optFlags[j].get();
-    if (flag) {
-      optLevel = j;
-      optCLIPosition = flag.getPosition();
-      break;
-    }
-  }
-  // Generate vector of pass information, plus the index at which we should
-  // insert any optimization passes in that vector (optPosition).
-  unsigned optPosition = 0;
-  for (unsigned i = 0, e = llvmPasses.size(); i < e; ++i) {
-    passes.push_back(llvmPasses[i]);
-    if (optCLIPosition < llvmPasses.getPosition(i)) {
-      optPosition = i;
-      optCLIPosition = UINT_MAX; // To ensure we never insert again
-    }
-  }
-
-  MLIRContext context;
-  auto m = parseMLIRInput(inputFilename, &context);
-  if (!m) {
-    llvm::errs() << "could not parse the input IR\n";
-    return 1;
-  }
-
-  auto transformer =
-      mlir::makeLLVMPassesTransformer(passes, optLevel, optPosition);
-  auto error = mainFuncType.getValue() == "f32"
-                   ? compileAndExecuteSingleFloatReturnFunction(
-                         m.get(), mainFuncName.getValue(), transformer)
-                   : compileAndExecuteFunctionWithMemRefs(
-                         m.get(), mainFuncName.getValue(), transformer);
-  int exitCode = EXIT_SUCCESS;
-  llvm::handleAllErrors(std::move(error),
-                        [&exitCode](const llvm::ErrorInfoBase &info) {
-                          llvm::errs() << "Error: ";
-                          info.log(llvm::errs());
-                          llvm::errs() << '\n';
-                          exitCode = EXIT_FAILURE;
-                        });
-
-  return exitCode;
-}
+int main(int argc, char **argv) { return run(argc, argv); }