Cleanup linalg integration test
authorNicolas Vasilache <ntv@google.com>
Wed, 15 May 2019 16:26:27 +0000 (09:26 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Mon, 20 May 2019 20:43:13 +0000 (13:43 -0700)
    This CL performs post-commit cleanups.
    It adds the ability to specify which shared libraries to load dynamically in ExecutionEngine. The linalg integration test is updated to use a shared library.
    Additional minor cleanups related to LLVM lowering of Linalg are also included.

--

PiperOrigin-RevId: 248346589

15 files changed:
mlir/bindings/python/test/test_py2and3.py
mlir/include/mlir/ExecutionEngine/ExecutionEngine.h
mlir/include/mlir/LLVMIR/LLVMLowering.h
mlir/include/mlir/Linalg/IR/CMakeLists.txt
mlir/lib/ExecutionEngine/ExecutionEngine.cpp
mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp
mlir/lib/Linalg/CMakeLists.txt
mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
mlir/test/CMakeLists.txt
mlir/test/lit.cfg.py
mlir/test/lit.site.cfg.py.in
mlir/test/mlir-cpu-runner/CMakeLists.txt
mlir/test/mlir-cpu-runner/linalg_integration_test.mlir [moved from mlir/test/mlir-cpu-runner/simple_linalg.mlir with 94% similarity]
mlir/test/mlir-cpu-runner/sdot.cpp [moved from mlir/test/mlir-cpu-runner/mlir-blas-cpu-runner.cpp with 67% similarity]
mlir/tools/mlir-cpu-runner/mlir-cpu-runner-lib.cpp

index 0832c02..7b424fd 100644 (file)
@@ -13,7 +13,7 @@
 # limitations under the License.
 # ==============================================================================
 
-# RUN: $(dirname %s)/test_edsc %s | FileCheck %s
+# RUN: %p/test_edsc %s | FileCheck %s
 """Python2 and 3 test for the MLIR EDSC Python bindings"""
 
 import google_mlir.bindings.python.pybind as E
index 102f22e..3a1ef8e 100644 (file)
@@ -32,7 +32,7 @@
 namespace llvm {
 template <typename T> class Expected;
 class Module;
-}
+} // namespace llvm
 
 namespace mlir {
 
@@ -61,16 +61,21 @@ public:
   /// runs it on the MLIR module.  If `transformer` is
   /// provided, it will be called on the LLVM module during JIT-compilation and
   /// can be used, e.g., for reporting or optimization.
+  /// If `sharedLibPaths` are provided, the underlying JIT-compilation will open
+  /// and link the shared libraries for symbol resolution.
   static llvm::Expected<std::unique_ptr<ExecutionEngine>>
   create(Module *m, PassManager *pm,
-         std::function<llvm::Error(llvm::Module *)> transformer = {});
+         std::function<llvm::Error(llvm::Module *)> transformer = {},
+         ArrayRef<StringRef> sharedLibPaths = {});
 
   /// Creates an execution engine for the given module.  If `transformer` is
   /// provided, it will be called on the LLVM module during JIT-compilation and
   /// can be used, e.g., for reporting or optimization.
+  /// If `sharedLibPaths` are provided, the underlying JIT-compilation will open
+  /// and link the shared libraries for symbol resolution.
   static llvm::Expected<std::unique_ptr<ExecutionEngine>>
-  create(Module *m,
-         std::function<llvm::Error(llvm::Module *)> transformer = {});
+  create(Module *m, std::function<llvm::Error(llvm::Module *)> transformer = {},
+         ArrayRef<StringRef> sharedLibPaths = {});
 
   /// Looks up a packed-argument function with the given name and returns a
   /// pointer to it.  Propagates errors in case of failure.
index 02bc816..4b35bbd 100644 (file)
@@ -28,6 +28,7 @@
 
 namespace llvm {
 class IntegerType;
+class LLVMContext;
 class Module;
 class Type;
 }
@@ -51,6 +52,9 @@ public:
   /// to each of the MLIR types converted with `convertType`.
   Type packFunctionResults(ArrayRef<Type> types);
 
+  /// Returns the LLVM context.
+  llvm::LLVMContext &getLLVMContext();
+
 protected:
   /// Create a set of converters that live in the pass object by passing them a
   /// reference to the LLVM IR dialect.  Store the module associated with the
index d3ed75c..b0c7266 100644 (file)
@@ -2,3 +2,7 @@ set(LLVM_TARGET_DEFINITIONS LinalgOps.td)
 mlir_tablegen(LinalgOps.h.inc -gen-op-decls)
 mlir_tablegen(LinalgOps.cpp.inc -gen-op-defs)
 add_public_tablegen_target(MLIRLinalgOpsIncGen)
+set(LLVM_TARGET_DEFINITIONS LinalgLibraryOps.td)
+mlir_tablegen(LinalgLibraryOps.h.inc -gen-op-decls)
+mlir_tablegen(LinalgLibraryOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRLinalgLibraryOpsIncGen)
index b10fcb8..1e761ad 100644 (file)
@@ -131,7 +131,8 @@ public:
   // Setup the object layer to use our custom memory manager in order to
   // resolve calls to library functions present in the process.
   OrcJIT(llvm::orc::JITTargetMachineBuilder machineBuilder,
-         llvm::DataLayout layout, IRTransformer transform)
+         llvm::DataLayout layout, IRTransformer transform,
+         ArrayRef<StringRef> sharedLibPaths)
       : irTransformer(transform),
         objectLayer(
             session,
@@ -144,11 +145,12 @@ public:
         threadSafeCtx(llvm::make_unique<llvm::LLVMContext>()) {
     session.getMainJITDylib().setGenerator(
         SearchGenerator(layout.getGlobalPrefix()));
+    loadLibraries(sharedLibPaths);
   }
 
   // Create a JIT engine for the current host.
   static Expected<std::unique_ptr<OrcJIT>>
-  createDefault(IRTransformer transformer) {
+  createDefault(IRTransformer transformer, ArrayRef<StringRef> sharedLibPaths) {
     auto machineBuilder = llvm::orc::JITTargetMachineBuilder::detectHost();
     if (!machineBuilder)
       return machineBuilder.takeError();
@@ -158,7 +160,8 @@ public:
       return dataLayout.takeError();
 
     return llvm::make_unique<OrcJIT>(std::move(*machineBuilder),
-                                     std::move(*dataLayout), transformer);
+                                     std::move(*dataLayout), transformer,
+                                     sharedLibPaths);
   }
 
   // Add an LLVM module to the main library managed by the JIT engine.
@@ -190,6 +193,10 @@ private:
     };
   }
 
+  // Iterate over shareLibPaths and load the corresponding libraries for symbol
+  // resolution.
+  void loadLibraries(ArrayRef<StringRef> sharedLibPaths);
+
   IRTransformer irTransformer;
   llvm::orc::ExecutionSession session;
   llvm::orc::RTDyldObjectLinkingLayer objectLayer;
@@ -202,6 +209,29 @@ private:
 } // end namespace impl
 } // namespace mlir
 
+void mlir::impl::OrcJIT::loadLibraries(ArrayRef<StringRef> sharedLibPaths) {
+  for (auto libPath : sharedLibPaths) {
+    auto mb = llvm::MemoryBuffer::getFile(libPath);
+    if (!mb) {
+      llvm::errs() << "Could not create MemoryBuffer for: " << libPath << " "
+                   << mb.getError().message() << "\n";
+      continue;
+    }
+    auto &JD = session.createJITDylib(libPath);
+    auto loaded = llvm::orc::DynamicLibrarySearchGenerator::Load(
+        libPath.data(), dataLayout.getGlobalPrefix());
+    if (!loaded) {
+      llvm::errs() << "Could not load: " << libPath << " " << loaded.takeError()
+                   << "\n";
+      continue;
+    }
+    JD.setGenerator(loaded.get());
+    auto res = objectLayer.add(JD, std::move(mb.get()));
+    if (res)
+      llvm::errs() << "Could not add: " << libPath << " " << res << "\n";
+  }
+}
+
 // Wrap a string into an llvm::StringError.
 static inline Error make_string_error(const llvm::Twine &message) {
   return llvm::make_error<llvm::StringError>(message.str(),
@@ -318,11 +348,12 @@ void packFunctionArguments(llvm::Module *module) {
 // Out of line for PIMPL unique_ptr.
 ExecutionEngine::~ExecutionEngine() = default;
 
-Expected<std::unique_ptr<ExecutionEngine>> ExecutionEngine::create(
-    Module *m, PassManager *pm,
-    std::function<llvm::Error(llvm::Module *)> transformer) {
+Expected<std::unique_ptr<ExecutionEngine>>
+ExecutionEngine::create(Module *m, PassManager *pm,
+                        std::function<llvm::Error(llvm::Module *)> transformer,
+                        ArrayRef<StringRef> sharedLibPaths) {
   auto engine = llvm::make_unique<ExecutionEngine>();
-  auto expectedJIT = impl::OrcJIT::createDefault(transformer);
+  auto expectedJIT = impl::OrcJIT::createDefault(transformer, sharedLibPaths);
   if (!expectedJIT)
     return expectedJIT.takeError();
 
@@ -345,12 +376,14 @@ Expected<std::unique_ptr<ExecutionEngine>> ExecutionEngine::create(
   return std::move(engine);
 }
 
-Expected<std::unique_ptr<ExecutionEngine>> ExecutionEngine::create(
-    Module *m, std::function<llvm::Error(llvm::Module *)> transformer) {
+Expected<std::unique_ptr<ExecutionEngine>>
+ExecutionEngine::create(Module *m,
+                        std::function<llvm::Error(llvm::Module *)> transformer,
+                        ArrayRef<StringRef> sharedLibPaths) {
   // Construct and run the default MLIR pipeline.
   PassManager manager;
   getDefaultPasses(manager, {});
-  return create(m, &manager, transformer);
+  return create(m, &manager, transformer, sharedLibPaths);
 }
 
 Expected<void (*)(void **)> ExecutionEngine::lookup(StringRef name) const {
index f3fb30a..e371bf0 100644 (file)
 
 using namespace mlir;
 
+// Get the LLVM context.
+llvm::LLVMContext &LLVMLowering::getLLVMContext() {
+  return module->getContext();
+}
+
 // Wrap the given LLVM IR type into an LLVM IR dialect type.
 Type LLVMLowering::wrap(llvm::Type *llvmType) {
   return LLVM::LLVMType::get(llvmDialect->getContext(), llvmType);
index 2488c44..8fa7be6 100644 (file)
@@ -9,4 +9,4 @@ add_llvm_library(MLIRLinalg
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Linalg
   )
-add_dependencies(MLIRLinalg MLIRLinalgOpsIncGen)
+add_dependencies(MLIRLinalg MLIRLinalgOpsIncGen MLIRLinalgLibraryOpsIncGen)
index 6c4d5c2..2f5a108 100644 (file)
@@ -166,18 +166,15 @@ public:
     auto indexType = IndexType::get(op->getContext());
     auto voidPtrTy = LLVM::LLVMType::get(
         op->getContext(),
-        lowering.convertType(IntegerType::get(8, op->getContext()))
-            .cast<LLVM::LLVMType>()
-            .getUnderlyingType()
-            ->getPointerTo());
+        llvm::IntegerType::get(lowering.getLLVMContext(), 8)->getPointerTo());
     auto int64Ty = lowering.convertType(operands[0]->getType());
     // Insert the `malloc` declaration if it is not already present.
-    Function *mallocFunc =
-        op->getFunction()->getModule()->getNamedFunction("malloc");
+    auto *module = op->getFunction()->getModule();
+    Function *mallocFunc = module->getNamedFunction("malloc");
     if (!mallocFunc) {
       auto mallocType = rewriter.getFunctionType(int64Ty, voidPtrTy);
       mallocFunc = new Function(rewriter.getUnknownLoc(), "malloc", mallocType);
-      op->getFunction()->getModule()->getFunctions().push_back(mallocFunc);
+      module->getFunctions().push_back(mallocFunc);
     }
 
     // Get MLIR types for injecting element pointer.
@@ -225,17 +222,14 @@ public:
                                   FuncBuilder &rewriter) const override {
     auto voidPtrTy = LLVM::LLVMType::get(
         op->getContext(),
-        lowering.convertType(IntegerType::get(8, op->getContext()))
-            .cast<LLVM::LLVMType>()
-            .getUnderlyingType()
-            ->getPointerTo());
+        llvm::IntegerType::get(lowering.getLLVMContext(), 8)->getPointerTo());
     // Insert the `free` declaration if it is not already present.
-    Function *freeFunc =
-        op->getFunction()->getModule()->getNamedFunction("free");
+    auto *module = op->getFunction()->getModule();
+    Function *freeFunc = module->getNamedFunction("free");
     if (!freeFunc) {
       auto freeType = rewriter.getFunctionType(voidPtrTy, {});
       freeFunc = new Function(rewriter.getUnknownLoc(), "free", freeType);
-      op->getFunction()->getModule()->getFunctions().push_back(freeFunc);
+      module->getFunctions().push_back(freeFunc);
     }
 
     // Get MLIR types for extracting element pointer.
index a17c309..c695852 100644 (file)
@@ -4,6 +4,10 @@ llvm_canonicalize_cmake_booleans(
   LLVM_BUILD_EXAMPLES
   )
 
+# Passed to lit.site.cfg.py.in to set up the path where to find the libraries
+# for linalg integration tests.
+set(MLIR_LINALG_INTEGRATION_TEST_LIB_DIR ${CMAKE_LIBRARY_OUTPUT_DIRECTORY})
+
 configure_lit_site_cfg(
   ${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in
   ${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py
@@ -20,14 +24,13 @@ configure_lit_site_cfg(
 set(MLIR_TEST_DEPENDS
   FileCheck count not
   MLIRUnitTests
-  mlir-blas-cpu-runner
   mlir-cpu-runner
   mlir-opt
   mlir-tblgen
   mlir-translate
+  sdot
   )
 
-
 if(LLVM_BUILD_EXAMPLES)
   list(APPEND MLIR_TEST_DEPENDS
     linalg1-opt
index fc97f11..6aa6d81 100644 (file)
@@ -62,6 +62,7 @@ tools.extend([
     ToolSubst('toy-ch3', unresolved='ignore'),
     ToolSubst('toy-ch4', unresolved='ignore'),
     ToolSubst('toy-ch5', unresolved='ignore'),
+    ToolSubst('%linalg_test_lib_dir', config.linalg_test_lib_dir, unresolved='ignore'),
 ])
 
 llvm_config.add_tool_substitutions(tools, tool_dirs)
index c701b04..cdfa047 100644 (file)
@@ -30,6 +30,7 @@ config.host_arch = "@HOST_ARCH@"
 config.mlir_src_root = "@MLIR_SOURCE_DIR@"
 config.mlir_obj_root = "@MLIR_BINARY_DIR@"
 config.mlir_tools_dir = "@MLIR_TOOLS_DIR@"
+config.linalg_test_lib_dir = "@MLIR_LINALG_INTEGRATION_TEST_LIB_DIR@"
 config.build_examples = @LLVM_BUILD_EXAMPLES@
 
 # Support substitution of the tools_dir with user parameters. This is
index dac832e..452019e 100644 (file)
@@ -1,27 +1,2 @@
-set(LIBS
-  MLIRAffineOps
-  MLIRAnalysis
-  MLIREDSC
-  MLIRExecutionEngine
-  MLIRIR
-  MLIRLLVMIR
-  MLIRParser
-  MLIRTargetLLVMIR
-  MLIRTransforms
-  MLIRSupport
-  MLIRCPURunnerLib
-  LLVMCore
-  LLVMSupport
-)
-add_executable(mlir-blas-cpu-runner
-  mlir-blas-cpu-runner.cpp
-)
-llvm_update_compile_flags(mlir-blas-cpu-runner)
-whole_archive_link(mlir-blas-cpu-runner
-  MLIRLLVMIR
-  MLIRStandardOps
-  MLIRTargetLLVMIR
-  MLIRTransforms
-  MLIRTranslation
-)
-target_link_libraries(mlir-blas-cpu-runner MLIRIR ${LIBS})
+add_llvm_library(sdot SHARED sdot.cpp)
+
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -linalg-lower-to-llvm-dialect | mlir-blas-cpu-runner -e entry1 -entry-point-result=f32 | FileCheck %s
+// RUN: mlir-opt %s -linalg-lower-to-llvm-dialect | mlir-cpu-runner -e entry1 -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libsdot.so | FileCheck %s
 
 func @cblas_sdot(!llvm.i64, !llvm<"float*">, !llvm.i64, !llvm<"float*">, !llvm.i64) -> !llvm.float
 
similarity index 67%
rename from mlir/test/mlir-cpu-runner/mlir-blas-cpu-runner.cpp
rename to mlir/test/mlir-cpu-runner/sdot.cpp
index 52969bd..8336355 100644 (file)
@@ -1,4 +1,4 @@
-//===- mlir-blas-cpu-runner.cpp - MLIR CPU Execution Driver + Blas Support ===//
+//===- sdot.cpp - Simple sdot Blas Function -------------------------------===//
 //
 // Copyright 2019 The MLIR Authors.
 //
 // limitations under the License.
 // =============================================================================
 //
-// Main entry point.
+// Sdot implementation.
 //
 //===----------------------------------------------------------------------===//
 
-#include "llvm/ADT/StringRef.h"
-#include "llvm/Support/DynamicLibrary.h"
-
-#ifdef WITH_LAPACK
-#include "lapack/cblas.h"
-#else
 extern "C" float cblas_sdot(const int N, const float *X, const int incX,
                             const float *Y, const int incY) {
   float res = 0.0f;
@@ -32,16 +26,3 @@ extern "C" float cblas_sdot(const int N, const float *X, const int incX,
     res += X[i * incX] * Y[i * incY];
   return res;
 }
-#endif
-
-extern int run(int argc, char **argv);
-
-void addSymbols() {
-  using llvm::sys::DynamicLibrary;
-  DynamicLibrary::AddSymbol("cblas_sdot", (void *)(&cblas_sdot));
-}
-
-int main(int argc, char **argv) {
-  addSymbols();
-  return run(argc, argv);
-}
index d3e65dc..ba487ea 100644 (file)
@@ -79,6 +79,12 @@ static llvm::cl::opt<bool> optO2("O2", llvm::cl::desc("Run opt O2 passes"),
 static llvm::cl::opt<bool> optO3("O3", llvm::cl::desc("Run opt O3 passes"),
                                  llvm::cl::cat(optFlags));
 
+static llvm::cl::OptionCategory clOptionsCategory("linking options");
+static llvm::cl::list<std::string>
+    clSharedLibs("shared-libs", llvm::cl::desc("Libraries to link dynamically"),
+                 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated,
+                 llvm::cl::cat(clOptionsCategory));
+
 static std::unique_ptr<Module> parseMLIRInput(StringRef inputFilename,
                                               MLIRContext *context) {
   // Set up the input file.
@@ -156,7 +162,9 @@ static Error compileAndExecuteFunctionWithMemRefs(
   if (!expectedArguments)
     return expectedArguments.takeError();
 
-  auto expectedEngine = mlir::ExecutionEngine::create(module, transformer);
+  SmallVector<StringRef, 4> libs(clSharedLibs.begin(), clSharedLibs.end());
+  auto expectedEngine =
+      mlir::ExecutionEngine::create(module, transformer, libs);
   if (!expectedEngine)
     return expectedEngine.takeError();
 
@@ -193,7 +201,9 @@ static Error compileAndExecuteSingleFloatReturnFunction(
   if (llvmTy != llvmTy->getFloatTy(llvmTy->getContext()))
     return make_string_error("only single llvm.f32 function result supported");
 
-  auto expectedEngine = mlir::ExecutionEngine::create(module, transformer);
+  SmallVector<StringRef, 4> libs(clSharedLibs.begin(), clSharedLibs.end());
+  auto expectedEngine =
+      mlir::ExecutionEngine::create(module, transformer, libs);
   if (!expectedEngine)
     return expectedEngine.takeError();