From: Alex Zinenko Date: Thu, 11 Feb 2021 14:01:33 +0000 (+0100) Subject: [mlir] Introduce dialect interfaces for translation to LLVM IR X-Git-Tag: llvmorg-14-init~15266 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=b77bac0572340d4e5b6095a82e0fcbcc01870645;p=platform%2Fupstream%2Fllvm.git [mlir] Introduce dialect interfaces for translation to LLVM IR The existing approach to translation to the LLVM IR relies on a single translation supporting the base LLVM dialect, extensible through inheritance to support intrinsic-based dialects also derived from LLVM IR such as NVVM and AVX512. This approach does not scale well as it requires additional translations to be created for each new intrinsic-based dialect and does not allow them to mix in the same module, contrary to the rest of the MLIR infrastructure. Furthermore, OpenMP translation ingrained itself into the main translation mechanism. Start refactoring the translation to LLVM IR to operate using dialect interfaces. Each dialect that contains ops translatable to LLVM IR can implement the interface for translating them, and the top-level translation driver can operate on interfaces without knowing about specific dialects. Furthermore, the delayed dialect registration mechanism allows one to avoid a dependency on LLVM IR in the dialect that is translated to it by implementing the translation as a separate library and only registering it at the client level. This change introduces the new mechanism and factors out the translation of the "main" LLVM dialect. The remaining dialects will follow suit. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D96503 --- diff --git a/mlir/examples/toy/Ch6/toyc.cpp b/mlir/examples/toy/Ch6/toyc.cpp index b800c0a..d717f69 100644 --- a/mlir/examples/toy/Ch6/toyc.cpp +++ b/mlir/examples/toy/Ch6/toyc.cpp @@ -189,6 +189,9 @@ int dumpAST() { } int dumpLLVMIR(mlir::ModuleOp module) { + // Register the translation to LLVM IR with the MLIR context. + mlir::registerLLVMDialectTranslation(*module->getContext()); + // Convert the module to LLVM IR in a new LLVM IR context. llvm::LLVMContext llvmContext; auto llvmModule = mlir::translateModuleToLLVMIR(module, llvmContext); @@ -219,6 +222,10 @@ int runJit(mlir::ModuleOp module) { llvm::InitializeNativeTarget(); llvm::InitializeNativeTargetAsmPrinter(); + // Register the translation from MLIR to LLVM IR, which must happen before we + // can JIT-compile. + mlir::registerLLVMDialectTranslation(*module->getContext()); + // An optimization pipeline to use within the execution engine. auto optPipeline = mlir::makeOptimizingTransformer( /*optLevel=*/enableOpt ? 3 : 0, /*sizeLevel=*/0, diff --git a/mlir/examples/toy/Ch7/toyc.cpp b/mlir/examples/toy/Ch7/toyc.cpp index 4fdb06d..3898e28 100644 --- a/mlir/examples/toy/Ch7/toyc.cpp +++ b/mlir/examples/toy/Ch7/toyc.cpp @@ -190,6 +190,9 @@ int dumpAST() { } int dumpLLVMIR(mlir::ModuleOp module) { + // Register the translation to LLVM IR with the MLIR context. + mlir::registerLLVMDialectTranslation(*module->getContext()); + // Convert the module to LLVM IR in a new LLVM IR context. llvm::LLVMContext llvmContext; auto llvmModule = mlir::translateModuleToLLVMIR(module, llvmContext); @@ -220,6 +223,10 @@ int runJit(mlir::ModuleOp module) { llvm::InitializeNativeTarget(); llvm::InitializeNativeTargetAsmPrinter(); + // Register the translation from MLIR to LLVM IR, which must happen before we + // can JIT-compile. + mlir::registerLLVMDialectTranslation(*module->getContext()); + // An optimization pipeline to use within the execution engine. auto optPipeline = mlir::makeOptimizingTransformer( /*optLevel=*/enableOpt ? 3 : 0, /*sizeLevel=*/0, diff --git a/mlir/include/mlir/Target/LLVMIR.h b/mlir/include/mlir/Target/LLVMIR.h index ffd1a4c..2050c63 100644 --- a/mlir/include/mlir/Target/LLVMIR.h +++ b/mlir/include/mlir/Target/LLVMIR.h @@ -25,6 +25,7 @@ class Module; namespace mlir { +class DialectRegistry; class OwningModuleRef; class MLIRContext; class ModuleOp; @@ -45,6 +46,15 @@ OwningModuleRef translateLLVMIRToModule(std::unique_ptr llvmModule, MLIRContext *context); +/// Register the LLVM dialect and the translation from it to the LLVM IR in the +/// given registry; +void registerLLVMDialectTranslation(DialectRegistry ®istry); + +/// Register the LLVM dialect and the translation from it in the registry +/// associated with the given context. This checks if the interface is already +/// registered and avoids double registation. +void registerLLVMDialectTranslation(MLIRContext &context); + } // namespace mlir #endif // MLIR_TARGET_LLVMIR_H diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h b/mlir/include/mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h new file mode 100644 index 0000000..8b72ced --- /dev/null +++ b/mlir/include/mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h @@ -0,0 +1,37 @@ +//===- LLVMToLLVMIRTranslation.h - LLVM Dialect to LLVM IR-------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the dialect interface for translating the LLVM dialect +// to LLVM IR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TARGET_LLVMIR_DIALECT_LLVMIR_LLVMTOLLVMIRTRANSLATION_H +#define MLIR_TARGET_LLVMIR_DIALECT_LLVMIR_LLVMTOLLVMIRTRANSLATION_H + +#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h" + +namespace mlir { + +/// Implementation of the dialect interface that converts operations beloning to +/// the LLVM dialect to LLVM IR. +class LLVMDialectLLVMIRTranslationInterface + : public LLVMTranslationDialectInterface { +public: + using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; + + /// Translates the given operation to LLVM IR using the provided IR builder + /// and saving the state in `moduleTranslation`. + LogicalResult + convertOperation(Operation *op, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) const final; +}; + +} // namespace mlir + +#endif // MLIR_TARGET_LLVMIR_DIALECT_LLVMIR_LLVMTOLLVMIRTRANSLATION_H diff --git a/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h b/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h new file mode 100644 index 0000000..0063bea --- /dev/null +++ b/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h @@ -0,0 +1,68 @@ +//===- LLVMTranslationInterface.h - Translation to LLVM iface ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines dialect interfaces for translation to LLVM IR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TARGET_LLVMIR_LLVMTRANSLATIONINTERFACE_H +#define MLIR_TARGET_LLVMIR_LLVMTRANSLATIONINTERFACE_H + +#include "mlir/IR/DialectInterface.h" +#include "mlir/Support/LogicalResult.h" + +namespace llvm { +class IRBuilderBase; +} + +namespace mlir { +namespace LLVM { +class ModuleTranslation; +} // namespace LLVM + +/// Base class for dialect interfaces providing translation to LLVM IR. +/// Dialects that can be translated should provide an implementation of this +/// interface for the supported operations. The interface may be implemented in +/// a separate library to avoid the "main" dialect library depending on LLVM IR. +/// The interface can be attached using the delayed registration mechanism +/// available in DialectRegistry. +class LLVMTranslationDialectInterface + : public DialectInterface::Base { +public: + LLVMTranslationDialectInterface(Dialect *dialect) : Base(dialect) {} + + /// Hook for derived dialect interface to provide translation of the + /// operations to LLVM IR. + virtual LogicalResult + convertOperation(Operation *op, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) const { + return failure(); + } +}; + +/// Interface collection for translation to LLVM IR, dispatches to a concrete +/// interface implementation based on the dialect to which the given op belongs. +class LLVMTranslationInterface + : public DialectInterfaceCollection { +public: + using Base::Base; + + /// Translates the given operation to LLVM IR using the interface implemented + /// by the op's dialect. + virtual LogicalResult + convertOperation(Operation *op, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) const { + if (const LLVMTranslationDialectInterface *iface = getInterfaceFor(op)) + return iface->convertOperation(op, builder, moduleTranslation); + return failure(); + } +}; + +} // namespace mlir + +#endif // MLIR_TARGET_LLVMIR_LLVMTRANSLATIONINTERFACE_H diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h index ebe9a7c..b15fcc3 100644 --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -19,6 +19,7 @@ #include "mlir/IR/Block.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Value.h" +#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h" #include "mlir/Target/LLVMIR/TypeTranslation.h" #include "llvm/Frontend/OpenMP/OMPIRBuilder.h" @@ -134,6 +135,31 @@ public: return branchMapping.lookup(op); } + /// Converts the type from MLIR LLVM dialect to LLVM. + llvm::Type *convertType(Type type); + + /// Looks up remapped a list of remapped values. + SmallVector lookupValues(ValueRange values); + + /// Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`. + /// This currently supports integer, floating point, splat and dense element + /// attributes and combinations thereof. In case of error, report it to `loc` + /// and return nullptr. + llvm::Constant *getLLVMConstant(llvm::Type *llvmType, Attribute attr, + Location loc); + + /// Returns the MLIR context of the module being translated. + MLIRContext &getContext() { return *mlirModule->getContext(); } + + /// Returns the LLVM context in which the IR is being constructed. + llvm::LLVMContext &getLLVMContext() { return llvmModule->getContext(); } + + /// Finds an LLVM IR global value that corresponds to the given MLIR operation + /// defining a global value. + llvm::GlobalValue *lookupGlobal(Operation *op) { + return globalsMapping.lookup(op); + } + protected: /// Translate the given MLIR module expressed in MLIR LLVM IR dialect into an /// LLVM IR module. The MLIR LLVM IR dialect holds a pointer to an @@ -158,16 +184,10 @@ protected: virtual LogicalResult convertOmpWsLoop(Operation &opInst, llvm::IRBuilder<> &builder); - /// Converts the type from MLIR LLVM dialect to LLVM. - llvm::Type *convertType(Type type); - static std::unique_ptr prepareLLVMModule(Operation *m, llvm::LLVMContext &llvmContext, StringRef name); - /// A helper to look up remapped operands in the value remapping table. - SmallVector lookupValues(ValueRange values); - private: /// Check whether the module contains only supported ops directly in its body. static LogicalResult checkSupportedModuleOps(Operation *m); @@ -179,9 +199,6 @@ private: LogicalResult convertBlock(Block &bb, bool ignoreArguments, llvm::IRBuilder<> &builder); - llvm::Constant *getLLVMConstant(llvm::Type *llvmType, Attribute attr, - Location loc); - /// Original and translated module. Operation *mlirModule; std::unique_ptr llvmModule; @@ -202,7 +219,8 @@ private: /// A stateful object used to translate types. TypeToLLVMIRTranslator typeTranslator; -private: + LLVMTranslationInterface iface; + /// Mappings between original and translated values, used for lookups. llvm::StringMap functionMapping; DenseMap valueMapping; diff --git a/mlir/lib/Target/CMakeLists.txt b/mlir/lib/Target/CMakeLists.txt index 51a0e78a..72555ac 100644 --- a/mlir/lib/Target/CMakeLists.txt +++ b/mlir/lib/Target/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(SPIRV) +add_subdirectory(LLVMIR) add_mlir_translation_library(MLIRTargetLLVMIRModuleTranslation LLVMIR/DebugTranslation.cpp @@ -39,6 +40,7 @@ add_mlir_translation_library(MLIRTargetAVX512 MLIRIR MLIRLLVMAVX512 MLIRLLVMIR + MLIRTargetLLVMIR MLIRTargetLLVMIRModuleTranslation ) @@ -54,6 +56,7 @@ add_mlir_translation_library(MLIRTargetLLVMIR IRReader LINK_LIBS PUBLIC + MLIRLLVMToLLVMIRTranslation MLIRTargetLLVMIRModuleTranslation ) @@ -73,6 +76,7 @@ add_mlir_translation_library(MLIRTargetArmNeon MLIRIR MLIRLLVMArmNeon MLIRLLVMIR + MLIRTargetLLVMIR MLIRTargetLLVMIRModuleTranslation ) @@ -92,6 +96,7 @@ add_mlir_translation_library(MLIRTargetArmSVE MLIRIR MLIRLLVMArmSVE MLIRLLVMIR + MLIRTargetLLVMIR MLIRTargetLLVMIRModuleTranslation ) @@ -112,6 +117,7 @@ add_mlir_translation_library(MLIRTargetNVVMIR MLIRIR MLIRLLVMIR MLIRNVVMIR + MLIRTargetLLVMIR MLIRTargetLLVMIRModuleTranslation ) @@ -132,5 +138,6 @@ add_mlir_translation_library(MLIRTargetROCDLIR MLIRIR MLIRLLVMIR MLIRROCDLIR + MLIRTargetLLVMIR MLIRTargetLLVMIRModuleTranslation ) diff --git a/mlir/lib/Target/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/CMakeLists.txt new file mode 100644 index 0000000..0ca0f41 --- /dev/null +++ b/mlir/lib/Target/LLVMIR/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Dialect) diff --git a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp index 476f365..bf8e248 100644 --- a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp @@ -13,6 +13,7 @@ #include "mlir/Target/LLVMIR.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/ModuleTranslation.h" #include "mlir/Translation.h" @@ -35,6 +36,24 @@ mlir::translateModuleToLLVMIR(ModuleOp m, llvm::LLVMContext &llvmContext, return llvmModule; } +void mlir::registerLLVMDialectTranslation(DialectRegistry ®istry) { + registry.insert(); + registry.addDialectInterface(); +} + +void mlir::registerLLVMDialectTranslation(MLIRContext &context) { + auto *dialect = context.getLoadedDialect(); + if (!dialect || dialect->getRegisteredInterface< + LLVMDialectLLVMIRTranslationInterface>() == nullptr) { + DialectRegistry registry; + registry.insert(); + registry.addDialectInterface(); + context.appendDialectRegistry(registry); + } +} + namespace mlir { void registerToLLVMIRTranslation() { TranslateFromMLIRRegistration registration( @@ -50,7 +69,8 @@ void registerToLLVMIRTranslation() { return success(); }, [](DialectRegistry ®istry) { - registry.insert(); + registry.insert(); + registerLLVMDialectTranslation(registry); }); } } // namespace mlir diff --git a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp index 668d9d9..7aee913 100644 --- a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/Target/LLVMIR.h" #include "mlir/Target/LLVMIR/ModuleTranslation.h" #include "mlir/Translation.h" @@ -68,6 +69,11 @@ protected: std::unique_ptr mlir::translateModuleToNVVMIR(Operation *m, llvm::LLVMContext &llvmContext, StringRef name) { + // Register the translation to LLVM IR if nobody else did before. This may + // happen if this translation is called inside a pass pipeline that converts + // GPU dialects to binary blobs without translating the rest of the code. + registerLLVMDialectTranslation(*m->getContext()); + auto llvmModule = LLVM::ModuleTranslation::translateModule( m, llvmContext, name); if (!llvmModule) @@ -111,7 +117,8 @@ void registerToNVVMIRTranslation() { return success(); }, [](DialectRegistry ®istry) { - registry.insert(); + registry.insert(); + registerLLVMDialectTranslation(registry); }); } } // namespace mlir diff --git a/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp index c415787..7ebbd3f 100644 --- a/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/Target/LLVMIR.h" #include "mlir/Target/LLVMIR/ModuleTranslation.h" #include "mlir/Translation.h" @@ -77,11 +78,16 @@ protected: std::unique_ptr mlir::translateModuleToROCDLIR(Operation *m, llvm::LLVMContext &llvmContext, StringRef name) { - // lower MLIR (with RODL Dialect) to LLVM IR (with ROCDL intrinsics) + // Register the translation to LLVM IR if nobody else did before. This may + // happen if this translation is called inside a pass pipeline that converts + // GPU dialects to binary blobs without translating the rest of the code. + registerLLVMDialectTranslation(*m->getContext()); + + // Lower MLIR (with RODL Dialect) to LLVM IR (with ROCDL intrinsics). auto llvmModule = LLVM::ModuleTranslation::translateModule( m, llvmContext, name); - // foreach GPU kernel + // Foreach GPU kernel: // 1. Insert AMDGPU_KERNEL calling convention. // 2. Insert amdgpu-flat-workgroup-size(1, 1024) attribute. for (auto func : @@ -114,7 +120,8 @@ void registerToROCDLIRTranslation() { return success(); }, [](DialectRegistry ®istry) { - registry.insert(); + registry.insert(); + registerLLVMDialectTranslation(registry); }); } } // namespace mlir diff --git a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt new file mode 100644 index 0000000..39d31dc --- /dev/null +++ b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(LLVMIR) diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/CMakeLists.txt new file mode 100644 index 0000000..2da7e95 --- /dev/null +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/CMakeLists.txt @@ -0,0 +1,12 @@ +add_mlir_translation_library(MLIRLLVMToLLVMIRTranslation + LLVMToLLVMIRTranslation.cpp + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMIR + MLIRSupport + MLIRTargetLLVMIRModuleTranslation + ) diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp new file mode 100644 index 0000000..25d5294 --- /dev/null +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -0,0 +1,405 @@ +//===- LLVMToLLVMIRTranslation.cpp - Translate LLVM dialect to LLVM IR ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a translation between the MLIR LLVM dialect and LLVM IR. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Operation.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" + +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InlineAsm.h" +#include "llvm/IR/MDBuilder.h" +#include "llvm/IR/Operator.h" + +using namespace mlir; +using namespace mlir::LLVM; + +#include "mlir/Dialect/LLVMIR/LLVMConversionEnumsToLLVM.inc" + +/// Convert MLIR integer comparison predicate to LLVM IR comparison predicate. +static llvm::CmpInst::Predicate getLLVMCmpPredicate(ICmpPredicate p) { + switch (p) { + case LLVM::ICmpPredicate::eq: + return llvm::CmpInst::Predicate::ICMP_EQ; + case LLVM::ICmpPredicate::ne: + return llvm::CmpInst::Predicate::ICMP_NE; + case LLVM::ICmpPredicate::slt: + return llvm::CmpInst::Predicate::ICMP_SLT; + case LLVM::ICmpPredicate::sle: + return llvm::CmpInst::Predicate::ICMP_SLE; + case LLVM::ICmpPredicate::sgt: + return llvm::CmpInst::Predicate::ICMP_SGT; + case LLVM::ICmpPredicate::sge: + return llvm::CmpInst::Predicate::ICMP_SGE; + case LLVM::ICmpPredicate::ult: + return llvm::CmpInst::Predicate::ICMP_ULT; + case LLVM::ICmpPredicate::ule: + return llvm::CmpInst::Predicate::ICMP_ULE; + case LLVM::ICmpPredicate::ugt: + return llvm::CmpInst::Predicate::ICMP_UGT; + case LLVM::ICmpPredicate::uge: + return llvm::CmpInst::Predicate::ICMP_UGE; + } + llvm_unreachable("incorrect comparison predicate"); +} + +static llvm::CmpInst::Predicate getLLVMCmpPredicate(FCmpPredicate p) { + switch (p) { + case LLVM::FCmpPredicate::_false: + return llvm::CmpInst::Predicate::FCMP_FALSE; + case LLVM::FCmpPredicate::oeq: + return llvm::CmpInst::Predicate::FCMP_OEQ; + case LLVM::FCmpPredicate::ogt: + return llvm::CmpInst::Predicate::FCMP_OGT; + case LLVM::FCmpPredicate::oge: + return llvm::CmpInst::Predicate::FCMP_OGE; + case LLVM::FCmpPredicate::olt: + return llvm::CmpInst::Predicate::FCMP_OLT; + case LLVM::FCmpPredicate::ole: + return llvm::CmpInst::Predicate::FCMP_OLE; + case LLVM::FCmpPredicate::one: + return llvm::CmpInst::Predicate::FCMP_ONE; + case LLVM::FCmpPredicate::ord: + return llvm::CmpInst::Predicate::FCMP_ORD; + case LLVM::FCmpPredicate::ueq: + return llvm::CmpInst::Predicate::FCMP_UEQ; + case LLVM::FCmpPredicate::ugt: + return llvm::CmpInst::Predicate::FCMP_UGT; + case LLVM::FCmpPredicate::uge: + return llvm::CmpInst::Predicate::FCMP_UGE; + case LLVM::FCmpPredicate::ult: + return llvm::CmpInst::Predicate::FCMP_ULT; + case LLVM::FCmpPredicate::ule: + return llvm::CmpInst::Predicate::FCMP_ULE; + case LLVM::FCmpPredicate::une: + return llvm::CmpInst::Predicate::FCMP_UNE; + case LLVM::FCmpPredicate::uno: + return llvm::CmpInst::Predicate::FCMP_UNO; + case LLVM::FCmpPredicate::_true: + return llvm::CmpInst::Predicate::FCMP_TRUE; + } + llvm_unreachable("incorrect comparison predicate"); +} + +static llvm::AtomicRMWInst::BinOp getLLVMAtomicBinOp(AtomicBinOp op) { + switch (op) { + case LLVM::AtomicBinOp::xchg: + return llvm::AtomicRMWInst::BinOp::Xchg; + case LLVM::AtomicBinOp::add: + return llvm::AtomicRMWInst::BinOp::Add; + case LLVM::AtomicBinOp::sub: + return llvm::AtomicRMWInst::BinOp::Sub; + case LLVM::AtomicBinOp::_and: + return llvm::AtomicRMWInst::BinOp::And; + case LLVM::AtomicBinOp::nand: + return llvm::AtomicRMWInst::BinOp::Nand; + case LLVM::AtomicBinOp::_or: + return llvm::AtomicRMWInst::BinOp::Or; + case LLVM::AtomicBinOp::_xor: + return llvm::AtomicRMWInst::BinOp::Xor; + case LLVM::AtomicBinOp::max: + return llvm::AtomicRMWInst::BinOp::Max; + case LLVM::AtomicBinOp::min: + return llvm::AtomicRMWInst::BinOp::Min; + case LLVM::AtomicBinOp::umax: + return llvm::AtomicRMWInst::BinOp::UMax; + case LLVM::AtomicBinOp::umin: + return llvm::AtomicRMWInst::BinOp::UMin; + case LLVM::AtomicBinOp::fadd: + return llvm::AtomicRMWInst::BinOp::FAdd; + case LLVM::AtomicBinOp::fsub: + return llvm::AtomicRMWInst::BinOp::FSub; + } + llvm_unreachable("incorrect atomic binary operator"); +} + +static llvm::AtomicOrdering getLLVMAtomicOrdering(AtomicOrdering ordering) { + switch (ordering) { + case LLVM::AtomicOrdering::not_atomic: + return llvm::AtomicOrdering::NotAtomic; + case LLVM::AtomicOrdering::unordered: + return llvm::AtomicOrdering::Unordered; + case LLVM::AtomicOrdering::monotonic: + return llvm::AtomicOrdering::Monotonic; + case LLVM::AtomicOrdering::acquire: + return llvm::AtomicOrdering::Acquire; + case LLVM::AtomicOrdering::release: + return llvm::AtomicOrdering::Release; + case LLVM::AtomicOrdering::acq_rel: + return llvm::AtomicOrdering::AcquireRelease; + case LLVM::AtomicOrdering::seq_cst: + return llvm::AtomicOrdering::SequentiallyConsistent; + } + llvm_unreachable("incorrect atomic ordering"); +} + +static llvm::FastMathFlags getFastmathFlags(FastmathFlagsInterface &op) { + using llvmFMF = llvm::FastMathFlags; + using FuncT = void (llvmFMF::*)(bool); + const std::pair handlers[] = { + // clang-format off + {FastmathFlags::nnan, &llvmFMF::setNoNaNs}, + {FastmathFlags::ninf, &llvmFMF::setNoInfs}, + {FastmathFlags::nsz, &llvmFMF::setNoSignedZeros}, + {FastmathFlags::arcp, &llvmFMF::setAllowReciprocal}, + {FastmathFlags::contract, &llvmFMF::setAllowContract}, + {FastmathFlags::afn, &llvmFMF::setApproxFunc}, + {FastmathFlags::reassoc, &llvmFMF::setAllowReassoc}, + {FastmathFlags::fast, &llvmFMF::setFast}, + // clang-format on + }; + llvm::FastMathFlags ret; + auto fmf = op.fastmathFlags(); + for (auto it : handlers) + if (bitEnumContains(fmf, it.first)) + (ret.*(it.second))(true); + return ret; +} + +namespace { +/// Dispatcher functional object targeting different overloads of +/// ModuleTranslation::mapValue. +// TODO: this is only necessary for compatibility with the code emitted from +// ODS, remove when ODS is updated (after all dialects have migrated to the new +// translation mechanism). +struct MapValueDispatcher { + explicit MapValueDispatcher(ModuleTranslation &mt) : moduleTranslation(mt) {} + + llvm::Value *&operator()(mlir::Value v) { + return moduleTranslation.mapValue(v); + } + + void operator()(mlir::Value m, llvm::Value *l) { + moduleTranslation.mapValue(m, l); + } + + LLVM::ModuleTranslation &moduleTranslation; +}; +} // end namespace + +static LogicalResult +convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + auto extractPosition = [](ArrayAttr attr) { + SmallVector position; + position.reserve(attr.size()); + for (Attribute v : attr) + position.push_back(v.cast().getValue().getZExtValue()); + return position; + }; + + llvm::IRBuilder<>::FastMathFlagGuard fmfGuard(builder); + if (auto fmf = dyn_cast(opInst)) + builder.setFastMathFlags(getFastmathFlags(fmf)); + + // TODO: these are necessary for compatibility with the code emitted from ODS, + // remove them when ODS is updated (after all dialects have migrated to the + // new translation mechanism). + MapValueDispatcher mapValue(moduleTranslation); + auto lookupValue = [&](mlir::Value v) { + return moduleTranslation.lookupValue(v); + }; + auto convertType = [&](Type ty) { return moduleTranslation.convertType(ty); }; + auto lookupValues = [&](ValueRange vs) { + return moduleTranslation.lookupValues(vs); + }; + auto getLLVMConstant = [&](llvm::Type *ty, Attribute attr, Location loc) { + return moduleTranslation.getLLVMConstant(ty, attr, loc); + }; + +#include "mlir/Dialect/LLVMIR/LLVMConversions.inc" + + // Emit function calls. If the "callee" attribute is present, this is a + // direct function call and we also need to look up the remapped function + // itself. Otherwise, this is an indirect call and the callee is the first + // operand, look it up as a normal value. Return the llvm::Value representing + // the function result, which may be of llvm::VoidTy type. + auto convertCall = [&](Operation &op) -> llvm::Value * { + auto operands = moduleTranslation.lookupValues(op.getOperands()); + ArrayRef operandsRef(operands); + if (auto attr = op.getAttrOfType("callee")) + return builder.CreateCall( + moduleTranslation.lookupFunction(attr.getValue()), operandsRef); + auto *calleePtrType = + cast(operandsRef.front()->getType()); + auto *calleeType = + cast(calleePtrType->getElementType()); + return builder.CreateCall(calleeType, operandsRef.front(), + operandsRef.drop_front()); + }; + + // Emit calls. If the called function has a result, remap the corresponding + // value. Note that LLVM IR dialect CallOp has either 0 or 1 result. + if (isa(opInst)) { + llvm::Value *result = convertCall(opInst); + if (opInst.getNumResults() != 0) { + mapValue(opInst.getResult(0), result); + return success(); + } + // Check that LLVM call returns void for 0-result functions. + return success(result->getType()->isVoidTy()); + } + + if (auto inlineAsmOp = dyn_cast(opInst)) { + // TODO: refactor function type creation which usually occurs in std-LLVM + // conversion. + SmallVector operandTypes; + operandTypes.reserve(inlineAsmOp.operands().size()); + for (auto t : inlineAsmOp.operands().getTypes()) + operandTypes.push_back(t); + + Type resultType; + if (inlineAsmOp.getNumResults() == 0) { + resultType = LLVM::LLVMVoidType::get(&moduleTranslation.getContext()); + } else { + assert(inlineAsmOp.getNumResults() == 1); + resultType = inlineAsmOp.getResultTypes()[0]; + } + auto ft = LLVM::LLVMFunctionType::get(resultType, operandTypes); + llvm::InlineAsm *inlineAsmInst = + inlineAsmOp.asm_dialect().hasValue() + ? llvm::InlineAsm::get( + static_cast(convertType(ft)), + inlineAsmOp.asm_string(), inlineAsmOp.constraints(), + inlineAsmOp.has_side_effects(), inlineAsmOp.is_align_stack(), + convertAsmDialectToLLVM(*inlineAsmOp.asm_dialect())) + : llvm::InlineAsm::get( + static_cast(convertType(ft)), + inlineAsmOp.asm_string(), inlineAsmOp.constraints(), + inlineAsmOp.has_side_effects(), inlineAsmOp.is_align_stack()); + llvm::Value *result = + builder.CreateCall(inlineAsmInst, lookupValues(inlineAsmOp.operands())); + if (opInst.getNumResults() != 0) + mapValue(opInst.getResult(0), result); + return success(); + } + + if (auto invOp = dyn_cast(opInst)) { + auto operands = lookupValues(opInst.getOperands()); + ArrayRef operandsRef(operands); + if (auto attr = opInst.getAttrOfType("callee")) { + builder.CreateInvoke(moduleTranslation.lookupFunction(attr.getValue()), + moduleTranslation.lookupBlock(invOp.getSuccessor(0)), + moduleTranslation.lookupBlock(invOp.getSuccessor(1)), + operandsRef); + } else { + auto *calleePtrType = + cast(operandsRef.front()->getType()); + auto *calleeType = + cast(calleePtrType->getElementType()); + builder.CreateInvoke(calleeType, operandsRef.front(), + moduleTranslation.lookupBlock(invOp.getSuccessor(0)), + moduleTranslation.lookupBlock(invOp.getSuccessor(1)), + operandsRef.drop_front()); + } + return success(); + } + + if (auto lpOp = dyn_cast(opInst)) { + llvm::Type *ty = convertType(lpOp.getType()); + llvm::LandingPadInst *lpi = + builder.CreateLandingPad(ty, lpOp.getNumOperands()); + + // Add clauses + for (llvm::Value *operand : lookupValues(lpOp.getOperands())) { + // All operands should be constant - checked by verifier + if (auto *constOperand = dyn_cast(operand)) + lpi->addClause(constOperand); + } + mapValue(lpOp.getResult(), lpi); + return success(); + } + + // Emit branches. We need to look up the remapped blocks and ignore the block + // arguments that were transformed into PHI nodes. + if (auto brOp = dyn_cast(opInst)) { + llvm::BranchInst *branch = + builder.CreateBr(moduleTranslation.lookupBlock(brOp.getSuccessor())); + moduleTranslation.mapBranch(&opInst, branch); + return success(); + } + if (auto condbrOp = dyn_cast(opInst)) { + auto weights = condbrOp.branch_weights(); + llvm::MDNode *branchWeights = nullptr; + if (weights) { + // Map weight attributes to LLVM metadata. + auto trueWeight = + weights.getValue().getValue(0).cast().getInt(); + auto falseWeight = + weights.getValue().getValue(1).cast().getInt(); + branchWeights = + llvm::MDBuilder(moduleTranslation.getLLVMContext()) + .createBranchWeights(static_cast(trueWeight), + static_cast(falseWeight)); + } + llvm::BranchInst *branch = builder.CreateCondBr( + moduleTranslation.lookupValue(condbrOp.getOperand(0)), + moduleTranslation.lookupBlock(condbrOp.getSuccessor(0)), + moduleTranslation.lookupBlock(condbrOp.getSuccessor(1)), branchWeights); + moduleTranslation.mapBranch(&opInst, branch); + return success(); + } + if (auto switchOp = dyn_cast(opInst)) { + llvm::MDNode *branchWeights = nullptr; + if (auto weights = switchOp.branch_weights()) { + llvm::SmallVector weightValues; + weightValues.reserve(weights->size()); + for (llvm::APInt weight : weights->cast()) + weightValues.push_back(weight.getLimitedValue()); + branchWeights = llvm::MDBuilder(moduleTranslation.getLLVMContext()) + .createBranchWeights(weightValues); + } + + llvm::SwitchInst *switchInst = builder.CreateSwitch( + moduleTranslation.lookupValue(switchOp.value()), + moduleTranslation.lookupBlock(switchOp.defaultDestination()), + switchOp.caseDestinations().size(), branchWeights); + + auto *ty = + llvm::cast(convertType(switchOp.value().getType())); + for (auto i : + llvm::zip(switchOp.case_values()->cast(), + switchOp.caseDestinations())) + switchInst->addCase( + llvm::ConstantInt::get(ty, std::get<0>(i).getLimitedValue()), + moduleTranslation.lookupBlock(std::get<1>(i))); + + moduleTranslation.mapBranch(&opInst, switchInst); + return success(); + } + + // Emit addressof. We need to look up the global value referenced by the + // operation and store it in the MLIR-to-LLVM value mapping. This does not + // emit any LLVM instruction. + if (auto addressOfOp = dyn_cast(opInst)) { + LLVM::GlobalOp global = addressOfOp.getGlobal(); + LLVM::LLVMFuncOp function = addressOfOp.getFunction(); + + // The verifier should not have allowed this. + assert((global || function) && + "referencing an undefined global or function"); + + mapValue(addressOfOp.getResult(), + global ? moduleTranslation.lookupGlobal(global) + : moduleTranslation.lookupFunction(function.getName())); + return success(); + } + + return failure(); +} + +LogicalResult mlir::LLVMDialectLLVMIRTranslationInterface::convertOperation( + Operation *op, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) const { + return convertOperationImpl(*op, builder, moduleTranslation); +} diff --git a/mlir/lib/Target/LLVMIR/LLVMAVX512Intr.cpp b/mlir/lib/Target/LLVMIR/LLVMAVX512Intr.cpp index 52f1792..082504d 100644 --- a/mlir/lib/Target/LLVMIR/LLVMAVX512Intr.cpp +++ b/mlir/lib/Target/LLVMIR/LLVMAVX512Intr.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h" +#include "mlir/Target/LLVMIR.h" #include "mlir/Target/LLVMIR/ModuleTranslation.h" #include "mlir/Translation.h" #include "llvm/IR/IntrinsicsX86.h" @@ -57,7 +58,8 @@ void registerAVX512ToLLVMIRTranslation() { return success(); }, [](DialectRegistry ®istry) { - registry.insert(); + registry.insert(); + registerLLVMDialectTranslation(registry); }); } } // namespace mlir diff --git a/mlir/lib/Target/LLVMIR/LLVMArmNeonIntr.cpp b/mlir/lib/Target/LLVMIR/LLVMArmNeonIntr.cpp index 0bd40ef..14bbd14 100644 --- a/mlir/lib/Target/LLVMIR/LLVMArmNeonIntr.cpp +++ b/mlir/lib/Target/LLVMIR/LLVMArmNeonIntr.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/LLVMIR/LLVMArmNeonDialect.h" +#include "mlir/Target/LLVMIR.h" #include "mlir/Target/LLVMIR/ModuleTranslation.h" #include "mlir/Translation.h" #include "llvm/IR/IntrinsicsAArch64.h" @@ -57,7 +58,8 @@ void registerArmNeonToLLVMIRTranslation() { return success(); }, [](DialectRegistry ®istry) { - registry.insert(); + registry.insert(); + registerLLVMDialectTranslation(registry); }); } } // namespace mlir diff --git a/mlir/lib/Target/LLVMIR/LLVMArmSVEIntr.cpp b/mlir/lib/Target/LLVMIR/LLVMArmSVEIntr.cpp index 717583a..5ef8b48 100644 --- a/mlir/lib/Target/LLVMIR/LLVMArmSVEIntr.cpp +++ b/mlir/lib/Target/LLVMIR/LLVMArmSVEIntr.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h" +#include "mlir/Target/LLVMIR.h" #include "mlir/Target/LLVMIR/ModuleTranslation.h" #include "mlir/Translation.h" #include "llvm/IR/IntrinsicsAArch64.h" @@ -57,7 +58,8 @@ void registerArmSVEToLLVMIRTranslation() { return success(); }, [](DialectRegistry ®istry) { - registry.insert(); + registry.insert(); + registerLLVMDialectTranslation(registry); }); } } // namespace mlir diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index dce3846..6728511 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -21,6 +21,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/RegionGraphTraits.h" #include "mlir/Support/LLVM.h" +#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h" #include "mlir/Target/LLVMIR/TypeTranslation.h" #include "llvm/ADT/TypeSwitch.h" @@ -183,130 +184,14 @@ llvm::Constant *ModuleTranslation::getLLVMConstant(llvm::Type *llvmType, return nullptr; } -/// Convert MLIR integer comparison predicate to LLVM IR comparison predicate. -static llvm::CmpInst::Predicate getLLVMCmpPredicate(ICmpPredicate p) { - switch (p) { - case LLVM::ICmpPredicate::eq: - return llvm::CmpInst::Predicate::ICMP_EQ; - case LLVM::ICmpPredicate::ne: - return llvm::CmpInst::Predicate::ICMP_NE; - case LLVM::ICmpPredicate::slt: - return llvm::CmpInst::Predicate::ICMP_SLT; - case LLVM::ICmpPredicate::sle: - return llvm::CmpInst::Predicate::ICMP_SLE; - case LLVM::ICmpPredicate::sgt: - return llvm::CmpInst::Predicate::ICMP_SGT; - case LLVM::ICmpPredicate::sge: - return llvm::CmpInst::Predicate::ICMP_SGE; - case LLVM::ICmpPredicate::ult: - return llvm::CmpInst::Predicate::ICMP_ULT; - case LLVM::ICmpPredicate::ule: - return llvm::CmpInst::Predicate::ICMP_ULE; - case LLVM::ICmpPredicate::ugt: - return llvm::CmpInst::Predicate::ICMP_UGT; - case LLVM::ICmpPredicate::uge: - return llvm::CmpInst::Predicate::ICMP_UGE; - } - llvm_unreachable("incorrect comparison predicate"); -} - -static llvm::CmpInst::Predicate getLLVMCmpPredicate(FCmpPredicate p) { - switch (p) { - case LLVM::FCmpPredicate::_false: - return llvm::CmpInst::Predicate::FCMP_FALSE; - case LLVM::FCmpPredicate::oeq: - return llvm::CmpInst::Predicate::FCMP_OEQ; - case LLVM::FCmpPredicate::ogt: - return llvm::CmpInst::Predicate::FCMP_OGT; - case LLVM::FCmpPredicate::oge: - return llvm::CmpInst::Predicate::FCMP_OGE; - case LLVM::FCmpPredicate::olt: - return llvm::CmpInst::Predicate::FCMP_OLT; - case LLVM::FCmpPredicate::ole: - return llvm::CmpInst::Predicate::FCMP_OLE; - case LLVM::FCmpPredicate::one: - return llvm::CmpInst::Predicate::FCMP_ONE; - case LLVM::FCmpPredicate::ord: - return llvm::CmpInst::Predicate::FCMP_ORD; - case LLVM::FCmpPredicate::ueq: - return llvm::CmpInst::Predicate::FCMP_UEQ; - case LLVM::FCmpPredicate::ugt: - return llvm::CmpInst::Predicate::FCMP_UGT; - case LLVM::FCmpPredicate::uge: - return llvm::CmpInst::Predicate::FCMP_UGE; - case LLVM::FCmpPredicate::ult: - return llvm::CmpInst::Predicate::FCMP_ULT; - case LLVM::FCmpPredicate::ule: - return llvm::CmpInst::Predicate::FCMP_ULE; - case LLVM::FCmpPredicate::une: - return llvm::CmpInst::Predicate::FCMP_UNE; - case LLVM::FCmpPredicate::uno: - return llvm::CmpInst::Predicate::FCMP_UNO; - case LLVM::FCmpPredicate::_true: - return llvm::CmpInst::Predicate::FCMP_TRUE; - } - llvm_unreachable("incorrect comparison predicate"); -} - -static llvm::AtomicRMWInst::BinOp getLLVMAtomicBinOp(AtomicBinOp op) { - switch (op) { - case LLVM::AtomicBinOp::xchg: - return llvm::AtomicRMWInst::BinOp::Xchg; - case LLVM::AtomicBinOp::add: - return llvm::AtomicRMWInst::BinOp::Add; - case LLVM::AtomicBinOp::sub: - return llvm::AtomicRMWInst::BinOp::Sub; - case LLVM::AtomicBinOp::_and: - return llvm::AtomicRMWInst::BinOp::And; - case LLVM::AtomicBinOp::nand: - return llvm::AtomicRMWInst::BinOp::Nand; - case LLVM::AtomicBinOp::_or: - return llvm::AtomicRMWInst::BinOp::Or; - case LLVM::AtomicBinOp::_xor: - return llvm::AtomicRMWInst::BinOp::Xor; - case LLVM::AtomicBinOp::max: - return llvm::AtomicRMWInst::BinOp::Max; - case LLVM::AtomicBinOp::min: - return llvm::AtomicRMWInst::BinOp::Min; - case LLVM::AtomicBinOp::umax: - return llvm::AtomicRMWInst::BinOp::UMax; - case LLVM::AtomicBinOp::umin: - return llvm::AtomicRMWInst::BinOp::UMin; - case LLVM::AtomicBinOp::fadd: - return llvm::AtomicRMWInst::BinOp::FAdd; - case LLVM::AtomicBinOp::fsub: - return llvm::AtomicRMWInst::BinOp::FSub; - } - llvm_unreachable("incorrect atomic binary operator"); -} - -static llvm::AtomicOrdering getLLVMAtomicOrdering(AtomicOrdering ordering) { - switch (ordering) { - case LLVM::AtomicOrdering::not_atomic: - return llvm::AtomicOrdering::NotAtomic; - case LLVM::AtomicOrdering::unordered: - return llvm::AtomicOrdering::Unordered; - case LLVM::AtomicOrdering::monotonic: - return llvm::AtomicOrdering::Monotonic; - case LLVM::AtomicOrdering::acquire: - return llvm::AtomicOrdering::Acquire; - case LLVM::AtomicOrdering::release: - return llvm::AtomicOrdering::Release; - case LLVM::AtomicOrdering::acq_rel: - return llvm::AtomicOrdering::AcquireRelease; - case LLVM::AtomicOrdering::seq_cst: - return llvm::AtomicOrdering::SequentiallyConsistent; - } - llvm_unreachable("incorrect atomic ordering"); -} - ModuleTranslation::ModuleTranslation(Operation *module, std::unique_ptr llvmModule) : mlirModule(module), llvmModule(std::move(llvmModule)), debugTranslation( std::make_unique(module, *this->llvmModule)), ompDialect(module->getContext()->getLoadedDialect("omp")), - typeTranslator(this->llvmModule->getContext()) { + typeTranslator(this->llvmModule->getContext()), + iface(module->getContext()) { assert(satisfiesLLVMModule(mlirModule) && "mlirModule should honor LLVM's module semantics."); } @@ -658,221 +543,17 @@ ModuleTranslation::convertOmpOperation(Operation &opInst, }); } -static llvm::FastMathFlags getFastmathFlags(FastmathFlagsInterface &op) { - using llvmFMF = llvm::FastMathFlags; - using FuncT = void (llvmFMF::*)(bool); - const std::pair handlers[] = { - // clang-format off - {FastmathFlags::nnan, &llvmFMF::setNoNaNs}, - {FastmathFlags::ninf, &llvmFMF::setNoInfs}, - {FastmathFlags::nsz, &llvmFMF::setNoSignedZeros}, - {FastmathFlags::arcp, &llvmFMF::setAllowReciprocal}, - {FastmathFlags::contract, &llvmFMF::setAllowContract}, - {FastmathFlags::afn, &llvmFMF::setApproxFunc}, - {FastmathFlags::reassoc, &llvmFMF::setAllowReassoc}, - {FastmathFlags::fast, &llvmFMF::setFast}, - // clang-format on - }; - llvm::FastMathFlags ret; - auto fmf = op.fastmathFlags(); - for (auto it : handlers) - if (bitEnumContains(fmf, it.first)) - (ret.*(it.second))(true); - return ret; -} - /// Given a single MLIR operation, create the corresponding LLVM IR operation /// using the `builder`. LLVM IR Builder does not have a generic interface so /// this has to be a long chain of `if`s calling different functions with a /// different number of arguments. LogicalResult ModuleTranslation::convertOperation(Operation &opInst, llvm::IRBuilder<> &builder) { - auto extractPosition = [](ArrayAttr attr) { - SmallVector position; - position.reserve(attr.size()); - for (Attribute v : attr) - position.push_back(v.cast().getValue().getZExtValue()); - return position; - }; - - llvm::IRBuilder<>::FastMathFlagGuard fmfGuard(builder); - if (auto fmf = dyn_cast(opInst)) - builder.setFastMathFlags(getFastmathFlags(fmf)); - -#include "mlir/Dialect/LLVMIR/LLVMConversions.inc" - - // Emit function calls. If the "callee" attribute is present, this is a - // direct function call and we also need to look up the remapped function - // itself. Otherwise, this is an indirect call and the callee is the first - // operand, look it up as a normal value. Return the llvm::Value representing - // the function result, which may be of llvm::VoidTy type. - auto convertCall = [this, &builder](Operation &op) -> llvm::Value * { - auto operands = lookupValues(op.getOperands()); - ArrayRef operandsRef(operands); - if (auto attr = op.getAttrOfType("callee")) - return builder.CreateCall(lookupFunction(attr.getValue()), operandsRef); - auto *calleePtrType = - cast(operandsRef.front()->getType()); - auto *calleeType = - cast(calleePtrType->getElementType()); - return builder.CreateCall(calleeType, operandsRef.front(), - operandsRef.drop_front()); - }; - - // Emit calls. If the called function has a result, remap the corresponding - // value. Note that LLVM IR dialect CallOp has either 0 or 1 result. - if (isa(opInst)) { - llvm::Value *result = convertCall(opInst); - if (opInst.getNumResults() != 0) { - mapValue(opInst.getResult(0), result); - return success(); - } - // Check that LLVM call returns void for 0-result functions. - return success(result->getType()->isVoidTy()); - } - if (auto inlineAsmOp = dyn_cast(opInst)) { - // TODO: refactor function type creation which usually occurs in std-LLVM - // conversion. - SmallVector operandTypes; - operandTypes.reserve(inlineAsmOp.operands().size()); - for (auto t : inlineAsmOp.operands().getTypes()) - operandTypes.push_back(t); - - Type resultType; - if (inlineAsmOp.getNumResults() == 0) { - resultType = LLVM::LLVMVoidType::get(mlirModule->getContext()); - } else { - assert(inlineAsmOp.getNumResults() == 1); - resultType = inlineAsmOp.getResultTypes()[0]; - } - auto ft = LLVM::LLVMFunctionType::get(resultType, operandTypes); - llvm::InlineAsm *inlineAsmInst = - inlineAsmOp.asm_dialect().hasValue() - ? llvm::InlineAsm::get( - static_cast(convertType(ft)), - inlineAsmOp.asm_string(), inlineAsmOp.constraints(), - inlineAsmOp.has_side_effects(), inlineAsmOp.is_align_stack(), - convertAsmDialectToLLVM(*inlineAsmOp.asm_dialect())) - : llvm::InlineAsm::get( - static_cast(convertType(ft)), - inlineAsmOp.asm_string(), inlineAsmOp.constraints(), - inlineAsmOp.has_side_effects(), inlineAsmOp.is_align_stack()); - llvm::Value *result = - builder.CreateCall(inlineAsmInst, lookupValues(inlineAsmOp.operands())); - if (opInst.getNumResults() != 0) - mapValue(opInst.getResult(0), result); + // TODO(zinenko): this should be the "main" conversion here, remove the + // dispatch below. + if (succeeded(iface.convertOperation(&opInst, builder, *this))) return success(); - } - - if (auto invOp = dyn_cast(opInst)) { - auto operands = lookupValues(opInst.getOperands()); - ArrayRef operandsRef(operands); - if (auto attr = opInst.getAttrOfType("callee")) { - builder.CreateInvoke(lookupFunction(attr.getValue()), - lookupBlock(invOp.getSuccessor(0)), - lookupBlock(invOp.getSuccessor(1)), operandsRef); - } else { - auto *calleePtrType = - cast(operandsRef.front()->getType()); - auto *calleeType = - cast(calleePtrType->getElementType()); - builder.CreateInvoke( - calleeType, operandsRef.front(), lookupBlock(invOp.getSuccessor(0)), - lookupBlock(invOp.getSuccessor(1)), operandsRef.drop_front()); - } - return success(); - } - - if (auto lpOp = dyn_cast(opInst)) { - llvm::Type *ty = convertType(lpOp.getType()); - llvm::LandingPadInst *lpi = - builder.CreateLandingPad(ty, lpOp.getNumOperands()); - - // Add clauses - for (llvm::Value *operand : lookupValues(lpOp.getOperands())) { - // All operands should be constant - checked by verifier - if (auto *constOperand = dyn_cast(operand)) - lpi->addClause(constOperand); - } - mapValue(lpOp.getResult(), lpi); - return success(); - } - - // Emit branches. We need to look up the remapped blocks and ignore the block - // arguments that were transformed into PHI nodes. - if (auto brOp = dyn_cast(opInst)) { - llvm::BranchInst *branch = - builder.CreateBr(lookupBlock(brOp.getSuccessor())); - mapBranch(&opInst, branch); - return success(); - } - if (auto condbrOp = dyn_cast(opInst)) { - auto weights = condbrOp.branch_weights(); - llvm::MDNode *branchWeights = nullptr; - if (weights) { - // Map weight attributes to LLVM metadata. - auto trueWeight = - weights.getValue().getValue(0).cast().getInt(); - auto falseWeight = - weights.getValue().getValue(1).cast().getInt(); - branchWeights = - llvm::MDBuilder(llvmModule->getContext()) - .createBranchWeights(static_cast(trueWeight), - static_cast(falseWeight)); - } - llvm::BranchInst *branch = builder.CreateCondBr( - lookupValue(condbrOp.getOperand(0)), - lookupBlock(condbrOp.getSuccessor(0)), - lookupBlock(condbrOp.getSuccessor(1)), branchWeights); - mapBranch(&opInst, branch); - return success(); - } - if (auto switchOp = dyn_cast(opInst)) { - llvm::MDNode *branchWeights = nullptr; - if (auto weights = switchOp.branch_weights()) { - llvm::SmallVector weightValues; - weightValues.reserve(weights->size()); - for (llvm::APInt weight : weights->cast()) - weightValues.push_back(weight.getLimitedValue()); - branchWeights = llvm::MDBuilder(llvmModule->getContext()) - .createBranchWeights(weightValues); - } - - llvm::SwitchInst *switchInst = - builder.CreateSwitch(lookupValue(switchOp.value()), - lookupBlock(switchOp.defaultDestination()), - switchOp.caseDestinations().size(), branchWeights); - - auto *ty = - llvm::cast(convertType(switchOp.value().getType())); - for (auto i : - llvm::zip(switchOp.case_values()->cast(), - switchOp.caseDestinations())) - switchInst->addCase( - llvm::ConstantInt::get(ty, std::get<0>(i).getLimitedValue()), - lookupBlock(std::get<1>(i))); - - mapBranch(&opInst, switchInst); - return success(); - } - - // Emit addressof. We need to look up the global value referenced by the - // operation and store it in the MLIR-to-LLVM value mapping. This does not - // emit any LLVM instruction. - if (auto addressOfOp = dyn_cast(opInst)) { - LLVM::GlobalOp global = addressOfOp.getGlobal(); - LLVM::LLVMFuncOp function = addressOfOp.getFunction(); - - // The verifier should not have allowed this. - assert((global || function) && - "referencing an undefined global or function"); - - mapValue(addressOfOp.getResult(), global - ? globalsMapping.lookup(global) - : lookupFunction(function.getName())); - return success(); - } if (ompDialect && opInst.getDialect() == ompDialect) return convertOmpOperation(opInst, builder); diff --git a/mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp b/mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp index e500db3..4389039 100644 --- a/mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp +++ b/mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp @@ -17,6 +17,8 @@ #include "mlir/ExecutionEngine/JitRunner.h" #include "mlir/ExecutionEngine/OptUtils.h" #include "mlir/IR/Dialect.h" +#include "mlir/Target/LLVMIR.h" + #include "llvm/Support/InitLLVM.h" #include "llvm/Support/TargetSelect.h" @@ -29,5 +31,7 @@ int main(int argc, char **argv) { mlir::DialectRegistry registry; registry.insert(); + mlir::registerLLVMDialectTranslation(registry); + return mlir::JitRunnerMain(argc, argv, registry); } diff --git a/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp b/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp index dd605b9..1e6a80b 100644 --- a/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp +++ b/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp @@ -31,9 +31,11 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" +#include "mlir/Target/LLVMIR.h" #include "mlir/Target/NVVMIR.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/Passes.h" + #include "llvm/Support/InitLLVM.h" #include "llvm/Support/TargetSelect.h" @@ -154,6 +156,7 @@ int main(int argc, char **argv) { registry.insert(); + mlir::registerLLVMDialectTranslation(registry); return mlir::JitRunnerMain(argc, argv, registry, jitRunnerConfig); } diff --git a/mlir/tools/mlir-rocm-runner/mlir-rocm-runner.cpp b/mlir/tools/mlir-rocm-runner/mlir-rocm-runner.cpp index 4df7fed..8237545 100644 --- a/mlir/tools/mlir-rocm-runner/mlir-rocm-runner.cpp +++ b/mlir/tools/mlir-rocm-runner/mlir-rocm-runner.cpp @@ -30,6 +30,7 @@ #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/FileUtilities.h" +#include "mlir/Target/LLVMIR.h" #include "mlir/Target/ROCDLIR.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/Passes.h" @@ -340,6 +341,7 @@ int main(int argc, char **argv) { mlir::DialectRegistry registry; registry.insert(); + mlir::registerLLVMDialectTranslation(registry); return mlir::JitRunnerMain(argc, argv, registry, jitRunnerConfig); } diff --git a/mlir/tools/mlir-spirv-cpu-runner/mlir-spirv-cpu-runner.cpp b/mlir/tools/mlir-spirv-cpu-runner/mlir-spirv-cpu-runner.cpp index 3cfd661..7955dca 100644 --- a/mlir/tools/mlir-spirv-cpu-runner/mlir-spirv-cpu-runner.cpp +++ b/mlir/tools/mlir-spirv-cpu-runner/mlir-spirv-cpu-runner.cpp @@ -96,6 +96,7 @@ int main(int argc, char **argv) { mlir::DialectRegistry registry; registry.insert(); + mlir::registerLLVMDialectTranslation(registry); return mlir::JitRunnerMain(argc, argv, registry, jitRunnerConfig); } diff --git a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp index e407172..651e59c 100644 --- a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp +++ b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp @@ -27,6 +27,7 @@ #include "mlir/ExecutionEngine/OptUtils.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" +#include "mlir/Target/LLVMIR.h" #include "llvm/Support/InitLLVM.h" #include "llvm/Support/TargetSelect.h" @@ -67,6 +68,7 @@ int main(int argc, char **argv) { mlir::DialectRegistry registry; registry.insert(); + mlir::registerLLVMDialectTranslation(registry); return mlir::JitRunnerMain(argc, argv, registry, jitRunnerConfig); } diff --git a/mlir/unittests/ExecutionEngine/Invoke.cpp b/mlir/unittests/ExecutionEngine/Invoke.cpp index 9b7450e..230c9b6 100644 --- a/mlir/unittests/ExecutionEngine/Invoke.cpp +++ b/mlir/unittests/ExecutionEngine/Invoke.cpp @@ -19,6 +19,7 @@ #include "mlir/InitAllDialects.h" #include "mlir/Parser.h" #include "mlir/Pass/PassManager.h" +#include "mlir/Target/LLVMIR.h" #include "llvm/Support/TargetSelect.h" #include "llvm/Support/raw_ostream.h" @@ -51,8 +52,10 @@ TEST(MLIRExecutionEngine, AddInteger) { return %res : i32 } )mlir"; - MLIRContext context; - registerAllDialects(context); + DialectRegistry registry; + registerAllDialects(registry); + registerLLVMDialectTranslation(registry); + MLIRContext context(registry); OwningModuleRef module = parseSourceString(moduleStr, &context); ASSERT_TRUE(!!module); ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); @@ -74,8 +77,10 @@ TEST(MLIRExecutionEngine, SubtractFloat) { return %res : f32 } )mlir"; - MLIRContext context; - registerAllDialects(context); + DialectRegistry registry; + registerAllDialects(registry); + registerLLVMDialectTranslation(registry); + MLIRContext context(registry); OwningModuleRef module = parseSourceString(moduleStr, &context); ASSERT_TRUE(!!module); ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); @@ -102,8 +107,10 @@ TEST(NativeMemRefJit, ZeroRankMemref) { return } )mlir"; - MLIRContext context; - registerAllDialects(context); + DialectRegistry registry; + registerAllDialects(registry); + registerLLVMDialectTranslation(registry); + MLIRContext context(registry); auto module = parseSourceString(moduleStr, &context); ASSERT_TRUE(!!module); ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); @@ -135,8 +142,10 @@ TEST(NativeMemRefJit, RankOneMemref) { return } )mlir"; - MLIRContext context; - registerAllDialects(context); + DialectRegistry registry; + registerAllDialects(registry); + registerLLVMDialectTranslation(registry); + MLIRContext context(registry); auto module = parseSourceString(moduleStr, &context); ASSERT_TRUE(!!module); ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); @@ -187,8 +196,10 @@ TEST(NativeMemRefJit, BasicMemref) { return } )mlir"; - MLIRContext context; - registerAllDialects(context); + DialectRegistry registry; + registerAllDialects(registry); + registerLLVMDialectTranslation(registry); + MLIRContext context(registry); OwningModuleRef module = parseSourceString(moduleStr, &context); ASSERT_TRUE(!!module); ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); @@ -227,8 +238,10 @@ TEST(NativeMemRefJit, JITCallback) { return } )mlir"; - MLIRContext context; - registerAllDialects(context); + DialectRegistry registry; + registerAllDialects(registry); + registerLLVMDialectTranslation(registry); + MLIRContext context(registry); auto module = parseSourceString(moduleStr, &context); ASSERT_TRUE(!!module); ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));