[mlir] Introduce dialect interfaces for translation to LLVM IR
authorAlex Zinenko <zinenko@google.com>
Thu, 11 Feb 2021 14:01:33 +0000 (15:01 +0100)
committerAlex Zinenko <zinenko@google.com>
Fri, 12 Feb 2021 16:49:44 +0000 (17:49 +0100)
The existing approach to translation to the LLVM IR relies on a single
translation supporting the base LLVM dialect, extensible through inheritance to
support intrinsic-based dialects also derived from LLVM IR such as NVVM and
AVX512. This approach does not scale well as it requires additional
translations to be created for each new intrinsic-based dialect and does not
allow them to mix in the same module, contrary to the rest of the MLIR
infrastructure. Furthermore, OpenMP translation ingrained itself into the main
translation mechanism.

Start refactoring the translation to LLVM IR to operate using dialect
interfaces. Each dialect that contains ops translatable to LLVM IR can
implement the interface for translating them, and the top-level translation
driver can operate on interfaces without knowing about specific dialects.
Furthermore, the delayed dialect registration mechanism allows one to avoid a
dependency on LLVM IR in the dialect that is translated to it by implementing
the translation as a separate library and only registering it at the client
level.

This change introduces the new mechanism and factors out the translation of the
"main" LLVM dialect. The remaining dialects will follow suit.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D96503

24 files changed:
mlir/examples/toy/Ch6/toyc.cpp
mlir/examples/toy/Ch7/toyc.cpp
mlir/include/mlir/Target/LLVMIR.h
mlir/include/mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h [new file with mode: 0644]
mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h [new file with mode: 0644]
mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
mlir/lib/Target/CMakeLists.txt
mlir/lib/Target/LLVMIR/CMakeLists.txt [new file with mode: 0644]
mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp
mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp
mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp
mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt [new file with mode: 0644]
mlir/lib/Target/LLVMIR/Dialect/LLVMIR/CMakeLists.txt [new file with mode: 0644]
mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp [new file with mode: 0644]
mlir/lib/Target/LLVMIR/LLVMAVX512Intr.cpp
mlir/lib/Target/LLVMIR/LLVMArmNeonIntr.cpp
mlir/lib/Target/LLVMIR/LLVMArmSVEIntr.cpp
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp
mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp
mlir/tools/mlir-rocm-runner/mlir-rocm-runner.cpp
mlir/tools/mlir-spirv-cpu-runner/mlir-spirv-cpu-runner.cpp
mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
mlir/unittests/ExecutionEngine/Invoke.cpp

index b800c0a..d717f69 100644 (file)
@@ -189,6 +189,9 @@ int dumpAST() {
 }
 
 int dumpLLVMIR(mlir::ModuleOp module) {
+  // Register the translation to LLVM IR with the MLIR context.
+  mlir::registerLLVMDialectTranslation(*module->getContext());
+
   // Convert the module to LLVM IR in a new LLVM IR context.
   llvm::LLVMContext llvmContext;
   auto llvmModule = mlir::translateModuleToLLVMIR(module, llvmContext);
@@ -219,6 +222,10 @@ int runJit(mlir::ModuleOp module) {
   llvm::InitializeNativeTarget();
   llvm::InitializeNativeTargetAsmPrinter();
 
+  // Register the translation from MLIR to LLVM IR, which must happen before we
+  // can JIT-compile.
+  mlir::registerLLVMDialectTranslation(*module->getContext());
+
   // An optimization pipeline to use within the execution engine.
   auto optPipeline = mlir::makeOptimizingTransformer(
       /*optLevel=*/enableOpt ? 3 : 0, /*sizeLevel=*/0,
index 4fdb06d..3898e28 100644 (file)
@@ -190,6 +190,9 @@ int dumpAST() {
 }
 
 int dumpLLVMIR(mlir::ModuleOp module) {
+  // Register the translation to LLVM IR with the MLIR context.
+  mlir::registerLLVMDialectTranslation(*module->getContext());
+
   // Convert the module to LLVM IR in a new LLVM IR context.
   llvm::LLVMContext llvmContext;
   auto llvmModule = mlir::translateModuleToLLVMIR(module, llvmContext);
@@ -220,6 +223,10 @@ int runJit(mlir::ModuleOp module) {
   llvm::InitializeNativeTarget();
   llvm::InitializeNativeTargetAsmPrinter();
 
+  // Register the translation from MLIR to LLVM IR, which must happen before we
+  // can JIT-compile.
+  mlir::registerLLVMDialectTranslation(*module->getContext());
+
   // An optimization pipeline to use within the execution engine.
   auto optPipeline = mlir::makeOptimizingTransformer(
       /*optLevel=*/enableOpt ? 3 : 0, /*sizeLevel=*/0,
index ffd1a4c..2050c63 100644 (file)
@@ -25,6 +25,7 @@ class Module;
 
 namespace mlir {
 
+class DialectRegistry;
 class OwningModuleRef;
 class MLIRContext;
 class ModuleOp;
@@ -45,6 +46,15 @@ OwningModuleRef
 translateLLVMIRToModule(std::unique_ptr<llvm::Module> llvmModule,
                         MLIRContext *context);
 
+/// Register the LLVM dialect and the translation from it to the LLVM IR in the
+/// given registry;
+void registerLLVMDialectTranslation(DialectRegistry &registry);
+
+/// Register the LLVM dialect and the translation from it in the registry
+/// associated with the given context. This checks if the interface is already
+/// registered and avoids double registation.
+void registerLLVMDialectTranslation(MLIRContext &context);
+
 } // namespace mlir
 
 #endif // MLIR_TARGET_LLVMIR_H
diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h b/mlir/include/mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h
new file mode 100644 (file)
index 0000000..8b72ced
--- /dev/null
@@ -0,0 +1,37 @@
+//===- LLVMToLLVMIRTranslation.h - LLVM Dialect to LLVM IR-------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the dialect interface for translating the LLVM dialect
+// to LLVM IR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TARGET_LLVMIR_DIALECT_LLVMIR_LLVMTOLLVMIRTRANSLATION_H
+#define MLIR_TARGET_LLVMIR_DIALECT_LLVMIR_LLVMTOLLVMIRTRANSLATION_H
+
+#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
+
+namespace mlir {
+
+/// Implementation of the dialect interface that converts operations beloning to
+/// the LLVM dialect to LLVM IR.
+class LLVMDialectLLVMIRTranslationInterface
+    : public LLVMTranslationDialectInterface {
+public:
+  using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
+
+  /// Translates the given operation to LLVM IR using the provided IR builder
+  /// and saving the state in `moduleTranslation`.
+  LogicalResult
+  convertOperation(Operation *op, llvm::IRBuilderBase &builder,
+                   LLVM::ModuleTranslation &moduleTranslation) const final;
+};
+
+} // namespace mlir
+
+#endif // MLIR_TARGET_LLVMIR_DIALECT_LLVMIR_LLVMTOLLVMIRTRANSLATION_H
diff --git a/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h b/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h
new file mode 100644 (file)
index 0000000..0063bea
--- /dev/null
@@ -0,0 +1,68 @@
+//===- LLVMTranslationInterface.h - Translation to LLVM iface ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header file defines dialect interfaces for translation to LLVM IR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TARGET_LLVMIR_LLVMTRANSLATIONINTERFACE_H
+#define MLIR_TARGET_LLVMIR_LLVMTRANSLATIONINTERFACE_H
+
+#include "mlir/IR/DialectInterface.h"
+#include "mlir/Support/LogicalResult.h"
+
+namespace llvm {
+class IRBuilderBase;
+}
+
+namespace mlir {
+namespace LLVM {
+class ModuleTranslation;
+} // namespace LLVM
+
+/// Base class for dialect interfaces providing translation to LLVM IR.
+/// Dialects that can be translated should provide an implementation of this
+/// interface for the supported operations. The interface may be implemented in
+/// a separate library to avoid the "main" dialect library depending on LLVM IR.
+/// The interface can be attached using the delayed registration mechanism
+/// available in DialectRegistry.
+class LLVMTranslationDialectInterface
+    : public DialectInterface::Base<LLVMTranslationDialectInterface> {
+public:
+  LLVMTranslationDialectInterface(Dialect *dialect) : Base(dialect) {}
+
+  /// Hook for derived dialect interface to provide translation of the
+  /// operations to LLVM IR.
+  virtual LogicalResult
+  convertOperation(Operation *op, llvm::IRBuilderBase &builder,
+                   LLVM::ModuleTranslation &moduleTranslation) const {
+    return failure();
+  }
+};
+
+/// Interface collection for translation to LLVM IR, dispatches to a concrete
+/// interface implementation based on the dialect to which the given op belongs.
+class LLVMTranslationInterface
+    : public DialectInterfaceCollection<LLVMTranslationDialectInterface> {
+public:
+  using Base::Base;
+
+  /// Translates the given operation to LLVM IR using the interface implemented
+  /// by the op's dialect.
+  virtual LogicalResult
+  convertOperation(Operation *op, llvm::IRBuilderBase &builder,
+                   LLVM::ModuleTranslation &moduleTranslation) const {
+    if (const LLVMTranslationDialectInterface *iface = getInterfaceFor(op))
+      return iface->convertOperation(op, builder, moduleTranslation);
+    return failure();
+  }
+};
+
+} // namespace mlir
+
+#endif // MLIR_TARGET_LLVMIR_LLVMTRANSLATIONINTERFACE_H
index ebe9a7c..b15fcc3 100644 (file)
@@ -19,6 +19,7 @@
 #include "mlir/IR/Block.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/Value.h"
+#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
 #include "mlir/Target/LLVMIR/TypeTranslation.h"
 
 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
@@ -134,6 +135,31 @@ public:
     return branchMapping.lookup(op);
   }
 
+  /// Converts the type from MLIR LLVM dialect to LLVM.
+  llvm::Type *convertType(Type type);
+
+  /// Looks up remapped a list of remapped values.
+  SmallVector<llvm::Value *, 8> lookupValues(ValueRange values);
+
+  /// Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`.
+  /// This currently supports integer, floating point, splat and dense element
+  /// attributes and combinations thereof.  In case of error, report it to `loc`
+  /// and return nullptr.
+  llvm::Constant *getLLVMConstant(llvm::Type *llvmType, Attribute attr,
+                                  Location loc);
+
+  /// Returns the MLIR context of the module being translated.
+  MLIRContext &getContext() { return *mlirModule->getContext(); }
+
+  /// Returns the LLVM context in which the IR is being constructed.
+  llvm::LLVMContext &getLLVMContext() { return llvmModule->getContext(); }
+
+  /// Finds an LLVM IR global value that corresponds to the given MLIR operation
+  /// defining a global value.
+  llvm::GlobalValue *lookupGlobal(Operation *op) {
+    return globalsMapping.lookup(op);
+  }
+
 protected:
   /// Translate the given MLIR module expressed in MLIR LLVM IR dialect into an
   /// LLVM IR module. The MLIR LLVM IR dialect holds a pointer to an
@@ -158,16 +184,10 @@ protected:
   virtual LogicalResult convertOmpWsLoop(Operation &opInst,
                                          llvm::IRBuilder<> &builder);
 
-  /// Converts the type from MLIR LLVM dialect to LLVM.
-  llvm::Type *convertType(Type type);
-
   static std::unique_ptr<llvm::Module>
   prepareLLVMModule(Operation *m, llvm::LLVMContext &llvmContext,
                     StringRef name);
 
-  /// A helper to look up remapped operands in the value remapping table.
-  SmallVector<llvm::Value *, 8> lookupValues(ValueRange values);
-
 private:
   /// Check whether the module contains only supported ops directly in its body.
   static LogicalResult checkSupportedModuleOps(Operation *m);
@@ -179,9 +199,6 @@ private:
   LogicalResult convertBlock(Block &bb, bool ignoreArguments,
                              llvm::IRBuilder<> &builder);
 
-  llvm::Constant *getLLVMConstant(llvm::Type *llvmType, Attribute attr,
-                                  Location loc);
-
   /// Original and translated module.
   Operation *mlirModule;
   std::unique_ptr<llvm::Module> llvmModule;
@@ -202,7 +219,8 @@ private:
   /// A stateful object used to translate types.
   TypeToLLVMIRTranslator typeTranslator;
 
-private:
+  LLVMTranslationInterface iface;
+
   /// Mappings between original and translated values, used for lookups.
   llvm::StringMap<llvm::Function *> functionMapping;
   DenseMap<Value, llvm::Value *> valueMapping;
index 51a0e78..72555ac 100644 (file)
@@ -1,4 +1,5 @@
 add_subdirectory(SPIRV)
+add_subdirectory(LLVMIR)
 
 add_mlir_translation_library(MLIRTargetLLVMIRModuleTranslation
   LLVMIR/DebugTranslation.cpp
@@ -39,6 +40,7 @@ add_mlir_translation_library(MLIRTargetAVX512
   MLIRIR
   MLIRLLVMAVX512
   MLIRLLVMIR
+  MLIRTargetLLVMIR
   MLIRTargetLLVMIRModuleTranslation
   )
 
@@ -54,6 +56,7 @@ add_mlir_translation_library(MLIRTargetLLVMIR
   IRReader
 
   LINK_LIBS PUBLIC
+  MLIRLLVMToLLVMIRTranslation
   MLIRTargetLLVMIRModuleTranslation
   )
 
@@ -73,6 +76,7 @@ add_mlir_translation_library(MLIRTargetArmNeon
   MLIRIR
   MLIRLLVMArmNeon
   MLIRLLVMIR
+  MLIRTargetLLVMIR
   MLIRTargetLLVMIRModuleTranslation
   )
 
@@ -92,6 +96,7 @@ add_mlir_translation_library(MLIRTargetArmSVE
   MLIRIR
   MLIRLLVMArmSVE
   MLIRLLVMIR
+  MLIRTargetLLVMIR
   MLIRTargetLLVMIRModuleTranslation
   )
 
@@ -112,6 +117,7 @@ add_mlir_translation_library(MLIRTargetNVVMIR
   MLIRIR
   MLIRLLVMIR
   MLIRNVVMIR
+  MLIRTargetLLVMIR
   MLIRTargetLLVMIRModuleTranslation
   )
 
@@ -132,5 +138,6 @@ add_mlir_translation_library(MLIRTargetROCDLIR
   MLIRIR
   MLIRLLVMIR
   MLIRROCDLIR
+  MLIRTargetLLVMIR
   MLIRTargetLLVMIRModuleTranslation
   )
diff --git a/mlir/lib/Target/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/CMakeLists.txt
new file mode 100644 (file)
index 0000000..0ca0f41
--- /dev/null
@@ -0,0 +1 @@
+add_subdirectory(Dialect)
index 476f365..bf8e248 100644 (file)
@@ -13,6 +13,7 @@
 #include "mlir/Target/LLVMIR.h"
 
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
 #include "mlir/Translation.h"
 
@@ -35,6 +36,24 @@ mlir::translateModuleToLLVMIR(ModuleOp m, llvm::LLVMContext &llvmContext,
   return llvmModule;
 }
 
+void mlir::registerLLVMDialectTranslation(DialectRegistry &registry) {
+  registry.insert<LLVM::LLVMDialect>();
+  registry.addDialectInterface<LLVM::LLVMDialect,
+                               LLVMDialectLLVMIRTranslationInterface>();
+}
+
+void mlir::registerLLVMDialectTranslation(MLIRContext &context) {
+  auto *dialect = context.getLoadedDialect<LLVM::LLVMDialect>();
+  if (!dialect || dialect->getRegisteredInterface<
+                      LLVMDialectLLVMIRTranslationInterface>() == nullptr) {
+    DialectRegistry registry;
+    registry.insert<LLVM::LLVMDialect>();
+    registry.addDialectInterface<LLVM::LLVMDialect,
+                                 LLVMDialectLLVMIRTranslationInterface>();
+    context.appendDialectRegistry(registry);
+  }
+}
+
 namespace mlir {
 void registerToLLVMIRTranslation() {
   TranslateFromMLIRRegistration registration(
@@ -50,7 +69,8 @@ void registerToLLVMIRTranslation() {
         return success();
       },
       [](DialectRegistry &registry) {
-        registry.insert<LLVM::LLVMDialect, omp::OpenMPDialect>();
+        registry.insert<omp::OpenMPDialect>();
+        registerLLVMDialectTranslation(registry);
       });
 }
 } // namespace mlir
index 668d9d9..7aee913 100644 (file)
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
 #include "mlir/IR/BuiltinOps.h"
+#include "mlir/Target/LLVMIR.h"
 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
 #include "mlir/Translation.h"
 
@@ -68,6 +69,11 @@ protected:
 std::unique_ptr<llvm::Module>
 mlir::translateModuleToNVVMIR(Operation *m, llvm::LLVMContext &llvmContext,
                               StringRef name) {
+  // Register the translation to LLVM IR if nobody else did before. This may
+  // happen if this translation is called inside a pass pipeline that converts
+  // GPU dialects to binary blobs without translating the rest of the code.
+  registerLLVMDialectTranslation(*m->getContext());
+
   auto llvmModule = LLVM::ModuleTranslation::translateModule<ModuleTranslation>(
       m, llvmContext, name);
   if (!llvmModule)
@@ -111,7 +117,8 @@ void registerToNVVMIRTranslation() {
         return success();
       },
       [](DialectRegistry &registry) {
-        registry.insert<LLVM::LLVMDialect, NVVM::NVVMDialect>();
+        registry.insert<NVVM::NVVMDialect>();
+        registerLLVMDialectTranslation(registry);
       });
 }
 } // namespace mlir
index c415787..7ebbd3f 100644 (file)
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
 #include "mlir/IR/BuiltinOps.h"
+#include "mlir/Target/LLVMIR.h"
 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
 #include "mlir/Translation.h"
 
@@ -77,11 +78,16 @@ protected:
 std::unique_ptr<llvm::Module>
 mlir::translateModuleToROCDLIR(Operation *m, llvm::LLVMContext &llvmContext,
                                StringRef name) {
-  // lower MLIR (with RODL Dialect) to LLVM IR (with ROCDL intrinsics)
+  // Register the translation to LLVM IR if nobody else did before. This may
+  // happen if this translation is called inside a pass pipeline that converts
+  // GPU dialects to binary blobs without translating the rest of the code.
+  registerLLVMDialectTranslation(*m->getContext());
+
+  // Lower MLIR (with RODL Dialect) to LLVM IR (with ROCDL intrinsics).
   auto llvmModule = LLVM::ModuleTranslation::translateModule<ModuleTranslation>(
       m, llvmContext, name);
 
-  // foreach GPU kernel
+  // Foreach GPU kernel:
   // 1. Insert AMDGPU_KERNEL calling convention.
   // 2. Insert amdgpu-flat-workgroup-size(1, 1024) attribute.
   for (auto func :
@@ -114,7 +120,8 @@ void registerToROCDLIRTranslation() {
         return success();
       },
       [](DialectRegistry &registry) {
-        registry.insert<ROCDL::ROCDLDialect, LLVM::LLVMDialect>();
+        registry.insert<ROCDL::ROCDLDialect>();
+        registerLLVMDialectTranslation(registry);
       });
 }
 } // namespace mlir
diff --git a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt
new file mode 100644 (file)
index 0000000..39d31dc
--- /dev/null
@@ -0,0 +1 @@
+add_subdirectory(LLVMIR)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/CMakeLists.txt
new file mode 100644 (file)
index 0000000..2da7e95
--- /dev/null
@@ -0,0 +1,12 @@
+add_mlir_translation_library(MLIRLLVMToLLVMIRTranslation
+  LLVMToLLVMIRTranslation.cpp
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRLLVMIR
+  MLIRSupport
+  MLIRTargetLLVMIRModuleTranslation
+  )
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
new file mode 100644 (file)
index 0000000..25d5294
--- /dev/null
@@ -0,0 +1,405 @@
+//===- LLVMToLLVMIRTranslation.cpp - Translate LLVM dialect to LLVM IR ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a translation between the MLIR LLVM dialect and LLVM IR.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Target/LLVMIR/ModuleTranslation.h"
+
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InlineAsm.h"
+#include "llvm/IR/MDBuilder.h"
+#include "llvm/IR/Operator.h"
+
+using namespace mlir;
+using namespace mlir::LLVM;
+
+#include "mlir/Dialect/LLVMIR/LLVMConversionEnumsToLLVM.inc"
+
+/// Convert MLIR integer comparison predicate to LLVM IR comparison predicate.
+static llvm::CmpInst::Predicate getLLVMCmpPredicate(ICmpPredicate p) {
+  switch (p) {
+  case LLVM::ICmpPredicate::eq:
+    return llvm::CmpInst::Predicate::ICMP_EQ;
+  case LLVM::ICmpPredicate::ne:
+    return llvm::CmpInst::Predicate::ICMP_NE;
+  case LLVM::ICmpPredicate::slt:
+    return llvm::CmpInst::Predicate::ICMP_SLT;
+  case LLVM::ICmpPredicate::sle:
+    return llvm::CmpInst::Predicate::ICMP_SLE;
+  case LLVM::ICmpPredicate::sgt:
+    return llvm::CmpInst::Predicate::ICMP_SGT;
+  case LLVM::ICmpPredicate::sge:
+    return llvm::CmpInst::Predicate::ICMP_SGE;
+  case LLVM::ICmpPredicate::ult:
+    return llvm::CmpInst::Predicate::ICMP_ULT;
+  case LLVM::ICmpPredicate::ule:
+    return llvm::CmpInst::Predicate::ICMP_ULE;
+  case LLVM::ICmpPredicate::ugt:
+    return llvm::CmpInst::Predicate::ICMP_UGT;
+  case LLVM::ICmpPredicate::uge:
+    return llvm::CmpInst::Predicate::ICMP_UGE;
+  }
+  llvm_unreachable("incorrect comparison predicate");
+}
+
+static llvm::CmpInst::Predicate getLLVMCmpPredicate(FCmpPredicate p) {
+  switch (p) {
+  case LLVM::FCmpPredicate::_false:
+    return llvm::CmpInst::Predicate::FCMP_FALSE;
+  case LLVM::FCmpPredicate::oeq:
+    return llvm::CmpInst::Predicate::FCMP_OEQ;
+  case LLVM::FCmpPredicate::ogt:
+    return llvm::CmpInst::Predicate::FCMP_OGT;
+  case LLVM::FCmpPredicate::oge:
+    return llvm::CmpInst::Predicate::FCMP_OGE;
+  case LLVM::FCmpPredicate::olt:
+    return llvm::CmpInst::Predicate::FCMP_OLT;
+  case LLVM::FCmpPredicate::ole:
+    return llvm::CmpInst::Predicate::FCMP_OLE;
+  case LLVM::FCmpPredicate::one:
+    return llvm::CmpInst::Predicate::FCMP_ONE;
+  case LLVM::FCmpPredicate::ord:
+    return llvm::CmpInst::Predicate::FCMP_ORD;
+  case LLVM::FCmpPredicate::ueq:
+    return llvm::CmpInst::Predicate::FCMP_UEQ;
+  case LLVM::FCmpPredicate::ugt:
+    return llvm::CmpInst::Predicate::FCMP_UGT;
+  case LLVM::FCmpPredicate::uge:
+    return llvm::CmpInst::Predicate::FCMP_UGE;
+  case LLVM::FCmpPredicate::ult:
+    return llvm::CmpInst::Predicate::FCMP_ULT;
+  case LLVM::FCmpPredicate::ule:
+    return llvm::CmpInst::Predicate::FCMP_ULE;
+  case LLVM::FCmpPredicate::une:
+    return llvm::CmpInst::Predicate::FCMP_UNE;
+  case LLVM::FCmpPredicate::uno:
+    return llvm::CmpInst::Predicate::FCMP_UNO;
+  case LLVM::FCmpPredicate::_true:
+    return llvm::CmpInst::Predicate::FCMP_TRUE;
+  }
+  llvm_unreachable("incorrect comparison predicate");
+}
+
+static llvm::AtomicRMWInst::BinOp getLLVMAtomicBinOp(AtomicBinOp op) {
+  switch (op) {
+  case LLVM::AtomicBinOp::xchg:
+    return llvm::AtomicRMWInst::BinOp::Xchg;
+  case LLVM::AtomicBinOp::add:
+    return llvm::AtomicRMWInst::BinOp::Add;
+  case LLVM::AtomicBinOp::sub:
+    return llvm::AtomicRMWInst::BinOp::Sub;
+  case LLVM::AtomicBinOp::_and:
+    return llvm::AtomicRMWInst::BinOp::And;
+  case LLVM::AtomicBinOp::nand:
+    return llvm::AtomicRMWInst::BinOp::Nand;
+  case LLVM::AtomicBinOp::_or:
+    return llvm::AtomicRMWInst::BinOp::Or;
+  case LLVM::AtomicBinOp::_xor:
+    return llvm::AtomicRMWInst::BinOp::Xor;
+  case LLVM::AtomicBinOp::max:
+    return llvm::AtomicRMWInst::BinOp::Max;
+  case LLVM::AtomicBinOp::min:
+    return llvm::AtomicRMWInst::BinOp::Min;
+  case LLVM::AtomicBinOp::umax:
+    return llvm::AtomicRMWInst::BinOp::UMax;
+  case LLVM::AtomicBinOp::umin:
+    return llvm::AtomicRMWInst::BinOp::UMin;
+  case LLVM::AtomicBinOp::fadd:
+    return llvm::AtomicRMWInst::BinOp::FAdd;
+  case LLVM::AtomicBinOp::fsub:
+    return llvm::AtomicRMWInst::BinOp::FSub;
+  }
+  llvm_unreachable("incorrect atomic binary operator");
+}
+
+static llvm::AtomicOrdering getLLVMAtomicOrdering(AtomicOrdering ordering) {
+  switch (ordering) {
+  case LLVM::AtomicOrdering::not_atomic:
+    return llvm::AtomicOrdering::NotAtomic;
+  case LLVM::AtomicOrdering::unordered:
+    return llvm::AtomicOrdering::Unordered;
+  case LLVM::AtomicOrdering::monotonic:
+    return llvm::AtomicOrdering::Monotonic;
+  case LLVM::AtomicOrdering::acquire:
+    return llvm::AtomicOrdering::Acquire;
+  case LLVM::AtomicOrdering::release:
+    return llvm::AtomicOrdering::Release;
+  case LLVM::AtomicOrdering::acq_rel:
+    return llvm::AtomicOrdering::AcquireRelease;
+  case LLVM::AtomicOrdering::seq_cst:
+    return llvm::AtomicOrdering::SequentiallyConsistent;
+  }
+  llvm_unreachable("incorrect atomic ordering");
+}
+
+static llvm::FastMathFlags getFastmathFlags(FastmathFlagsInterface &op) {
+  using llvmFMF = llvm::FastMathFlags;
+  using FuncT = void (llvmFMF::*)(bool);
+  const std::pair<FastmathFlags, FuncT> handlers[] = {
+      // clang-format off
+      {FastmathFlags::nnan,     &llvmFMF::setNoNaNs},
+      {FastmathFlags::ninf,     &llvmFMF::setNoInfs},
+      {FastmathFlags::nsz,      &llvmFMF::setNoSignedZeros},
+      {FastmathFlags::arcp,     &llvmFMF::setAllowReciprocal},
+      {FastmathFlags::contract, &llvmFMF::setAllowContract},
+      {FastmathFlags::afn,      &llvmFMF::setApproxFunc},
+      {FastmathFlags::reassoc,  &llvmFMF::setAllowReassoc},
+      {FastmathFlags::fast,     &llvmFMF::setFast},
+      // clang-format on
+  };
+  llvm::FastMathFlags ret;
+  auto fmf = op.fastmathFlags();
+  for (auto it : handlers)
+    if (bitEnumContains(fmf, it.first))
+      (ret.*(it.second))(true);
+  return ret;
+}
+
+namespace {
+/// Dispatcher functional object targeting different overloads of
+/// ModuleTranslation::mapValue.
+// TODO: this is only necessary for compatibility with the code emitted from
+// ODS, remove when ODS is updated (after all dialects have migrated to the new
+// translation mechanism).
+struct MapValueDispatcher {
+  explicit MapValueDispatcher(ModuleTranslation &mt) : moduleTranslation(mt) {}
+
+  llvm::Value *&operator()(mlir::Value v) {
+    return moduleTranslation.mapValue(v);
+  }
+
+  void operator()(mlir::Value m, llvm::Value *l) {
+    moduleTranslation.mapValue(m, l);
+  }
+
+  LLVM::ModuleTranslation &moduleTranslation;
+};
+} // end namespace
+
+static LogicalResult
+convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
+                     LLVM::ModuleTranslation &moduleTranslation) {
+  auto extractPosition = [](ArrayAttr attr) {
+    SmallVector<unsigned, 4> position;
+    position.reserve(attr.size());
+    for (Attribute v : attr)
+      position.push_back(v.cast<IntegerAttr>().getValue().getZExtValue());
+    return position;
+  };
+
+  llvm::IRBuilder<>::FastMathFlagGuard fmfGuard(builder);
+  if (auto fmf = dyn_cast<FastmathFlagsInterface>(opInst))
+    builder.setFastMathFlags(getFastmathFlags(fmf));
+
+  // TODO: these are necessary for compatibility with the code emitted from ODS,
+  // remove them when ODS is updated (after all dialects have migrated to the
+  // new translation mechanism).
+  MapValueDispatcher mapValue(moduleTranslation);
+  auto lookupValue = [&](mlir::Value v) {
+    return moduleTranslation.lookupValue(v);
+  };
+  auto convertType = [&](Type ty) { return moduleTranslation.convertType(ty); };
+  auto lookupValues = [&](ValueRange vs) {
+    return moduleTranslation.lookupValues(vs);
+  };
+  auto getLLVMConstant = [&](llvm::Type *ty, Attribute attr, Location loc) {
+    return moduleTranslation.getLLVMConstant(ty, attr, loc);
+  };
+
+#include "mlir/Dialect/LLVMIR/LLVMConversions.inc"
+
+  // Emit function calls.  If the "callee" attribute is present, this is a
+  // direct function call and we also need to look up the remapped function
+  // itself.  Otherwise, this is an indirect call and the callee is the first
+  // operand, look it up as a normal value.  Return the llvm::Value representing
+  // the function result, which may be of llvm::VoidTy type.
+  auto convertCall = [&](Operation &op) -> llvm::Value * {
+    auto operands = moduleTranslation.lookupValues(op.getOperands());
+    ArrayRef<llvm::Value *> operandsRef(operands);
+    if (auto attr = op.getAttrOfType<FlatSymbolRefAttr>("callee"))
+      return builder.CreateCall(
+          moduleTranslation.lookupFunction(attr.getValue()), operandsRef);
+    auto *calleePtrType =
+        cast<llvm::PointerType>(operandsRef.front()->getType());
+    auto *calleeType =
+        cast<llvm::FunctionType>(calleePtrType->getElementType());
+    return builder.CreateCall(calleeType, operandsRef.front(),
+                              operandsRef.drop_front());
+  };
+
+  // Emit calls.  If the called function has a result, remap the corresponding
+  // value.  Note that LLVM IR dialect CallOp has either 0 or 1 result.
+  if (isa<LLVM::CallOp>(opInst)) {
+    llvm::Value *result = convertCall(opInst);
+    if (opInst.getNumResults() != 0) {
+      mapValue(opInst.getResult(0), result);
+      return success();
+    }
+    // Check that LLVM call returns void for 0-result functions.
+    return success(result->getType()->isVoidTy());
+  }
+
+  if (auto inlineAsmOp = dyn_cast<LLVM::InlineAsmOp>(opInst)) {
+    // TODO: refactor function type creation which usually occurs in std-LLVM
+    // conversion.
+    SmallVector<Type, 8> operandTypes;
+    operandTypes.reserve(inlineAsmOp.operands().size());
+    for (auto t : inlineAsmOp.operands().getTypes())
+      operandTypes.push_back(t);
+
+    Type resultType;
+    if (inlineAsmOp.getNumResults() == 0) {
+      resultType = LLVM::LLVMVoidType::get(&moduleTranslation.getContext());
+    } else {
+      assert(inlineAsmOp.getNumResults() == 1);
+      resultType = inlineAsmOp.getResultTypes()[0];
+    }
+    auto ft = LLVM::LLVMFunctionType::get(resultType, operandTypes);
+    llvm::InlineAsm *inlineAsmInst =
+        inlineAsmOp.asm_dialect().hasValue()
+            ? llvm::InlineAsm::get(
+                  static_cast<llvm::FunctionType *>(convertType(ft)),
+                  inlineAsmOp.asm_string(), inlineAsmOp.constraints(),
+                  inlineAsmOp.has_side_effects(), inlineAsmOp.is_align_stack(),
+                  convertAsmDialectToLLVM(*inlineAsmOp.asm_dialect()))
+            : llvm::InlineAsm::get(
+                  static_cast<llvm::FunctionType *>(convertType(ft)),
+                  inlineAsmOp.asm_string(), inlineAsmOp.constraints(),
+                  inlineAsmOp.has_side_effects(), inlineAsmOp.is_align_stack());
+    llvm::Value *result =
+        builder.CreateCall(inlineAsmInst, lookupValues(inlineAsmOp.operands()));
+    if (opInst.getNumResults() != 0)
+      mapValue(opInst.getResult(0), result);
+    return success();
+  }
+
+  if (auto invOp = dyn_cast<LLVM::InvokeOp>(opInst)) {
+    auto operands = lookupValues(opInst.getOperands());
+    ArrayRef<llvm::Value *> operandsRef(operands);
+    if (auto attr = opInst.getAttrOfType<FlatSymbolRefAttr>("callee")) {
+      builder.CreateInvoke(moduleTranslation.lookupFunction(attr.getValue()),
+                           moduleTranslation.lookupBlock(invOp.getSuccessor(0)),
+                           moduleTranslation.lookupBlock(invOp.getSuccessor(1)),
+                           operandsRef);
+    } else {
+      auto *calleePtrType =
+          cast<llvm::PointerType>(operandsRef.front()->getType());
+      auto *calleeType =
+          cast<llvm::FunctionType>(calleePtrType->getElementType());
+      builder.CreateInvoke(calleeType, operandsRef.front(),
+                           moduleTranslation.lookupBlock(invOp.getSuccessor(0)),
+                           moduleTranslation.lookupBlock(invOp.getSuccessor(1)),
+                           operandsRef.drop_front());
+    }
+    return success();
+  }
+
+  if (auto lpOp = dyn_cast<LLVM::LandingpadOp>(opInst)) {
+    llvm::Type *ty = convertType(lpOp.getType());
+    llvm::LandingPadInst *lpi =
+        builder.CreateLandingPad(ty, lpOp.getNumOperands());
+
+    // Add clauses
+    for (llvm::Value *operand : lookupValues(lpOp.getOperands())) {
+      // All operands should be constant - checked by verifier
+      if (auto *constOperand = dyn_cast<llvm::Constant>(operand))
+        lpi->addClause(constOperand);
+    }
+    mapValue(lpOp.getResult(), lpi);
+    return success();
+  }
+
+  // Emit branches.  We need to look up the remapped blocks and ignore the block
+  // arguments that were transformed into PHI nodes.
+  if (auto brOp = dyn_cast<LLVM::BrOp>(opInst)) {
+    llvm::BranchInst *branch =
+        builder.CreateBr(moduleTranslation.lookupBlock(brOp.getSuccessor()));
+    moduleTranslation.mapBranch(&opInst, branch);
+    return success();
+  }
+  if (auto condbrOp = dyn_cast<LLVM::CondBrOp>(opInst)) {
+    auto weights = condbrOp.branch_weights();
+    llvm::MDNode *branchWeights = nullptr;
+    if (weights) {
+      // Map weight attributes to LLVM metadata.
+      auto trueWeight =
+          weights.getValue().getValue(0).cast<IntegerAttr>().getInt();
+      auto falseWeight =
+          weights.getValue().getValue(1).cast<IntegerAttr>().getInt();
+      branchWeights =
+          llvm::MDBuilder(moduleTranslation.getLLVMContext())
+              .createBranchWeights(static_cast<uint32_t>(trueWeight),
+                                   static_cast<uint32_t>(falseWeight));
+    }
+    llvm::BranchInst *branch = builder.CreateCondBr(
+        moduleTranslation.lookupValue(condbrOp.getOperand(0)),
+        moduleTranslation.lookupBlock(condbrOp.getSuccessor(0)),
+        moduleTranslation.lookupBlock(condbrOp.getSuccessor(1)), branchWeights);
+    moduleTranslation.mapBranch(&opInst, branch);
+    return success();
+  }
+  if (auto switchOp = dyn_cast<LLVM::SwitchOp>(opInst)) {
+    llvm::MDNode *branchWeights = nullptr;
+    if (auto weights = switchOp.branch_weights()) {
+      llvm::SmallVector<uint32_t> weightValues;
+      weightValues.reserve(weights->size());
+      for (llvm::APInt weight : weights->cast<DenseIntElementsAttr>())
+        weightValues.push_back(weight.getLimitedValue());
+      branchWeights = llvm::MDBuilder(moduleTranslation.getLLVMContext())
+                          .createBranchWeights(weightValues);
+    }
+
+    llvm::SwitchInst *switchInst = builder.CreateSwitch(
+        moduleTranslation.lookupValue(switchOp.value()),
+        moduleTranslation.lookupBlock(switchOp.defaultDestination()),
+        switchOp.caseDestinations().size(), branchWeights);
+
+    auto *ty =
+        llvm::cast<llvm::IntegerType>(convertType(switchOp.value().getType()));
+    for (auto i :
+         llvm::zip(switchOp.case_values()->cast<DenseIntElementsAttr>(),
+                   switchOp.caseDestinations()))
+      switchInst->addCase(
+          llvm::ConstantInt::get(ty, std::get<0>(i).getLimitedValue()),
+          moduleTranslation.lookupBlock(std::get<1>(i)));
+
+    moduleTranslation.mapBranch(&opInst, switchInst);
+    return success();
+  }
+
+  // Emit addressof.  We need to look up the global value referenced by the
+  // operation and store it in the MLIR-to-LLVM value mapping.  This does not
+  // emit any LLVM instruction.
+  if (auto addressOfOp = dyn_cast<LLVM::AddressOfOp>(opInst)) {
+    LLVM::GlobalOp global = addressOfOp.getGlobal();
+    LLVM::LLVMFuncOp function = addressOfOp.getFunction();
+
+    // The verifier should not have allowed this.
+    assert((global || function) &&
+           "referencing an undefined global or function");
+
+    mapValue(addressOfOp.getResult(),
+             global ? moduleTranslation.lookupGlobal(global)
+                    : moduleTranslation.lookupFunction(function.getName()));
+    return success();
+  }
+
+  return failure();
+}
+
+LogicalResult mlir::LLVMDialectLLVMIRTranslationInterface::convertOperation(
+    Operation *op, llvm::IRBuilderBase &builder,
+    LLVM::ModuleTranslation &moduleTranslation) const {
+  return convertOperationImpl(*op, builder, moduleTranslation);
+}
index 52f1792..082504d 100644 (file)
@@ -12,6 +12,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h"
+#include "mlir/Target/LLVMIR.h"
 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
 #include "mlir/Translation.h"
 #include "llvm/IR/IntrinsicsX86.h"
@@ -57,7 +58,8 @@ void registerAVX512ToLLVMIRTranslation() {
         return success();
       },
       [](DialectRegistry &registry) {
-        registry.insert<LLVM::LLVMAVX512Dialect, LLVM::LLVMDialect>();
+        registry.insert<LLVM::LLVMAVX512Dialect>();
+        registerLLVMDialectTranslation(registry);
       });
 }
 } // namespace mlir
index 0bd40ef..14bbd14 100644 (file)
@@ -12,6 +12,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/LLVMIR/LLVMArmNeonDialect.h"
+#include "mlir/Target/LLVMIR.h"
 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
 #include "mlir/Translation.h"
 #include "llvm/IR/IntrinsicsAArch64.h"
@@ -57,7 +58,8 @@ void registerArmNeonToLLVMIRTranslation() {
         return success();
       },
       [](DialectRegistry &registry) {
-        registry.insert<LLVM::LLVMArmNeonDialect, LLVM::LLVMDialect>();
+        registry.insert<LLVM::LLVMArmNeonDialect>();
+        registerLLVMDialectTranslation(registry);
       });
 }
 } // namespace mlir
index 717583a..5ef8b48 100644 (file)
@@ -12,6 +12,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h"
+#include "mlir/Target/LLVMIR.h"
 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
 #include "mlir/Translation.h"
 #include "llvm/IR/IntrinsicsAArch64.h"
@@ -57,7 +58,8 @@ void registerArmSVEToLLVMIRTranslation() {
         return success();
       },
       [](DialectRegistry &registry) {
-        registry.insert<LLVM::LLVMArmSVEDialect, LLVM::LLVMDialect>();
+        registry.insert<LLVM::LLVMArmSVEDialect>();
+        registerLLVMDialectTranslation(registry);
       });
 }
 } // namespace mlir
index dce3846..6728511 100644 (file)
@@ -21,6 +21,7 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/RegionGraphTraits.h"
 #include "mlir/Support/LLVM.h"
+#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
 #include "mlir/Target/LLVMIR/TypeTranslation.h"
 #include "llvm/ADT/TypeSwitch.h"
 
@@ -183,130 +184,14 @@ llvm::Constant *ModuleTranslation::getLLVMConstant(llvm::Type *llvmType,
   return nullptr;
 }
 
-/// Convert MLIR integer comparison predicate to LLVM IR comparison predicate.
-static llvm::CmpInst::Predicate getLLVMCmpPredicate(ICmpPredicate p) {
-  switch (p) {
-  case LLVM::ICmpPredicate::eq:
-    return llvm::CmpInst::Predicate::ICMP_EQ;
-  case LLVM::ICmpPredicate::ne:
-    return llvm::CmpInst::Predicate::ICMP_NE;
-  case LLVM::ICmpPredicate::slt:
-    return llvm::CmpInst::Predicate::ICMP_SLT;
-  case LLVM::ICmpPredicate::sle:
-    return llvm::CmpInst::Predicate::ICMP_SLE;
-  case LLVM::ICmpPredicate::sgt:
-    return llvm::CmpInst::Predicate::ICMP_SGT;
-  case LLVM::ICmpPredicate::sge:
-    return llvm::CmpInst::Predicate::ICMP_SGE;
-  case LLVM::ICmpPredicate::ult:
-    return llvm::CmpInst::Predicate::ICMP_ULT;
-  case LLVM::ICmpPredicate::ule:
-    return llvm::CmpInst::Predicate::ICMP_ULE;
-  case LLVM::ICmpPredicate::ugt:
-    return llvm::CmpInst::Predicate::ICMP_UGT;
-  case LLVM::ICmpPredicate::uge:
-    return llvm::CmpInst::Predicate::ICMP_UGE;
-  }
-  llvm_unreachable("incorrect comparison predicate");
-}
-
-static llvm::CmpInst::Predicate getLLVMCmpPredicate(FCmpPredicate p) {
-  switch (p) {
-  case LLVM::FCmpPredicate::_false:
-    return llvm::CmpInst::Predicate::FCMP_FALSE;
-  case LLVM::FCmpPredicate::oeq:
-    return llvm::CmpInst::Predicate::FCMP_OEQ;
-  case LLVM::FCmpPredicate::ogt:
-    return llvm::CmpInst::Predicate::FCMP_OGT;
-  case LLVM::FCmpPredicate::oge:
-    return llvm::CmpInst::Predicate::FCMP_OGE;
-  case LLVM::FCmpPredicate::olt:
-    return llvm::CmpInst::Predicate::FCMP_OLT;
-  case LLVM::FCmpPredicate::ole:
-    return llvm::CmpInst::Predicate::FCMP_OLE;
-  case LLVM::FCmpPredicate::one:
-    return llvm::CmpInst::Predicate::FCMP_ONE;
-  case LLVM::FCmpPredicate::ord:
-    return llvm::CmpInst::Predicate::FCMP_ORD;
-  case LLVM::FCmpPredicate::ueq:
-    return llvm::CmpInst::Predicate::FCMP_UEQ;
-  case LLVM::FCmpPredicate::ugt:
-    return llvm::CmpInst::Predicate::FCMP_UGT;
-  case LLVM::FCmpPredicate::uge:
-    return llvm::CmpInst::Predicate::FCMP_UGE;
-  case LLVM::FCmpPredicate::ult:
-    return llvm::CmpInst::Predicate::FCMP_ULT;
-  case LLVM::FCmpPredicate::ule:
-    return llvm::CmpInst::Predicate::FCMP_ULE;
-  case LLVM::FCmpPredicate::une:
-    return llvm::CmpInst::Predicate::FCMP_UNE;
-  case LLVM::FCmpPredicate::uno:
-    return llvm::CmpInst::Predicate::FCMP_UNO;
-  case LLVM::FCmpPredicate::_true:
-    return llvm::CmpInst::Predicate::FCMP_TRUE;
-  }
-  llvm_unreachable("incorrect comparison predicate");
-}
-
-static llvm::AtomicRMWInst::BinOp getLLVMAtomicBinOp(AtomicBinOp op) {
-  switch (op) {
-  case LLVM::AtomicBinOp::xchg:
-    return llvm::AtomicRMWInst::BinOp::Xchg;
-  case LLVM::AtomicBinOp::add:
-    return llvm::AtomicRMWInst::BinOp::Add;
-  case LLVM::AtomicBinOp::sub:
-    return llvm::AtomicRMWInst::BinOp::Sub;
-  case LLVM::AtomicBinOp::_and:
-    return llvm::AtomicRMWInst::BinOp::And;
-  case LLVM::AtomicBinOp::nand:
-    return llvm::AtomicRMWInst::BinOp::Nand;
-  case LLVM::AtomicBinOp::_or:
-    return llvm::AtomicRMWInst::BinOp::Or;
-  case LLVM::AtomicBinOp::_xor:
-    return llvm::AtomicRMWInst::BinOp::Xor;
-  case LLVM::AtomicBinOp::max:
-    return llvm::AtomicRMWInst::BinOp::Max;
-  case LLVM::AtomicBinOp::min:
-    return llvm::AtomicRMWInst::BinOp::Min;
-  case LLVM::AtomicBinOp::umax:
-    return llvm::AtomicRMWInst::BinOp::UMax;
-  case LLVM::AtomicBinOp::umin:
-    return llvm::AtomicRMWInst::BinOp::UMin;
-  case LLVM::AtomicBinOp::fadd:
-    return llvm::AtomicRMWInst::BinOp::FAdd;
-  case LLVM::AtomicBinOp::fsub:
-    return llvm::AtomicRMWInst::BinOp::FSub;
-  }
-  llvm_unreachable("incorrect atomic binary operator");
-}
-
-static llvm::AtomicOrdering getLLVMAtomicOrdering(AtomicOrdering ordering) {
-  switch (ordering) {
-  case LLVM::AtomicOrdering::not_atomic:
-    return llvm::AtomicOrdering::NotAtomic;
-  case LLVM::AtomicOrdering::unordered:
-    return llvm::AtomicOrdering::Unordered;
-  case LLVM::AtomicOrdering::monotonic:
-    return llvm::AtomicOrdering::Monotonic;
-  case LLVM::AtomicOrdering::acquire:
-    return llvm::AtomicOrdering::Acquire;
-  case LLVM::AtomicOrdering::release:
-    return llvm::AtomicOrdering::Release;
-  case LLVM::AtomicOrdering::acq_rel:
-    return llvm::AtomicOrdering::AcquireRelease;
-  case LLVM::AtomicOrdering::seq_cst:
-    return llvm::AtomicOrdering::SequentiallyConsistent;
-  }
-  llvm_unreachable("incorrect atomic ordering");
-}
-
 ModuleTranslation::ModuleTranslation(Operation *module,
                                      std::unique_ptr<llvm::Module> llvmModule)
     : mlirModule(module), llvmModule(std::move(llvmModule)),
       debugTranslation(
           std::make_unique<DebugTranslation>(module, *this->llvmModule)),
       ompDialect(module->getContext()->getLoadedDialect("omp")),
-      typeTranslator(this->llvmModule->getContext()) {
+      typeTranslator(this->llvmModule->getContext()),
+      iface(module->getContext()) {
   assert(satisfiesLLVMModule(mlirModule) &&
          "mlirModule should honor LLVM's module semantics.");
 }
@@ -658,221 +543,17 @@ ModuleTranslation::convertOmpOperation(Operation &opInst,
       });
 }
 
-static llvm::FastMathFlags getFastmathFlags(FastmathFlagsInterface &op) {
-  using llvmFMF = llvm::FastMathFlags;
-  using FuncT = void (llvmFMF::*)(bool);
-  const std::pair<FastmathFlags, FuncT> handlers[] = {
-      // clang-format off
-      {FastmathFlags::nnan,     &llvmFMF::setNoNaNs},
-      {FastmathFlags::ninf,     &llvmFMF::setNoInfs},
-      {FastmathFlags::nsz,      &llvmFMF::setNoSignedZeros},
-      {FastmathFlags::arcp,     &llvmFMF::setAllowReciprocal},
-      {FastmathFlags::contract, &llvmFMF::setAllowContract},
-      {FastmathFlags::afn,      &llvmFMF::setApproxFunc},
-      {FastmathFlags::reassoc,  &llvmFMF::setAllowReassoc},
-      {FastmathFlags::fast,     &llvmFMF::setFast},
-      // clang-format on
-  };
-  llvm::FastMathFlags ret;
-  auto fmf = op.fastmathFlags();
-  for (auto it : handlers)
-    if (bitEnumContains(fmf, it.first))
-      (ret.*(it.second))(true);
-  return ret;
-}
-
 /// Given a single MLIR operation, create the corresponding LLVM IR operation
 /// using the `builder`.  LLVM IR Builder does not have a generic interface so
 /// this has to be a long chain of `if`s calling different functions with a
 /// different number of arguments.
 LogicalResult ModuleTranslation::convertOperation(Operation &opInst,
                                                   llvm::IRBuilder<> &builder) {
-  auto extractPosition = [](ArrayAttr attr) {
-    SmallVector<unsigned, 4> position;
-    position.reserve(attr.size());
-    for (Attribute v : attr)
-      position.push_back(v.cast<IntegerAttr>().getValue().getZExtValue());
-    return position;
-  };
-
-  llvm::IRBuilder<>::FastMathFlagGuard fmfGuard(builder);
-  if (auto fmf = dyn_cast<FastmathFlagsInterface>(opInst))
-    builder.setFastMathFlags(getFastmathFlags(fmf));
-
-#include "mlir/Dialect/LLVMIR/LLVMConversions.inc"
-
-  // Emit function calls.  If the "callee" attribute is present, this is a
-  // direct function call and we also need to look up the remapped function
-  // itself.  Otherwise, this is an indirect call and the callee is the first
-  // operand, look it up as a normal value.  Return the llvm::Value representing
-  // the function result, which may be of llvm::VoidTy type.
-  auto convertCall = [this, &builder](Operation &op) -> llvm::Value * {
-    auto operands = lookupValues(op.getOperands());
-    ArrayRef<llvm::Value *> operandsRef(operands);
-    if (auto attr = op.getAttrOfType<FlatSymbolRefAttr>("callee"))
-      return builder.CreateCall(lookupFunction(attr.getValue()), operandsRef);
-    auto *calleePtrType =
-        cast<llvm::PointerType>(operandsRef.front()->getType());
-    auto *calleeType =
-        cast<llvm::FunctionType>(calleePtrType->getElementType());
-    return builder.CreateCall(calleeType, operandsRef.front(),
-                              operandsRef.drop_front());
-  };
-
-  // Emit calls.  If the called function has a result, remap the corresponding
-  // value.  Note that LLVM IR dialect CallOp has either 0 or 1 result.
-  if (isa<LLVM::CallOp>(opInst)) {
-    llvm::Value *result = convertCall(opInst);
-    if (opInst.getNumResults() != 0) {
-      mapValue(opInst.getResult(0), result);
-      return success();
-    }
-    // Check that LLVM call returns void for 0-result functions.
-    return success(result->getType()->isVoidTy());
-  }
 
-  if (auto inlineAsmOp = dyn_cast<LLVM::InlineAsmOp>(opInst)) {
-    // TODO: refactor function type creation which usually occurs in std-LLVM
-    // conversion.
-    SmallVector<Type, 8> operandTypes;
-    operandTypes.reserve(inlineAsmOp.operands().size());
-    for (auto t : inlineAsmOp.operands().getTypes())
-      operandTypes.push_back(t);
-
-    Type resultType;
-    if (inlineAsmOp.getNumResults() == 0) {
-      resultType = LLVM::LLVMVoidType::get(mlirModule->getContext());
-    } else {
-      assert(inlineAsmOp.getNumResults() == 1);
-      resultType = inlineAsmOp.getResultTypes()[0];
-    }
-    auto ft = LLVM::LLVMFunctionType::get(resultType, operandTypes);
-    llvm::InlineAsm *inlineAsmInst =
-        inlineAsmOp.asm_dialect().hasValue()
-            ? llvm::InlineAsm::get(
-                  static_cast<llvm::FunctionType *>(convertType(ft)),
-                  inlineAsmOp.asm_string(), inlineAsmOp.constraints(),
-                  inlineAsmOp.has_side_effects(), inlineAsmOp.is_align_stack(),
-                  convertAsmDialectToLLVM(*inlineAsmOp.asm_dialect()))
-            : llvm::InlineAsm::get(
-                  static_cast<llvm::FunctionType *>(convertType(ft)),
-                  inlineAsmOp.asm_string(), inlineAsmOp.constraints(),
-                  inlineAsmOp.has_side_effects(), inlineAsmOp.is_align_stack());
-    llvm::Value *result =
-        builder.CreateCall(inlineAsmInst, lookupValues(inlineAsmOp.operands()));
-    if (opInst.getNumResults() != 0)
-      mapValue(opInst.getResult(0), result);
+  // TODO(zinenko): this should be the "main" conversion here, remove the
+  // dispatch below.
+  if (succeeded(iface.convertOperation(&opInst, builder, *this)))
     return success();
-  }
-
-  if (auto invOp = dyn_cast<LLVM::InvokeOp>(opInst)) {
-    auto operands = lookupValues(opInst.getOperands());
-    ArrayRef<llvm::Value *> operandsRef(operands);
-    if (auto attr = opInst.getAttrOfType<FlatSymbolRefAttr>("callee")) {
-      builder.CreateInvoke(lookupFunction(attr.getValue()),
-                           lookupBlock(invOp.getSuccessor(0)),
-                           lookupBlock(invOp.getSuccessor(1)), operandsRef);
-    } else {
-      auto *calleePtrType =
-          cast<llvm::PointerType>(operandsRef.front()->getType());
-      auto *calleeType =
-          cast<llvm::FunctionType>(calleePtrType->getElementType());
-      builder.CreateInvoke(
-          calleeType, operandsRef.front(), lookupBlock(invOp.getSuccessor(0)),
-          lookupBlock(invOp.getSuccessor(1)), operandsRef.drop_front());
-    }
-    return success();
-  }
-
-  if (auto lpOp = dyn_cast<LLVM::LandingpadOp>(opInst)) {
-    llvm::Type *ty = convertType(lpOp.getType());
-    llvm::LandingPadInst *lpi =
-        builder.CreateLandingPad(ty, lpOp.getNumOperands());
-
-    // Add clauses
-    for (llvm::Value *operand : lookupValues(lpOp.getOperands())) {
-      // All operands should be constant - checked by verifier
-      if (auto *constOperand = dyn_cast<llvm::Constant>(operand))
-        lpi->addClause(constOperand);
-    }
-    mapValue(lpOp.getResult(), lpi);
-    return success();
-  }
-
-  // Emit branches.  We need to look up the remapped blocks and ignore the block
-  // arguments that were transformed into PHI nodes.
-  if (auto brOp = dyn_cast<LLVM::BrOp>(opInst)) {
-    llvm::BranchInst *branch =
-        builder.CreateBr(lookupBlock(brOp.getSuccessor()));
-    mapBranch(&opInst, branch);
-    return success();
-  }
-  if (auto condbrOp = dyn_cast<LLVM::CondBrOp>(opInst)) {
-    auto weights = condbrOp.branch_weights();
-    llvm::MDNode *branchWeights = nullptr;
-    if (weights) {
-      // Map weight attributes to LLVM metadata.
-      auto trueWeight =
-          weights.getValue().getValue(0).cast<IntegerAttr>().getInt();
-      auto falseWeight =
-          weights.getValue().getValue(1).cast<IntegerAttr>().getInt();
-      branchWeights =
-          llvm::MDBuilder(llvmModule->getContext())
-              .createBranchWeights(static_cast<uint32_t>(trueWeight),
-                                   static_cast<uint32_t>(falseWeight));
-    }
-    llvm::BranchInst *branch = builder.CreateCondBr(
-        lookupValue(condbrOp.getOperand(0)),
-        lookupBlock(condbrOp.getSuccessor(0)),
-        lookupBlock(condbrOp.getSuccessor(1)), branchWeights);
-    mapBranch(&opInst, branch);
-    return success();
-  }
-  if (auto switchOp = dyn_cast<LLVM::SwitchOp>(opInst)) {
-    llvm::MDNode *branchWeights = nullptr;
-    if (auto weights = switchOp.branch_weights()) {
-      llvm::SmallVector<uint32_t> weightValues;
-      weightValues.reserve(weights->size());
-      for (llvm::APInt weight : weights->cast<DenseIntElementsAttr>())
-        weightValues.push_back(weight.getLimitedValue());
-      branchWeights = llvm::MDBuilder(llvmModule->getContext())
-                          .createBranchWeights(weightValues);
-    }
-
-    llvm::SwitchInst *switchInst =
-        builder.CreateSwitch(lookupValue(switchOp.value()),
-                             lookupBlock(switchOp.defaultDestination()),
-                             switchOp.caseDestinations().size(), branchWeights);
-
-    auto *ty =
-        llvm::cast<llvm::IntegerType>(convertType(switchOp.value().getType()));
-    for (auto i :
-         llvm::zip(switchOp.case_values()->cast<DenseIntElementsAttr>(),
-                   switchOp.caseDestinations()))
-      switchInst->addCase(
-          llvm::ConstantInt::get(ty, std::get<0>(i).getLimitedValue()),
-          lookupBlock(std::get<1>(i)));
-
-    mapBranch(&opInst, switchInst);
-    return success();
-  }
-
-  // Emit addressof.  We need to look up the global value referenced by the
-  // operation and store it in the MLIR-to-LLVM value mapping.  This does not
-  // emit any LLVM instruction.
-  if (auto addressOfOp = dyn_cast<LLVM::AddressOfOp>(opInst)) {
-    LLVM::GlobalOp global = addressOfOp.getGlobal();
-    LLVM::LLVMFuncOp function = addressOfOp.getFunction();
-
-    // The verifier should not have allowed this.
-    assert((global || function) &&
-           "referencing an undefined global or function");
-
-    mapValue(addressOfOp.getResult(), global
-                                          ? globalsMapping.lookup(global)
-                                          : lookupFunction(function.getName()));
-    return success();
-  }
 
   if (ompDialect && opInst.getDialect() == ompDialect)
     return convertOmpOperation(opInst, builder);
index e500db3..4389039 100644 (file)
@@ -17,6 +17,8 @@
 #include "mlir/ExecutionEngine/JitRunner.h"
 #include "mlir/ExecutionEngine/OptUtils.h"
 #include "mlir/IR/Dialect.h"
+#include "mlir/Target/LLVMIR.h"
+
 #include "llvm/Support/InitLLVM.h"
 #include "llvm/Support/TargetSelect.h"
 
@@ -29,5 +31,7 @@ int main(int argc, char **argv) {
 
   mlir::DialectRegistry registry;
   registry.insert<mlir::LLVM::LLVMDialect, mlir::omp::OpenMPDialect>();
+  mlir::registerLLVMDialectTranslation(registry);
+
   return mlir::JitRunnerMain(argc, argv, registry);
 }
index dd605b9..1e6a80b 100644 (file)
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassManager.h"
+#include "mlir/Target/LLVMIR.h"
 #include "mlir/Target/NVVMIR.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/Passes.h"
+
 #include "llvm/Support/InitLLVM.h"
 #include "llvm/Support/TargetSelect.h"
 
@@ -154,6 +156,7 @@ int main(int argc, char **argv) {
   registry.insert<mlir::LLVM::LLVMDialect, mlir::NVVM::NVVMDialect,
                   mlir::async::AsyncDialect, mlir::gpu::GPUDialect,
                   mlir::StandardOpsDialect>();
+  mlir::registerLLVMDialectTranslation(registry);
 
   return mlir::JitRunnerMain(argc, argv, registry, jitRunnerConfig);
 }
index 4df7fed..8237545 100644 (file)
@@ -30,6 +30,7 @@
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Support/FileUtilities.h"
+#include "mlir/Target/LLVMIR.h"
 #include "mlir/Target/ROCDLIR.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/Passes.h"
@@ -340,6 +341,7 @@ int main(int argc, char **argv) {
   mlir::DialectRegistry registry;
   registry.insert<mlir::LLVM::LLVMDialect, mlir::gpu::GPUDialect,
                   mlir::ROCDL::ROCDLDialect, mlir::StandardOpsDialect>();
+  mlir::registerLLVMDialectTranslation(registry);
 
   return mlir::JitRunnerMain(argc, argv, registry, jitRunnerConfig);
 }
index 3cfd661..7955dca 100644 (file)
@@ -96,6 +96,7 @@ int main(int argc, char **argv) {
   mlir::DialectRegistry registry;
   registry.insert<mlir::LLVM::LLVMDialect, mlir::gpu::GPUDialect,
                   mlir::spirv::SPIRVDialect, mlir::StandardOpsDialect>();
+  mlir::registerLLVMDialectTranslation(registry);
 
   return mlir::JitRunnerMain(argc, argv, registry, jitRunnerConfig);
 }
index e407172..651e59c 100644 (file)
@@ -27,6 +27,7 @@
 #include "mlir/ExecutionEngine/OptUtils.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassManager.h"
+#include "mlir/Target/LLVMIR.h"
 #include "llvm/Support/InitLLVM.h"
 #include "llvm/Support/TargetSelect.h"
 
@@ -67,6 +68,7 @@ int main(int argc, char **argv) {
   mlir::DialectRegistry registry;
   registry.insert<mlir::LLVM::LLVMDialect, mlir::gpu::GPUDialect,
                   mlir::spirv::SPIRVDialect, mlir::StandardOpsDialect>();
+  mlir::registerLLVMDialectTranslation(registry);
 
   return mlir::JitRunnerMain(argc, argv, registry, jitRunnerConfig);
 }
index 9b7450e..230c9b6 100644 (file)
@@ -19,6 +19,7 @@
 #include "mlir/InitAllDialects.h"
 #include "mlir/Parser.h"
 #include "mlir/Pass/PassManager.h"
+#include "mlir/Target/LLVMIR.h"
 #include "llvm/Support/TargetSelect.h"
 #include "llvm/Support/raw_ostream.h"
 
@@ -51,8 +52,10 @@ TEST(MLIRExecutionEngine, AddInteger) {
     return %res : i32
   }
   )mlir";
-  MLIRContext context;
-  registerAllDialects(context);
+  DialectRegistry registry;
+  registerAllDialects(registry);
+  registerLLVMDialectTranslation(registry);
+  MLIRContext context(registry);
   OwningModuleRef module = parseSourceString(moduleStr, &context);
   ASSERT_TRUE(!!module);
   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
@@ -74,8 +77,10 @@ TEST(MLIRExecutionEngine, SubtractFloat) {
     return %res : f32
   }
   )mlir";
-  MLIRContext context;
-  registerAllDialects(context);
+  DialectRegistry registry;
+  registerAllDialects(registry);
+  registerLLVMDialectTranslation(registry);
+  MLIRContext context(registry);
   OwningModuleRef module = parseSourceString(moduleStr, &context);
   ASSERT_TRUE(!!module);
   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
@@ -102,8 +107,10 @@ TEST(NativeMemRefJit, ZeroRankMemref) {
     return
   }
   )mlir";
-  MLIRContext context;
-  registerAllDialects(context);
+  DialectRegistry registry;
+  registerAllDialects(registry);
+  registerLLVMDialectTranslation(registry);
+  MLIRContext context(registry);
   auto module = parseSourceString(moduleStr, &context);
   ASSERT_TRUE(!!module);
   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
@@ -135,8 +142,10 @@ TEST(NativeMemRefJit, RankOneMemref) {
     return
   }
   )mlir";
-  MLIRContext context;
-  registerAllDialects(context);
+  DialectRegistry registry;
+  registerAllDialects(registry);
+  registerLLVMDialectTranslation(registry);
+  MLIRContext context(registry);
   auto module = parseSourceString(moduleStr, &context);
   ASSERT_TRUE(!!module);
   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
@@ -187,8 +196,10 @@ TEST(NativeMemRefJit, BasicMemref) {
     return
   }
   )mlir";
-  MLIRContext context;
-  registerAllDialects(context);
+  DialectRegistry registry;
+  registerAllDialects(registry);
+  registerLLVMDialectTranslation(registry);
+  MLIRContext context(registry);
   OwningModuleRef module = parseSourceString(moduleStr, &context);
   ASSERT_TRUE(!!module);
   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
@@ -227,8 +238,10 @@ TEST(NativeMemRefJit, JITCallback) {
     return
   }
   )mlir";
-  MLIRContext context;
-  registerAllDialects(context);
+  DialectRegistry registry;
+  registerAllDialects(registry);
+  registerLLVMDialectTranslation(registry);
+  MLIRContext context(registry);
   auto module = parseSourceString(moduleStr, &context);
   ASSERT_TRUE(!!module);
   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));