[OpenMP][JIT] Cleanup JIT interface, caching, and races
authorJohannes Doerfert <johannes@jdoerfert.de>
Wed, 4 Jan 2023 19:33:44 +0000 (11:33 -0800)
committerJohannes Doerfert <johannes@jdoerfert.de>
Sun, 15 Jan 2023 19:43:50 +0000 (11:43 -0800)
The JIT interface was somewhat irregular as it used multiple global
functions. It also did not cache the results of the JIT, hence multiple
GPU systems would perform the work multiple times. Finally, there might
have been races on the state if we have multi-threaded initialization of
different embedded images, or one image initialized on multiple devices.

This patch tries to rectify all of the above. The JITEngine is now a
part of the GenericPluginTy and tied to one target triple. To support
multiple "ComputeUnitKind"s (previously confusingly called Arch or
[M]CPU) and to avoid re-jitting for the same ComputeUnitKind, we keep a
map of JIT results per ComputeUnitKind. All interaction with the JIT
happens through the JITEngine directly, two functions are exposed. Both
use (shared) locks to avoid races and cache the result. All JIT-related
environment variables are now defined together.

Differential Revision: https://reviews.llvm.org/D141081

openmp/libomptarget/plugins-nextgen/amdgpu/src/rtl.cpp
openmp/libomptarget/plugins-nextgen/common/PluginInterface/JIT.cpp
openmp/libomptarget/plugins-nextgen/common/PluginInterface/JIT.h
openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.cpp
openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.h
openmp/libomptarget/plugins-nextgen/cuda/src/rtl.cpp
openmp/libomptarget/plugins-nextgen/generic-elf-64bit/src/rtl.cpp

index 4eefd19..14efe84 100644 (file)
@@ -1530,7 +1530,7 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
     char GPUName[64];
     if (auto Err = getDeviceAttr(HSA_AGENT_INFO_NAME, GPUName))
       return Err;
-    Arch = GPUName;
+    ComputeUnitKind = GPUName;
 
     // Get the wavefront size.
     uint32_t WavefrontSize = 0;
@@ -1669,7 +1669,7 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
     INFO(OMP_INFOTYPE_PLUGIN_KERNEL, getDeviceId(),
          "Using `%s` to link JITed amdgcn ouput.", LLDPath.c_str());
 
-    std::string MCPU = "-plugin-opt=mcpu=" + getArch();
+    std::string MCPU = "-plugin-opt=mcpu=" + getComputeUnitKind();
 
     StringRef Args[] = {LLDPath,
                         "-flavor",
@@ -1692,7 +1692,8 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
         MemoryBuffer::getFileOrSTDIN(LinkerOutputFilePath.data()).get());
   }
 
-  std::string getArch() const override { return Arch; }
+  /// See GenericDeviceTy::getComputeUnitKind().
+  std::string getComputeUnitKind() const override { return ComputeUnitKind; }
 
   /// Allocate and construct an AMDGPU kernel.
   Expected<GenericKernelTy *>
@@ -2096,7 +2097,7 @@ private:
   hsa_agent_t Agent;
 
   /// The GPU architecture.
-  std::string Arch;
+  std::string ComputeUnitKind;
 
   /// Reference to the host device.
   AMDHostDeviceTy &HostDevice;
@@ -2244,7 +2245,7 @@ private:
 /// Class implementing the AMDGPU-specific functionalities of the plugin.
 struct AMDGPUPluginTy final : public GenericPluginTy {
   /// Create an AMDGPU plugin and initialize the AMDGPU driver.
-  AMDGPUPluginTy() : GenericPluginTy(), HostDevice(nullptr) {}
+  AMDGPUPluginTy() : GenericPluginTy(getTripleArch()), HostDevice(nullptr) {}
 
   /// This class should not be copied.
   AMDGPUPluginTy(const AMDGPUPluginTy &) = delete;
index aa0e599..4382135 100644 (file)
 #include "JIT.h"
 #include "Debug.h"
 
+#include "PluginInterface.h"
 #include "Utilities.h"
 #include "omptarget.h"
 
 #include "llvm/ADT/SmallVector.h"
-#include "llvm/ADT/StringRef.h"
 #include "llvm/CodeGen/CommandFlags.h"
 #include "llvm/CodeGen/MachineModuleInfo.h"
 #include "llvm/IR/LLVMContext.h"
@@ -28,7 +28,6 @@
 #include "llvm/Object/IRObjectFile.h"
 #include "llvm/Passes/OptimizationLevel.h"
 #include "llvm/Passes/PassBuilder.h"
-#include "llvm/Support/Error.h"
 #include "llvm/Support/MemoryBuffer.h"
 #include "llvm/Support/SourceMgr.h"
 #include "llvm/Support/TargetSelect.h"
 #include "llvm/Target/TargetOptions.h"
 
 #include <mutex>
+#include <shared_mutex>
 #include <system_error>
 
 using namespace llvm;
 using namespace llvm::object;
 using namespace omp;
+using namespace omp::target;
 
 static codegen::RegisterCodeGenFlags RCGF;
 
 namespace {
+
+/// A map from a bitcode image start address to its corresponding triple. If the
+/// image is not in the map, it is not a bitcode image.
+DenseMap<void *, Triple::ArchType> BitcodeImageMap;
+std::shared_mutex BitcodeImageMapMutex;
+
 std::once_flag InitFlag;
 
 void init(Triple TT) {
@@ -70,10 +77,8 @@ void init(Triple TT) {
     JITTargetInitialized = true;
   }
 #endif
-  if (!JITTargetInitialized) {
-    FAILURE_MESSAGE("unsupported JIT target: %s\n", TT.str().c_str());
-    abort();
-  }
+  if (!JITTargetInitialized)
+    return;
 
   // Initialize passes
   PassRegistry &Registry = *PassRegistry::getPassRegistry();
@@ -125,9 +130,9 @@ createModuleFromMemoryBuffer(std::unique_ptr<MemoryBuffer> &MB,
   return std::move(Mod);
 }
 Expected<std::unique_ptr<Module>>
-createModuleFromImage(__tgt_device_image *Image, LLVMContext &Context) {
-  StringRef Data((const char *)Image->ImageStart,
-                 target::getPtrDiff(Image->ImageEnd, Image->ImageStart));
+createModuleFromImage(const __tgt_device_image &Image, LLVMContext &Context) {
+  StringRef Data((const char *)Image.ImageStart,
+                 target::getPtrDiff(Image.ImageEnd, Image.ImageStart));
   std::unique_ptr<MemoryBuffer> MB = MemoryBuffer::getMemBuffer(
       Data, /* BufferName */ "", /* RequiresNullTerminator */ false);
   return createModuleFromMemoryBuffer(MB, Context);
@@ -192,44 +197,11 @@ createTargetMachine(Module &M, std::string CPU, unsigned OptLevel) {
   return std::move(TM);
 }
 
-///
-class JITEngine {
-public:
-  JITEngine(Triple::ArchType TA, std::string MCpu)
-      : TT(Triple::getArchTypeName(TA)), CPU(MCpu),
-        ReplacementModuleFileName("LIBOMPTARGET_JIT_REPLACEMENT_MODULE"),
-        PreOptIRModuleFileName("LIBOMPTARGET_JIT_PRE_OPT_IR_MODULE"),
-        PostOptIRModuleFileName("LIBOMPTARGET_JIT_POST_OPT_IR_MODULE") {
-    std::call_once(InitFlag, init, TT);
-  }
-
-  /// Run jit compilation. It is expected to get a memory buffer containing the
-  /// generated device image that could be loaded to the device directly.
-  Expected<std::unique_ptr<MemoryBuffer>>
-  run(__tgt_device_image *Image, unsigned OptLevel,
-      jit::PostProcessingFn PostProcessing);
-
-private:
-  /// Run backend, which contains optimization and code generation.
-  Expected<std::unique_ptr<MemoryBuffer>> backend(Module &M, unsigned OptLevel);
-
-  /// Run optimization pipeline.
-  void opt(TargetMachine *TM, TargetLibraryInfoImpl *TLII, Module &M,
-           unsigned OptLevel);
-
-  /// Run code generation.
-  void codegen(TargetMachine *TM, TargetLibraryInfoImpl *TLII, Module &M,
-               raw_pwrite_stream &OS);
-
-  LLVMContext Context;
-  const Triple TT;
-  const std::string CPU;
+} // namespace
 
-  /// Control environment variables.
-  target::StringEnvar ReplacementModuleFileName;
-  target::StringEnvar PreOptIRModuleFileName;
-  target::StringEnvar PostOptIRModuleFileName;
-};
+JITEngine::JITEngine(Triple::ArchType TA) : TT(Triple::getArchTypeName(TA)) {
+  std::call_once(InitFlag, init, TT);
+}
 
 void JITEngine::opt(TargetMachine *TM, TargetLibraryInfoImpl *TLII, Module &M,
                     unsigned OptLevel) {
@@ -274,18 +246,19 @@ void JITEngine::codegen(TargetMachine *TM, TargetLibraryInfoImpl *TLII,
   PM.run(M);
 }
 
-Expected<std::unique_ptr<MemoryBuffer>> JITEngine::backend(Module &M,
-                                                           unsigned OptLevel) {
+Expected<std::unique_ptr<MemoryBuffer>>
+JITEngine::backend(Module &M, const std::string &ComputeUnitKind,
+                   unsigned OptLevel) {
 
   auto RemarksFileOrErr = setupLLVMOptimizationRemarks(
-      Context, /* RemarksFilename */ "", /* RemarksPasses */ "",
+      M.getContext(), /* RemarksFilename */ "", /* RemarksPasses */ "",
       /* RemarksFormat */ "", /* RemarksWithHotness */ false);
   if (Error E = RemarksFileOrErr.takeError())
     return std::move(E);
   if (*RemarksFileOrErr)
     (*RemarksFileOrErr)->keep();
 
-  auto TMOrErr = createTargetMachine(M, CPU, OptLevel);
+  auto TMOrErr = createTargetMachine(M, ComputeUnitKind, OptLevel);
   if (!TMOrErr)
     return TMOrErr.takeError();
 
@@ -323,14 +296,23 @@ Expected<std::unique_ptr<MemoryBuffer>> JITEngine::backend(Module &M,
   return MemoryBuffer::getMemBufferCopy(OS.str());
 }
 
-Expected<std::unique_ptr<MemoryBuffer>>
-JITEngine::run(__tgt_device_image *Image, unsigned OptLevel,
-               jit::PostProcessingFn PostProcessing) {
+Expected<const __tgt_device_image *>
+JITEngine::compile(const __tgt_device_image &Image,
+                   const std::string &ComputeUnitKind,
+                   PostProcessingFn PostProcessing) {
+  std::lock_guard<std::mutex> Lock(ComputeUnitMapMutex);
+
+  // Check if we JITed this image for the given compute unit kind before.
+  ComputeUnitInfo &CUI = ComputeUnitMap[ComputeUnitKind];
+  if (__tgt_device_image *JITedImage = CUI.TgtImageMap.lookup(&Image))
+    return JITedImage;
+
   Module *Mod = nullptr;
   // Check if the user replaces the module at runtime or we read it from the
   // image.
+  // TODO: Allow the user to specify images per device (Arch + ComputeUnitKind).
   if (!ReplacementModuleFileName.isPresent()) {
-    auto ModOrErr = createModuleFromImage(Image, Context);
+    auto ModOrErr = createModuleFromImage(Image, CUI.Context);
     if (!ModOrErr)
       return ModOrErr.takeError();
     Mod = ModOrErr->release();
@@ -341,44 +323,65 @@ JITEngine::run(__tgt_device_image *Image, unsigned OptLevel,
       return createStringError(MBOrErr.getError(),
                                "Could not read replacement module from %s\n",
                                ReplacementModuleFileName.get().c_str());
-    auto ModOrErr = createModuleFromMemoryBuffer(MBOrErr.get(), Context);
+    auto ModOrErr = createModuleFromMemoryBuffer(MBOrErr.get(), CUI.Context);
     if (!ModOrErr)
       return ModOrErr.takeError();
     Mod = ModOrErr->release();
   }
 
-  auto MBOrError = backend(*Mod, OptLevel);
+  auto MBOrError = backend(*Mod, ComputeUnitKind, JITOptLevel);
   if (!MBOrError)
     return MBOrError.takeError();
 
-  return PostProcessing(std::move(*MBOrError));
+  auto ImageMBOrErr = PostProcessing(std::move(*MBOrError));
+  if (!ImageMBOrErr)
+    return ImageMBOrErr.takeError();
+
+  CUI.JITImages.push_back(std::move(*ImageMBOrErr));
+  __tgt_device_image *&JITedImage = CUI.TgtImageMap[&Image];
+  JITedImage = new __tgt_device_image();
+  *JITedImage = Image;
+
+  auto &ImageMB = CUI.JITImages.back();
+
+  JITedImage->ImageStart = (void *)ImageMB->getBufferStart();
+  JITedImage->ImageEnd = (void *)ImageMB->getBufferEnd();
+
+  return JITedImage;
 }
 
-/// A map from a bitcode image start address to its corresponding triple. If the
-/// image is not in the map, it is not a bitcode image.
-DenseMap<void *, Triple::ArchType> BitcodeImageMap;
+Expected<const __tgt_device_image *>
+JITEngine::process(const __tgt_device_image &Image,
+                   target::plugin::GenericDeviceTy &Device) {
+  const std::string &ComputeUnitKind = Device.getComputeUnitKind();
 
-/// Output images generated from LLVM backend.
-SmallVector<std::unique_ptr<MemoryBuffer>, 4> JITImages;
+  PostProcessingFn PostProcessing = [&Device](std::unique_ptr<MemoryBuffer> MB)
+      -> Expected<std::unique_ptr<MemoryBuffer>> {
+    return Device.doJITPostProcessing(std::move(MB));
+  };
 
-/// A list of __tgt_device_image images.
-std::list<__tgt_device_image> TgtImages;
-} // namespace
+  {
+    std::shared_lock<std::shared_mutex> SharedLock(BitcodeImageMapMutex);
+    auto Itr = BitcodeImageMap.find(Image.ImageStart);
+    if (Itr != BitcodeImageMap.end() && Itr->second == TT.getArch())
+      return compile(Image, ComputeUnitKind, PostProcessing);
+  }
 
-namespace llvm {
-namespace omp {
-namespace jit {
-bool checkBitcodeImage(__tgt_device_image *Image, Triple::ArchType TA) {
+  return &Image;
+}
+
+bool JITEngine::checkBitcodeImage(const __tgt_device_image &Image) {
   TimeTraceScope TimeScope("Check bitcode image");
+  std::lock_guard<std::shared_mutex> Lock(BitcodeImageMapMutex);
 
   {
-    auto Itr = BitcodeImageMap.find(Image->ImageStart);
-    if (Itr != BitcodeImageMap.end() && Itr->second == TA)
+    auto Itr = BitcodeImageMap.find(Image.ImageStart);
+    if (Itr != BitcodeImageMap.end() && Itr->second == TT.getArch())
       return true;
   }
 
-  StringRef Data(reinterpret_cast<const char *>(Image->ImageStart),
-                 target::getPtrDiff(Image->ImageEnd, Image->ImageStart));
+  StringRef Data(reinterpret_cast<const char *>(Image.ImageStart),
+                 target::getPtrDiff(Image.ImageEnd, Image.ImageStart));
   std::unique_ptr<MemoryBuffer> MB = MemoryBuffer::getMemBuffer(
       Data, /* BufferName */ "", /* RequiresNullTerminator */ false);
   if (!MB)
@@ -391,37 +394,8 @@ bool checkBitcodeImage(__tgt_device_image *Image, Triple::ArchType TA) {
   }
 
   auto ActualTriple = FOrErr->TheReader.getTargetTriple();
+  auto BitcodeTA = Triple(ActualTriple).getArch();
+  BitcodeImageMap[Image.ImageStart] = BitcodeTA;
 
-  if (Triple(ActualTriple).getArch() == TA) {
-    BitcodeImageMap[Image->ImageStart] = TA;
-    return true;
-  }
-
-  return false;
+  return BitcodeTA == TT.getArch();
 }
-
-Expected<__tgt_device_image *> compile(__tgt_device_image *Image,
-                                       Triple::ArchType TA, std::string MCPU,
-                                       unsigned OptLevel,
-                                       PostProcessingFn PostProcessing) {
-  JITEngine J(TA, MCPU);
-
-  auto ImageMBOrErr = J.run(Image, OptLevel, PostProcessing);
-  if (!ImageMBOrErr)
-    return ImageMBOrErr.takeError();
-
-  JITImages.push_back(std::move(*ImageMBOrErr));
-  TgtImages.push_back(*Image);
-
-  auto &ImageMB = JITImages.back();
-  auto *NewImage = &TgtImages.back();
-
-  NewImage->ImageStart = (void *)ImageMB->getBufferStart();
-  NewImage->ImageEnd = (void *)ImageMB->getBufferEnd();
-
-  return NewImage;
-}
-
-} // namespace jit
-} // namespace omp
-} // namespace llvm
index 73483ce..0c51810 100644 (file)
 #ifndef OPENMP_LIBOMPTARGET_PLUGINS_NEXTGEN_COMMON_JIT_H
 #define OPENMP_LIBOMPTARGET_PLUGINS_NEXTGEN_COMMON_JIT_H
 
+#include "Utilities.h"
+
+#include "llvm/ADT/StringMap.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/ADT/Triple.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/IR/Module.h"
 #include "llvm/Support/Error.h"
+#include "llvm/Target/TargetMachine.h"
 
 #include <functional>
 #include <memory>
+#include <shared_mutex>
 #include <string>
 
 struct __tgt_device_image;
@@ -25,25 +32,84 @@ namespace llvm {
 class MemoryBuffer;
 
 namespace omp {
-namespace jit {
-
-/// Function type for a callback that will be called after the backend is
-/// called.
-using PostProcessingFn = std::function<Expected<std::unique_ptr<MemoryBuffer>>(
-    std::unique_ptr<MemoryBuffer>)>;
-
-/// Check if \p Image contains bitcode with triple \p Triple.
-bool checkBitcodeImage(__tgt_device_image *Image, Triple::ArchType TA);
-
-/// Compile the bitcode image \p Image and generate the binary image that can be
-/// loaded to the target device of the triple \p Triple architecture \p MCpu. \p
-/// PostProcessing will be called after codegen to handle cases such as assember
-/// as an external tool.
-Expected<__tgt_device_image *> compile(__tgt_device_image *Image,
-                                       Triple::ArchType TA, std::string MCpu,
-                                       unsigned OptLevel,
-                                       PostProcessingFn PostProcessing);
-} // namespace jit
+namespace target {
+namespace plugin {
+struct GenericDeviceTy;
+} // namespace plugin
+
+/// The JIT infrastructure and caching mechanism.
+struct JITEngine {
+  /// Function type for a callback that will be called after the backend is
+  /// called.
+  using PostProcessingFn =
+      std::function<Expected<std::unique_ptr<MemoryBuffer>>(
+          std::unique_ptr<MemoryBuffer>)>;
+
+  JITEngine(Triple::ArchType TA);
+
+  /// Run jit compilation if \p Image is a bitcode image, otherwise simply
+  /// return \p Image. It is expected to return a memory buffer containing the
+  /// generated device image that could be loaded to the device directly.
+  Expected<const __tgt_device_image *>
+  process(const __tgt_device_image &Image,
+          target::plugin::GenericDeviceTy &Device);
+
+  /// Return true if \p Image is a bitcode image that can be JITed for the given
+  /// architecture.
+  bool checkBitcodeImage(const __tgt_device_image &Image);
+
+private:
+  /// Compile the bitcode image \p Image and generate the binary image that can
+  /// be loaded to the target device of the triple \p Triple architecture \p
+  /// MCpu. \p PostProcessing will be called after codegen to handle cases such
+  /// as assember as an external tool.
+  Expected<const __tgt_device_image *>
+  compile(const __tgt_device_image &Image, const std::string &ComputeUnitKind,
+          PostProcessingFn PostProcessing);
+
+  /// Run backend, which contains optimization and code generation.
+  Expected<std::unique_ptr<MemoryBuffer>>
+  backend(Module &M, const std::string &ComputeUnitKind, unsigned OptLevel);
+
+  /// Run optimization pipeline.
+  void opt(TargetMachine *TM, TargetLibraryInfoImpl *TLII, Module &M,
+           unsigned OptLevel);
+
+  /// Run code generation.
+  void codegen(TargetMachine *TM, TargetLibraryInfoImpl *TLII, Module &M,
+               raw_pwrite_stream &OS);
+
+  /// The target triple used by the JIT.
+  const Triple TT;
+
+  struct ComputeUnitInfo {
+    /// LLVM Context in which the modules will be constructed.
+    LLVMContext Context;
+
+    /// Output images generated from LLVM backend.
+    SmallVector<std::unique_ptr<MemoryBuffer>, 4> JITImages;
+
+    /// A map of embedded IR images to JITed images.
+    DenseMap<const __tgt_device_image *, __tgt_device_image *> TgtImageMap;
+  };
+
+  /// Map from (march) "CPUs" (e.g., sm_80, or gfx90a), which we call compute
+  /// units as they are not CPUs, to the image information we cached for them.
+  StringMap<ComputeUnitInfo> ComputeUnitMap;
+  std::mutex ComputeUnitMapMutex;
+
+  /// Control environment variables.
+  target::StringEnvar ReplacementModuleFileName =
+      target::StringEnvar("LIBOMPTARGET_JIT_REPLACEMENT_MODULE");
+  target::StringEnvar PreOptIRModuleFileName =
+      target::StringEnvar("LIBOMPTARGET_JIT_PRE_OPT_IR_MODULE");
+  target::StringEnvar PostOptIRModuleFileName =
+      target::StringEnvar("LIBOMPTARGET_JIT_POST_OPT_IR_MODULE");
+  target::UInt32Envar JITOptLevel =
+      target::UInt32Envar("LIBOMPTARGET_JIT_OPT_LEVEL", 3);
+};
+
+} // namespace target
 } // namespace omp
 } // namespace llvm
 
index 3ff7e09..96800fe 100644 (file)
@@ -212,12 +212,22 @@ Error GenericDeviceTy::deinit() {
 
 Expected<__tgt_target_table *>
 GenericDeviceTy::loadBinary(GenericPluginTy &Plugin,
-                            const __tgt_device_image *TgtImage) {
-  DP("Load data from image " DPxMOD "\n", DPxPTR(TgtImage->ImageStart));
+                            const __tgt_device_image *InputTgtImage) {
+  assert(InputTgtImage && "Expected non-null target image");
+  DP("Load data from image " DPxMOD "\n", DPxPTR(InputTgtImage->ImageStart));
+
+  auto PostJITImageOrErr = Plugin.getJIT().process(*InputTgtImage, *this);
+  if (!PostJITImageOrErr) {
+    auto Err = PostJITImageOrErr.takeError();
+    REPORT("Failure to jit IR image %p on device %d: %s\n", InputTgtImage,
+           DeviceId, toString(std::move(Err)).data());
+    return nullptr;
+  }
 
   // Load the binary and allocate the image object. Use the next available id
   // for the image id, which is the number of previously loaded images.
-  auto ImageOrErr = loadBinaryImpl(TgtImage, LoadedImages.size());
+  auto ImageOrErr =
+      loadBinaryImpl(PostJITImageOrErr.get(), LoadedImages.size());
   if (!ImageOrErr)
     return ImageOrErr.takeError();
 
@@ -668,7 +678,7 @@ int32_t __tgt_rtl_is_valid_binary(__tgt_device_image *TgtImage) {
   if (elf_check_machine(TgtImage, Plugin::get().getMagicElfBits()))
     return true;
 
-  return jit::checkBitcodeImage(TgtImage, Plugin::get().getTripleArch());
+  return Plugin::get().getJIT().checkBitcodeImage(*TgtImage);
 }
 
 int32_t __tgt_rtl_is_valid_binary_info(__tgt_device_image *TgtImage,
@@ -745,34 +755,6 @@ __tgt_target_table *__tgt_rtl_load_binary(int32_t DeviceId,
   GenericPluginTy &Plugin = Plugin::get();
   GenericDeviceTy &Device = Plugin.getDevice(DeviceId);
 
-  // If it is a bitcode image, we have to jit the binary image before loading to
-  // the device.
-  {
-    // TODO: Move this (at least the environment variable) into the JIT.h.
-    UInt32Envar JITOptLevel("LIBOMPTARGET_JIT_OPT_LEVEL", 3);
-    Triple::ArchType TA = Plugin.getTripleArch();
-    std::string Arch = Device.getArch();
-
-    jit::PostProcessingFn PostProcessing =
-        [&Device](std::unique_ptr<MemoryBuffer> MB)
-        -> Expected<std::unique_ptr<MemoryBuffer>> {
-      return Device.doJITPostProcessing(std::move(MB));
-    };
-
-    if (jit::checkBitcodeImage(TgtImage, TA)) {
-      auto TgtImageOrErr =
-          jit::compile(TgtImage, TA, Arch, JITOptLevel, PostProcessing);
-      if (!TgtImageOrErr) {
-        auto Err = TgtImageOrErr.takeError();
-        REPORT("Failure to jit binary image from bitcode image %p on device "
-               "%d: %s\n",
-               TgtImage, DeviceId, toString(std::move(Err)).data());
-        return nullptr;
-      }
-
-      TgtImage = *TgtImageOrErr;
-    }
-  }
 
   auto TableOrErr = Device.loadBinary(Plugin, TgtImage);
   if (!TableOrErr) {
index 752ee2e..d65209d 100644 (file)
@@ -21,6 +21,7 @@
 #include "Debug.h"
 #include "DeviceEnvironment.h"
 #include "GlobalHandler.h"
+#include "JIT.h"
 #include "MemoryManager.h"
 #include "Utilities.h"
 #include "omptarget.h"
@@ -37,6 +38,7 @@
 namespace llvm {
 namespace omp {
 namespace target {
+
 namespace plugin {
 
 struct GenericPluginTy;
@@ -378,10 +380,8 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
   }
   uint32_t getDynamicMemorySize() const { return OMPX_SharedMemorySize; }
 
-  /// Get target architecture.
-  virtual std::string getArch() const {
-    return "unknown";
-  }
+  /// Get target compute unit kind (e.g., sm_80, or gfx908).
+  virtual std::string getComputeUnitKind() const { return "unknown"; }
 
   /// Post processing after jit backend. The ownership of \p MB will be taken.
   virtual Expected<std::unique_ptr<MemoryBuffer>>
@@ -513,8 +513,8 @@ protected:
 struct GenericPluginTy {
 
   /// Construct a plugin instance.
-  GenericPluginTy()
-      : RequiresFlags(OMP_REQ_UNDEFINED), GlobalHandler(nullptr) {}
+  GenericPluginTy(Triple::ArchType TA)
+      : RequiresFlags(OMP_REQ_UNDEFINED), GlobalHandler(nullptr), JIT(TA) {}
 
   virtual ~GenericPluginTy() {}
 
@@ -543,9 +543,7 @@ struct GenericPluginTy {
   virtual uint16_t getMagicElfBits() const = 0;
 
   /// Get the target triple of this plugin.
-  virtual Triple::ArchType getTripleArch() const {
-    return Triple::ArchType::UnknownArch;
-  }
+  virtual Triple::ArchType getTripleArch() const = 0;
 
   /// Allocate a structure using the internal allocator.
   template <typename Ty> Ty *allocate() {
@@ -558,6 +556,10 @@ struct GenericPluginTy {
     return *GlobalHandler;
   }
 
+  /// Get the reference to the JIT used for all devices connected to this
+  /// plugin.
+  JITEngine &getJIT() { return JIT; }
+
   /// Get the OpenMP requires flags set for this plugin.
   int64_t getRequiresFlags() const { return RequiresFlags; }
 
@@ -609,6 +611,9 @@ private:
 
   /// Internal allocator for different structures.
   BumpPtrAllocator Allocator;
+
+  /// The JIT engine shared by all devices connected to this plugin.
+  JITEngine JIT;
 };
 
 /// Class for simplifying the getter operation of the plugin. Anywhere on the
index cb5f004..cfc97b6 100644 (file)
@@ -784,8 +784,10 @@ struct CUDADeviceTy : public GenericDeviceTy {
     return Plugin::check(Res, "Error in cuDeviceGetAttribute: %s");
   }
 
-  /// See GenericDeviceTy::getArch().
-  std::string getArch() const override { return ComputeCapability.str(); }
+  /// See GenericDeviceTy::getComputeUnitKind().
+  std::string getComputeUnitKind() const override {
+    return ComputeCapability.str();
+  }
 
 private:
   using CUDAStreamManagerTy = GenericDeviceResourceManagerTy<CUDAStreamRef>;
@@ -867,7 +869,7 @@ public:
 /// Class implementing the CUDA-specific functionalities of the plugin.
 struct CUDAPluginTy final : public GenericPluginTy {
   /// Create a CUDA plugin.
-  CUDAPluginTy() : GenericPluginTy() {}
+  CUDAPluginTy() : GenericPluginTy(getTripleArch()) {}
 
   /// This class should not be copied.
   CUDAPluginTy(const CUDAPluginTy &) = delete;
index ed6897a..ab5f9e3 100644 (file)
@@ -340,7 +340,7 @@ public:
 /// Class implementing the plugin functionalities for GenELF64.
 struct GenELF64PluginTy final : public GenericPluginTy {
   /// Create the GenELF64 plugin.
-  GenELF64PluginTy() : GenericPluginTy() {}
+  GenELF64PluginTy() : GenericPluginTy(getTripleArch()) {}
 
   /// This class should not be copied.
   GenELF64PluginTy(const GenELF64PluginTy &) = delete;