}
int dumpLLVMIR(mlir::ModuleOp module) {
+ // Register the translation to LLVM IR with the MLIR context.
+ mlir::registerLLVMDialectTranslation(*module->getContext());
+
// Convert the module to LLVM IR in a new LLVM IR context.
llvm::LLVMContext llvmContext;
auto llvmModule = mlir::translateModuleToLLVMIR(module, llvmContext);
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
+ // Register the translation from MLIR to LLVM IR, which must happen before we
+ // can JIT-compile.
+ mlir::registerLLVMDialectTranslation(*module->getContext());
+
// An optimization pipeline to use within the execution engine.
auto optPipeline = mlir::makeOptimizingTransformer(
/*optLevel=*/enableOpt ? 3 : 0, /*sizeLevel=*/0,
}
int dumpLLVMIR(mlir::ModuleOp module) {
+ // Register the translation to LLVM IR with the MLIR context.
+ mlir::registerLLVMDialectTranslation(*module->getContext());
+
// Convert the module to LLVM IR in a new LLVM IR context.
llvm::LLVMContext llvmContext;
auto llvmModule = mlir::translateModuleToLLVMIR(module, llvmContext);
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
+ // Register the translation from MLIR to LLVM IR, which must happen before we
+ // can JIT-compile.
+ mlir::registerLLVMDialectTranslation(*module->getContext());
+
// An optimization pipeline to use within the execution engine.
auto optPipeline = mlir::makeOptimizingTransformer(
/*optLevel=*/enableOpt ? 3 : 0, /*sizeLevel=*/0,
namespace mlir {
+class DialectRegistry;
class OwningModuleRef;
class MLIRContext;
class ModuleOp;
translateLLVMIRToModule(std::unique_ptr<llvm::Module> llvmModule,
MLIRContext *context);
+/// Register the LLVM dialect and the translation from it to the LLVM IR in the
+/// given registry;
+void registerLLVMDialectTranslation(DialectRegistry ®istry);
+
+/// Register the LLVM dialect and the translation from it in the registry
+/// associated with the given context. This checks if the interface is already
+/// registered and avoids double registation.
+void registerLLVMDialectTranslation(MLIRContext &context);
+
} // namespace mlir
#endif // MLIR_TARGET_LLVMIR_H
--- /dev/null
+//===- LLVMToLLVMIRTranslation.h - LLVM Dialect to LLVM IR-------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the dialect interface for translating the LLVM dialect
+// to LLVM IR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TARGET_LLVMIR_DIALECT_LLVMIR_LLVMTOLLVMIRTRANSLATION_H
+#define MLIR_TARGET_LLVMIR_DIALECT_LLVMIR_LLVMTOLLVMIRTRANSLATION_H
+
+#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
+
+namespace mlir {
+
+/// Implementation of the dialect interface that converts operations beloning to
+/// the LLVM dialect to LLVM IR.
+class LLVMDialectLLVMIRTranslationInterface
+ : public LLVMTranslationDialectInterface {
+public:
+ using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
+
+ /// Translates the given operation to LLVM IR using the provided IR builder
+ /// and saving the state in `moduleTranslation`.
+ LogicalResult
+ convertOperation(Operation *op, llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation) const final;
+};
+
+} // namespace mlir
+
+#endif // MLIR_TARGET_LLVMIR_DIALECT_LLVMIR_LLVMTOLLVMIRTRANSLATION_H
--- /dev/null
+//===- LLVMTranslationInterface.h - Translation to LLVM iface ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header file defines dialect interfaces for translation to LLVM IR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TARGET_LLVMIR_LLVMTRANSLATIONINTERFACE_H
+#define MLIR_TARGET_LLVMIR_LLVMTRANSLATIONINTERFACE_H
+
+#include "mlir/IR/DialectInterface.h"
+#include "mlir/Support/LogicalResult.h"
+
+namespace llvm {
+class IRBuilderBase;
+}
+
+namespace mlir {
+namespace LLVM {
+class ModuleTranslation;
+} // namespace LLVM
+
+/// Base class for dialect interfaces providing translation to LLVM IR.
+/// Dialects that can be translated should provide an implementation of this
+/// interface for the supported operations. The interface may be implemented in
+/// a separate library to avoid the "main" dialect library depending on LLVM IR.
+/// The interface can be attached using the delayed registration mechanism
+/// available in DialectRegistry.
+class LLVMTranslationDialectInterface
+ : public DialectInterface::Base<LLVMTranslationDialectInterface> {
+public:
+ LLVMTranslationDialectInterface(Dialect *dialect) : Base(dialect) {}
+
+ /// Hook for derived dialect interface to provide translation of the
+ /// operations to LLVM IR.
+ virtual LogicalResult
+ convertOperation(Operation *op, llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation) const {
+ return failure();
+ }
+};
+
+/// Interface collection for translation to LLVM IR, dispatches to a concrete
+/// interface implementation based on the dialect to which the given op belongs.
+class LLVMTranslationInterface
+ : public DialectInterfaceCollection<LLVMTranslationDialectInterface> {
+public:
+ using Base::Base;
+
+ /// Translates the given operation to LLVM IR using the interface implemented
+ /// by the op's dialect.
+ virtual LogicalResult
+ convertOperation(Operation *op, llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation) const {
+ if (const LLVMTranslationDialectInterface *iface = getInterfaceFor(op))
+ return iface->convertOperation(op, builder, moduleTranslation);
+ return failure();
+ }
+};
+
+} // namespace mlir
+
+#endif // MLIR_TARGET_LLVMIR_LLVMTRANSLATIONINTERFACE_H
#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Value.h"
+#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
#include "mlir/Target/LLVMIR/TypeTranslation.h"
#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
return branchMapping.lookup(op);
}
+ /// Converts the type from MLIR LLVM dialect to LLVM.
+ llvm::Type *convertType(Type type);
+
+ /// Looks up remapped a list of remapped values.
+ SmallVector<llvm::Value *, 8> lookupValues(ValueRange values);
+
+ /// Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`.
+ /// This currently supports integer, floating point, splat and dense element
+ /// attributes and combinations thereof. In case of error, report it to `loc`
+ /// and return nullptr.
+ llvm::Constant *getLLVMConstant(llvm::Type *llvmType, Attribute attr,
+ Location loc);
+
+ /// Returns the MLIR context of the module being translated.
+ MLIRContext &getContext() { return *mlirModule->getContext(); }
+
+ /// Returns the LLVM context in which the IR is being constructed.
+ llvm::LLVMContext &getLLVMContext() { return llvmModule->getContext(); }
+
+ /// Finds an LLVM IR global value that corresponds to the given MLIR operation
+ /// defining a global value.
+ llvm::GlobalValue *lookupGlobal(Operation *op) {
+ return globalsMapping.lookup(op);
+ }
+
protected:
/// Translate the given MLIR module expressed in MLIR LLVM IR dialect into an
/// LLVM IR module. The MLIR LLVM IR dialect holds a pointer to an
virtual LogicalResult convertOmpWsLoop(Operation &opInst,
llvm::IRBuilder<> &builder);
- /// Converts the type from MLIR LLVM dialect to LLVM.
- llvm::Type *convertType(Type type);
-
static std::unique_ptr<llvm::Module>
prepareLLVMModule(Operation *m, llvm::LLVMContext &llvmContext,
StringRef name);
- /// A helper to look up remapped operands in the value remapping table.
- SmallVector<llvm::Value *, 8> lookupValues(ValueRange values);
-
private:
/// Check whether the module contains only supported ops directly in its body.
static LogicalResult checkSupportedModuleOps(Operation *m);
LogicalResult convertBlock(Block &bb, bool ignoreArguments,
llvm::IRBuilder<> &builder);
- llvm::Constant *getLLVMConstant(llvm::Type *llvmType, Attribute attr,
- Location loc);
-
/// Original and translated module.
Operation *mlirModule;
std::unique_ptr<llvm::Module> llvmModule;
/// A stateful object used to translate types.
TypeToLLVMIRTranslator typeTranslator;
-private:
+ LLVMTranslationInterface iface;
+
/// Mappings between original and translated values, used for lookups.
llvm::StringMap<llvm::Function *> functionMapping;
DenseMap<Value, llvm::Value *> valueMapping;
add_subdirectory(SPIRV)
+add_subdirectory(LLVMIR)
add_mlir_translation_library(MLIRTargetLLVMIRModuleTranslation
LLVMIR/DebugTranslation.cpp
MLIRIR
MLIRLLVMAVX512
MLIRLLVMIR
+ MLIRTargetLLVMIR
MLIRTargetLLVMIRModuleTranslation
)
IRReader
LINK_LIBS PUBLIC
+ MLIRLLVMToLLVMIRTranslation
MLIRTargetLLVMIRModuleTranslation
)
MLIRIR
MLIRLLVMArmNeon
MLIRLLVMIR
+ MLIRTargetLLVMIR
MLIRTargetLLVMIRModuleTranslation
)
MLIRIR
MLIRLLVMArmSVE
MLIRLLVMIR
+ MLIRTargetLLVMIR
MLIRTargetLLVMIRModuleTranslation
)
MLIRIR
MLIRLLVMIR
MLIRNVVMIR
+ MLIRTargetLLVMIR
MLIRTargetLLVMIRModuleTranslation
)
MLIRIR
MLIRLLVMIR
MLIRROCDLIR
+ MLIRTargetLLVMIR
MLIRTargetLLVMIRModuleTranslation
)
--- /dev/null
+add_subdirectory(Dialect)
#include "mlir/Target/LLVMIR.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
#include "mlir/Translation.h"
return llvmModule;
}
+void mlir::registerLLVMDialectTranslation(DialectRegistry ®istry) {
+ registry.insert<LLVM::LLVMDialect>();
+ registry.addDialectInterface<LLVM::LLVMDialect,
+ LLVMDialectLLVMIRTranslationInterface>();
+}
+
+void mlir::registerLLVMDialectTranslation(MLIRContext &context) {
+ auto *dialect = context.getLoadedDialect<LLVM::LLVMDialect>();
+ if (!dialect || dialect->getRegisteredInterface<
+ LLVMDialectLLVMIRTranslationInterface>() == nullptr) {
+ DialectRegistry registry;
+ registry.insert<LLVM::LLVMDialect>();
+ registry.addDialectInterface<LLVM::LLVMDialect,
+ LLVMDialectLLVMIRTranslationInterface>();
+ context.appendDialectRegistry(registry);
+ }
+}
+
namespace mlir {
void registerToLLVMIRTranslation() {
TranslateFromMLIRRegistration registration(
return success();
},
[](DialectRegistry ®istry) {
- registry.insert<LLVM::LLVMDialect, omp::OpenMPDialect>();
+ registry.insert<omp::OpenMPDialect>();
+ registerLLVMDialectTranslation(registry);
});
}
} // namespace mlir
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Target/LLVMIR.h"
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
#include "mlir/Translation.h"
std::unique_ptr<llvm::Module>
mlir::translateModuleToNVVMIR(Operation *m, llvm::LLVMContext &llvmContext,
StringRef name) {
+ // Register the translation to LLVM IR if nobody else did before. This may
+ // happen if this translation is called inside a pass pipeline that converts
+ // GPU dialects to binary blobs without translating the rest of the code.
+ registerLLVMDialectTranslation(*m->getContext());
+
auto llvmModule = LLVM::ModuleTranslation::translateModule<ModuleTranslation>(
m, llvmContext, name);
if (!llvmModule)
return success();
},
[](DialectRegistry ®istry) {
- registry.insert<LLVM::LLVMDialect, NVVM::NVVMDialect>();
+ registry.insert<NVVM::NVVMDialect>();
+ registerLLVMDialectTranslation(registry);
});
}
} // namespace mlir
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Target/LLVMIR.h"
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
#include "mlir/Translation.h"
std::unique_ptr<llvm::Module>
mlir::translateModuleToROCDLIR(Operation *m, llvm::LLVMContext &llvmContext,
StringRef name) {
- // lower MLIR (with RODL Dialect) to LLVM IR (with ROCDL intrinsics)
+ // Register the translation to LLVM IR if nobody else did before. This may
+ // happen if this translation is called inside a pass pipeline that converts
+ // GPU dialects to binary blobs without translating the rest of the code.
+ registerLLVMDialectTranslation(*m->getContext());
+
+ // Lower MLIR (with RODL Dialect) to LLVM IR (with ROCDL intrinsics).
auto llvmModule = LLVM::ModuleTranslation::translateModule<ModuleTranslation>(
m, llvmContext, name);
- // foreach GPU kernel
+ // Foreach GPU kernel:
// 1. Insert AMDGPU_KERNEL calling convention.
// 2. Insert amdgpu-flat-workgroup-size(1, 1024) attribute.
for (auto func :
return success();
},
[](DialectRegistry ®istry) {
- registry.insert<ROCDL::ROCDLDialect, LLVM::LLVMDialect>();
+ registry.insert<ROCDL::ROCDLDialect>();
+ registerLLVMDialectTranslation(registry);
});
}
} // namespace mlir
--- /dev/null
+add_subdirectory(LLVMIR)
--- /dev/null
+add_mlir_translation_library(MLIRLLVMToLLVMIRTranslation
+ LLVMToLLVMIRTranslation.cpp
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRLLVMIR
+ MLIRSupport
+ MLIRTargetLLVMIRModuleTranslation
+ )
--- /dev/null
+//===- LLVMToLLVMIRTranslation.cpp - Translate LLVM dialect to LLVM IR ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a translation between the MLIR LLVM dialect and LLVM IR.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Target/LLVMIR/ModuleTranslation.h"
+
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InlineAsm.h"
+#include "llvm/IR/MDBuilder.h"
+#include "llvm/IR/Operator.h"
+
+using namespace mlir;
+using namespace mlir::LLVM;
+
+#include "mlir/Dialect/LLVMIR/LLVMConversionEnumsToLLVM.inc"
+
+/// Convert MLIR integer comparison predicate to LLVM IR comparison predicate.
+static llvm::CmpInst::Predicate getLLVMCmpPredicate(ICmpPredicate p) {
+ switch (p) {
+ case LLVM::ICmpPredicate::eq:
+ return llvm::CmpInst::Predicate::ICMP_EQ;
+ case LLVM::ICmpPredicate::ne:
+ return llvm::CmpInst::Predicate::ICMP_NE;
+ case LLVM::ICmpPredicate::slt:
+ return llvm::CmpInst::Predicate::ICMP_SLT;
+ case LLVM::ICmpPredicate::sle:
+ return llvm::CmpInst::Predicate::ICMP_SLE;
+ case LLVM::ICmpPredicate::sgt:
+ return llvm::CmpInst::Predicate::ICMP_SGT;
+ case LLVM::ICmpPredicate::sge:
+ return llvm::CmpInst::Predicate::ICMP_SGE;
+ case LLVM::ICmpPredicate::ult:
+ return llvm::CmpInst::Predicate::ICMP_ULT;
+ case LLVM::ICmpPredicate::ule:
+ return llvm::CmpInst::Predicate::ICMP_ULE;
+ case LLVM::ICmpPredicate::ugt:
+ return llvm::CmpInst::Predicate::ICMP_UGT;
+ case LLVM::ICmpPredicate::uge:
+ return llvm::CmpInst::Predicate::ICMP_UGE;
+ }
+ llvm_unreachable("incorrect comparison predicate");
+}
+
+static llvm::CmpInst::Predicate getLLVMCmpPredicate(FCmpPredicate p) {
+ switch (p) {
+ case LLVM::FCmpPredicate::_false:
+ return llvm::CmpInst::Predicate::FCMP_FALSE;
+ case LLVM::FCmpPredicate::oeq:
+ return llvm::CmpInst::Predicate::FCMP_OEQ;
+ case LLVM::FCmpPredicate::ogt:
+ return llvm::CmpInst::Predicate::FCMP_OGT;
+ case LLVM::FCmpPredicate::oge:
+ return llvm::CmpInst::Predicate::FCMP_OGE;
+ case LLVM::FCmpPredicate::olt:
+ return llvm::CmpInst::Predicate::FCMP_OLT;
+ case LLVM::FCmpPredicate::ole:
+ return llvm::CmpInst::Predicate::FCMP_OLE;
+ case LLVM::FCmpPredicate::one:
+ return llvm::CmpInst::Predicate::FCMP_ONE;
+ case LLVM::FCmpPredicate::ord:
+ return llvm::CmpInst::Predicate::FCMP_ORD;
+ case LLVM::FCmpPredicate::ueq:
+ return llvm::CmpInst::Predicate::FCMP_UEQ;
+ case LLVM::FCmpPredicate::ugt:
+ return llvm::CmpInst::Predicate::FCMP_UGT;
+ case LLVM::FCmpPredicate::uge:
+ return llvm::CmpInst::Predicate::FCMP_UGE;
+ case LLVM::FCmpPredicate::ult:
+ return llvm::CmpInst::Predicate::FCMP_ULT;
+ case LLVM::FCmpPredicate::ule:
+ return llvm::CmpInst::Predicate::FCMP_ULE;
+ case LLVM::FCmpPredicate::une:
+ return llvm::CmpInst::Predicate::FCMP_UNE;
+ case LLVM::FCmpPredicate::uno:
+ return llvm::CmpInst::Predicate::FCMP_UNO;
+ case LLVM::FCmpPredicate::_true:
+ return llvm::CmpInst::Predicate::FCMP_TRUE;
+ }
+ llvm_unreachable("incorrect comparison predicate");
+}
+
+static llvm::AtomicRMWInst::BinOp getLLVMAtomicBinOp(AtomicBinOp op) {
+ switch (op) {
+ case LLVM::AtomicBinOp::xchg:
+ return llvm::AtomicRMWInst::BinOp::Xchg;
+ case LLVM::AtomicBinOp::add:
+ return llvm::AtomicRMWInst::BinOp::Add;
+ case LLVM::AtomicBinOp::sub:
+ return llvm::AtomicRMWInst::BinOp::Sub;
+ case LLVM::AtomicBinOp::_and:
+ return llvm::AtomicRMWInst::BinOp::And;
+ case LLVM::AtomicBinOp::nand:
+ return llvm::AtomicRMWInst::BinOp::Nand;
+ case LLVM::AtomicBinOp::_or:
+ return llvm::AtomicRMWInst::BinOp::Or;
+ case LLVM::AtomicBinOp::_xor:
+ return llvm::AtomicRMWInst::BinOp::Xor;
+ case LLVM::AtomicBinOp::max:
+ return llvm::AtomicRMWInst::BinOp::Max;
+ case LLVM::AtomicBinOp::min:
+ return llvm::AtomicRMWInst::BinOp::Min;
+ case LLVM::AtomicBinOp::umax:
+ return llvm::AtomicRMWInst::BinOp::UMax;
+ case LLVM::AtomicBinOp::umin:
+ return llvm::AtomicRMWInst::BinOp::UMin;
+ case LLVM::AtomicBinOp::fadd:
+ return llvm::AtomicRMWInst::BinOp::FAdd;
+ case LLVM::AtomicBinOp::fsub:
+ return llvm::AtomicRMWInst::BinOp::FSub;
+ }
+ llvm_unreachable("incorrect atomic binary operator");
+}
+
+static llvm::AtomicOrdering getLLVMAtomicOrdering(AtomicOrdering ordering) {
+ switch (ordering) {
+ case LLVM::AtomicOrdering::not_atomic:
+ return llvm::AtomicOrdering::NotAtomic;
+ case LLVM::AtomicOrdering::unordered:
+ return llvm::AtomicOrdering::Unordered;
+ case LLVM::AtomicOrdering::monotonic:
+ return llvm::AtomicOrdering::Monotonic;
+ case LLVM::AtomicOrdering::acquire:
+ return llvm::AtomicOrdering::Acquire;
+ case LLVM::AtomicOrdering::release:
+ return llvm::AtomicOrdering::Release;
+ case LLVM::AtomicOrdering::acq_rel:
+ return llvm::AtomicOrdering::AcquireRelease;
+ case LLVM::AtomicOrdering::seq_cst:
+ return llvm::AtomicOrdering::SequentiallyConsistent;
+ }
+ llvm_unreachable("incorrect atomic ordering");
+}
+
+static llvm::FastMathFlags getFastmathFlags(FastmathFlagsInterface &op) {
+ using llvmFMF = llvm::FastMathFlags;
+ using FuncT = void (llvmFMF::*)(bool);
+ const std::pair<FastmathFlags, FuncT> handlers[] = {
+ // clang-format off
+ {FastmathFlags::nnan, &llvmFMF::setNoNaNs},
+ {FastmathFlags::ninf, &llvmFMF::setNoInfs},
+ {FastmathFlags::nsz, &llvmFMF::setNoSignedZeros},
+ {FastmathFlags::arcp, &llvmFMF::setAllowReciprocal},
+ {FastmathFlags::contract, &llvmFMF::setAllowContract},
+ {FastmathFlags::afn, &llvmFMF::setApproxFunc},
+ {FastmathFlags::reassoc, &llvmFMF::setAllowReassoc},
+ {FastmathFlags::fast, &llvmFMF::setFast},
+ // clang-format on
+ };
+ llvm::FastMathFlags ret;
+ auto fmf = op.fastmathFlags();
+ for (auto it : handlers)
+ if (bitEnumContains(fmf, it.first))
+ (ret.*(it.second))(true);
+ return ret;
+}
+
+namespace {
+/// Dispatcher functional object targeting different overloads of
+/// ModuleTranslation::mapValue.
+// TODO: this is only necessary for compatibility with the code emitted from
+// ODS, remove when ODS is updated (after all dialects have migrated to the new
+// translation mechanism).
+struct MapValueDispatcher {
+ explicit MapValueDispatcher(ModuleTranslation &mt) : moduleTranslation(mt) {}
+
+ llvm::Value *&operator()(mlir::Value v) {
+ return moduleTranslation.mapValue(v);
+ }
+
+ void operator()(mlir::Value m, llvm::Value *l) {
+ moduleTranslation.mapValue(m, l);
+ }
+
+ LLVM::ModuleTranslation &moduleTranslation;
+};
+} // end namespace
+
+static LogicalResult
+convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ auto extractPosition = [](ArrayAttr attr) {
+ SmallVector<unsigned, 4> position;
+ position.reserve(attr.size());
+ for (Attribute v : attr)
+ position.push_back(v.cast<IntegerAttr>().getValue().getZExtValue());
+ return position;
+ };
+
+ llvm::IRBuilder<>::FastMathFlagGuard fmfGuard(builder);
+ if (auto fmf = dyn_cast<FastmathFlagsInterface>(opInst))
+ builder.setFastMathFlags(getFastmathFlags(fmf));
+
+ // TODO: these are necessary for compatibility with the code emitted from ODS,
+ // remove them when ODS is updated (after all dialects have migrated to the
+ // new translation mechanism).
+ MapValueDispatcher mapValue(moduleTranslation);
+ auto lookupValue = [&](mlir::Value v) {
+ return moduleTranslation.lookupValue(v);
+ };
+ auto convertType = [&](Type ty) { return moduleTranslation.convertType(ty); };
+ auto lookupValues = [&](ValueRange vs) {
+ return moduleTranslation.lookupValues(vs);
+ };
+ auto getLLVMConstant = [&](llvm::Type *ty, Attribute attr, Location loc) {
+ return moduleTranslation.getLLVMConstant(ty, attr, loc);
+ };
+
+#include "mlir/Dialect/LLVMIR/LLVMConversions.inc"
+
+ // Emit function calls. If the "callee" attribute is present, this is a
+ // direct function call and we also need to look up the remapped function
+ // itself. Otherwise, this is an indirect call and the callee is the first
+ // operand, look it up as a normal value. Return the llvm::Value representing
+ // the function result, which may be of llvm::VoidTy type.
+ auto convertCall = [&](Operation &op) -> llvm::Value * {
+ auto operands = moduleTranslation.lookupValues(op.getOperands());
+ ArrayRef<llvm::Value *> operandsRef(operands);
+ if (auto attr = op.getAttrOfType<FlatSymbolRefAttr>("callee"))
+ return builder.CreateCall(
+ moduleTranslation.lookupFunction(attr.getValue()), operandsRef);
+ auto *calleePtrType =
+ cast<llvm::PointerType>(operandsRef.front()->getType());
+ auto *calleeType =
+ cast<llvm::FunctionType>(calleePtrType->getElementType());
+ return builder.CreateCall(calleeType, operandsRef.front(),
+ operandsRef.drop_front());
+ };
+
+ // Emit calls. If the called function has a result, remap the corresponding
+ // value. Note that LLVM IR dialect CallOp has either 0 or 1 result.
+ if (isa<LLVM::CallOp>(opInst)) {
+ llvm::Value *result = convertCall(opInst);
+ if (opInst.getNumResults() != 0) {
+ mapValue(opInst.getResult(0), result);
+ return success();
+ }
+ // Check that LLVM call returns void for 0-result functions.
+ return success(result->getType()->isVoidTy());
+ }
+
+ if (auto inlineAsmOp = dyn_cast<LLVM::InlineAsmOp>(opInst)) {
+ // TODO: refactor function type creation which usually occurs in std-LLVM
+ // conversion.
+ SmallVector<Type, 8> operandTypes;
+ operandTypes.reserve(inlineAsmOp.operands().size());
+ for (auto t : inlineAsmOp.operands().getTypes())
+ operandTypes.push_back(t);
+
+ Type resultType;
+ if (inlineAsmOp.getNumResults() == 0) {
+ resultType = LLVM::LLVMVoidType::get(&moduleTranslation.getContext());
+ } else {
+ assert(inlineAsmOp.getNumResults() == 1);
+ resultType = inlineAsmOp.getResultTypes()[0];
+ }
+ auto ft = LLVM::LLVMFunctionType::get(resultType, operandTypes);
+ llvm::InlineAsm *inlineAsmInst =
+ inlineAsmOp.asm_dialect().hasValue()
+ ? llvm::InlineAsm::get(
+ static_cast<llvm::FunctionType *>(convertType(ft)),
+ inlineAsmOp.asm_string(), inlineAsmOp.constraints(),
+ inlineAsmOp.has_side_effects(), inlineAsmOp.is_align_stack(),
+ convertAsmDialectToLLVM(*inlineAsmOp.asm_dialect()))
+ : llvm::InlineAsm::get(
+ static_cast<llvm::FunctionType *>(convertType(ft)),
+ inlineAsmOp.asm_string(), inlineAsmOp.constraints(),
+ inlineAsmOp.has_side_effects(), inlineAsmOp.is_align_stack());
+ llvm::Value *result =
+ builder.CreateCall(inlineAsmInst, lookupValues(inlineAsmOp.operands()));
+ if (opInst.getNumResults() != 0)
+ mapValue(opInst.getResult(0), result);
+ return success();
+ }
+
+ if (auto invOp = dyn_cast<LLVM::InvokeOp>(opInst)) {
+ auto operands = lookupValues(opInst.getOperands());
+ ArrayRef<llvm::Value *> operandsRef(operands);
+ if (auto attr = opInst.getAttrOfType<FlatSymbolRefAttr>("callee")) {
+ builder.CreateInvoke(moduleTranslation.lookupFunction(attr.getValue()),
+ moduleTranslation.lookupBlock(invOp.getSuccessor(0)),
+ moduleTranslation.lookupBlock(invOp.getSuccessor(1)),
+ operandsRef);
+ } else {
+ auto *calleePtrType =
+ cast<llvm::PointerType>(operandsRef.front()->getType());
+ auto *calleeType =
+ cast<llvm::FunctionType>(calleePtrType->getElementType());
+ builder.CreateInvoke(calleeType, operandsRef.front(),
+ moduleTranslation.lookupBlock(invOp.getSuccessor(0)),
+ moduleTranslation.lookupBlock(invOp.getSuccessor(1)),
+ operandsRef.drop_front());
+ }
+ return success();
+ }
+
+ if (auto lpOp = dyn_cast<LLVM::LandingpadOp>(opInst)) {
+ llvm::Type *ty = convertType(lpOp.getType());
+ llvm::LandingPadInst *lpi =
+ builder.CreateLandingPad(ty, lpOp.getNumOperands());
+
+ // Add clauses
+ for (llvm::Value *operand : lookupValues(lpOp.getOperands())) {
+ // All operands should be constant - checked by verifier
+ if (auto *constOperand = dyn_cast<llvm::Constant>(operand))
+ lpi->addClause(constOperand);
+ }
+ mapValue(lpOp.getResult(), lpi);
+ return success();
+ }
+
+ // Emit branches. We need to look up the remapped blocks and ignore the block
+ // arguments that were transformed into PHI nodes.
+ if (auto brOp = dyn_cast<LLVM::BrOp>(opInst)) {
+ llvm::BranchInst *branch =
+ builder.CreateBr(moduleTranslation.lookupBlock(brOp.getSuccessor()));
+ moduleTranslation.mapBranch(&opInst, branch);
+ return success();
+ }
+ if (auto condbrOp = dyn_cast<LLVM::CondBrOp>(opInst)) {
+ auto weights = condbrOp.branch_weights();
+ llvm::MDNode *branchWeights = nullptr;
+ if (weights) {
+ // Map weight attributes to LLVM metadata.
+ auto trueWeight =
+ weights.getValue().getValue(0).cast<IntegerAttr>().getInt();
+ auto falseWeight =
+ weights.getValue().getValue(1).cast<IntegerAttr>().getInt();
+ branchWeights =
+ llvm::MDBuilder(moduleTranslation.getLLVMContext())
+ .createBranchWeights(static_cast<uint32_t>(trueWeight),
+ static_cast<uint32_t>(falseWeight));
+ }
+ llvm::BranchInst *branch = builder.CreateCondBr(
+ moduleTranslation.lookupValue(condbrOp.getOperand(0)),
+ moduleTranslation.lookupBlock(condbrOp.getSuccessor(0)),
+ moduleTranslation.lookupBlock(condbrOp.getSuccessor(1)), branchWeights);
+ moduleTranslation.mapBranch(&opInst, branch);
+ return success();
+ }
+ if (auto switchOp = dyn_cast<LLVM::SwitchOp>(opInst)) {
+ llvm::MDNode *branchWeights = nullptr;
+ if (auto weights = switchOp.branch_weights()) {
+ llvm::SmallVector<uint32_t> weightValues;
+ weightValues.reserve(weights->size());
+ for (llvm::APInt weight : weights->cast<DenseIntElementsAttr>())
+ weightValues.push_back(weight.getLimitedValue());
+ branchWeights = llvm::MDBuilder(moduleTranslation.getLLVMContext())
+ .createBranchWeights(weightValues);
+ }
+
+ llvm::SwitchInst *switchInst = builder.CreateSwitch(
+ moduleTranslation.lookupValue(switchOp.value()),
+ moduleTranslation.lookupBlock(switchOp.defaultDestination()),
+ switchOp.caseDestinations().size(), branchWeights);
+
+ auto *ty =
+ llvm::cast<llvm::IntegerType>(convertType(switchOp.value().getType()));
+ for (auto i :
+ llvm::zip(switchOp.case_values()->cast<DenseIntElementsAttr>(),
+ switchOp.caseDestinations()))
+ switchInst->addCase(
+ llvm::ConstantInt::get(ty, std::get<0>(i).getLimitedValue()),
+ moduleTranslation.lookupBlock(std::get<1>(i)));
+
+ moduleTranslation.mapBranch(&opInst, switchInst);
+ return success();
+ }
+
+ // Emit addressof. We need to look up the global value referenced by the
+ // operation and store it in the MLIR-to-LLVM value mapping. This does not
+ // emit any LLVM instruction.
+ if (auto addressOfOp = dyn_cast<LLVM::AddressOfOp>(opInst)) {
+ LLVM::GlobalOp global = addressOfOp.getGlobal();
+ LLVM::LLVMFuncOp function = addressOfOp.getFunction();
+
+ // The verifier should not have allowed this.
+ assert((global || function) &&
+ "referencing an undefined global or function");
+
+ mapValue(addressOfOp.getResult(),
+ global ? moduleTranslation.lookupGlobal(global)
+ : moduleTranslation.lookupFunction(function.getName()));
+ return success();
+ }
+
+ return failure();
+}
+
+LogicalResult mlir::LLVMDialectLLVMIRTranslationInterface::convertOperation(
+ Operation *op, llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation) const {
+ return convertOperationImpl(*op, builder, moduleTranslation);
+}
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h"
+#include "mlir/Target/LLVMIR.h"
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
#include "mlir/Translation.h"
#include "llvm/IR/IntrinsicsX86.h"
return success();
},
[](DialectRegistry ®istry) {
- registry.insert<LLVM::LLVMAVX512Dialect, LLVM::LLVMDialect>();
+ registry.insert<LLVM::LLVMAVX512Dialect>();
+ registerLLVMDialectTranslation(registry);
});
}
} // namespace mlir
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/LLVMIR/LLVMArmNeonDialect.h"
+#include "mlir/Target/LLVMIR.h"
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
#include "mlir/Translation.h"
#include "llvm/IR/IntrinsicsAArch64.h"
return success();
},
[](DialectRegistry ®istry) {
- registry.insert<LLVM::LLVMArmNeonDialect, LLVM::LLVMDialect>();
+ registry.insert<LLVM::LLVMArmNeonDialect>();
+ registerLLVMDialectTranslation(registry);
});
}
} // namespace mlir
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h"
+#include "mlir/Target/LLVMIR.h"
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
#include "mlir/Translation.h"
#include "llvm/IR/IntrinsicsAArch64.h"
return success();
},
[](DialectRegistry ®istry) {
- registry.insert<LLVM::LLVMArmSVEDialect, LLVM::LLVMDialect>();
+ registry.insert<LLVM::LLVMArmSVEDialect>();
+ registerLLVMDialectTranslation(registry);
});
}
} // namespace mlir
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/RegionGraphTraits.h"
#include "mlir/Support/LLVM.h"
+#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
#include "mlir/Target/LLVMIR/TypeTranslation.h"
#include "llvm/ADT/TypeSwitch.h"
return nullptr;
}
-/// Convert MLIR integer comparison predicate to LLVM IR comparison predicate.
-static llvm::CmpInst::Predicate getLLVMCmpPredicate(ICmpPredicate p) {
- switch (p) {
- case LLVM::ICmpPredicate::eq:
- return llvm::CmpInst::Predicate::ICMP_EQ;
- case LLVM::ICmpPredicate::ne:
- return llvm::CmpInst::Predicate::ICMP_NE;
- case LLVM::ICmpPredicate::slt:
- return llvm::CmpInst::Predicate::ICMP_SLT;
- case LLVM::ICmpPredicate::sle:
- return llvm::CmpInst::Predicate::ICMP_SLE;
- case LLVM::ICmpPredicate::sgt:
- return llvm::CmpInst::Predicate::ICMP_SGT;
- case LLVM::ICmpPredicate::sge:
- return llvm::CmpInst::Predicate::ICMP_SGE;
- case LLVM::ICmpPredicate::ult:
- return llvm::CmpInst::Predicate::ICMP_ULT;
- case LLVM::ICmpPredicate::ule:
- return llvm::CmpInst::Predicate::ICMP_ULE;
- case LLVM::ICmpPredicate::ugt:
- return llvm::CmpInst::Predicate::ICMP_UGT;
- case LLVM::ICmpPredicate::uge:
- return llvm::CmpInst::Predicate::ICMP_UGE;
- }
- llvm_unreachable("incorrect comparison predicate");
-}
-
-static llvm::CmpInst::Predicate getLLVMCmpPredicate(FCmpPredicate p) {
- switch (p) {
- case LLVM::FCmpPredicate::_false:
- return llvm::CmpInst::Predicate::FCMP_FALSE;
- case LLVM::FCmpPredicate::oeq:
- return llvm::CmpInst::Predicate::FCMP_OEQ;
- case LLVM::FCmpPredicate::ogt:
- return llvm::CmpInst::Predicate::FCMP_OGT;
- case LLVM::FCmpPredicate::oge:
- return llvm::CmpInst::Predicate::FCMP_OGE;
- case LLVM::FCmpPredicate::olt:
- return llvm::CmpInst::Predicate::FCMP_OLT;
- case LLVM::FCmpPredicate::ole:
- return llvm::CmpInst::Predicate::FCMP_OLE;
- case LLVM::FCmpPredicate::one:
- return llvm::CmpInst::Predicate::FCMP_ONE;
- case LLVM::FCmpPredicate::ord:
- return llvm::CmpInst::Predicate::FCMP_ORD;
- case LLVM::FCmpPredicate::ueq:
- return llvm::CmpInst::Predicate::FCMP_UEQ;
- case LLVM::FCmpPredicate::ugt:
- return llvm::CmpInst::Predicate::FCMP_UGT;
- case LLVM::FCmpPredicate::uge:
- return llvm::CmpInst::Predicate::FCMP_UGE;
- case LLVM::FCmpPredicate::ult:
- return llvm::CmpInst::Predicate::FCMP_ULT;
- case LLVM::FCmpPredicate::ule:
- return llvm::CmpInst::Predicate::FCMP_ULE;
- case LLVM::FCmpPredicate::une:
- return llvm::CmpInst::Predicate::FCMP_UNE;
- case LLVM::FCmpPredicate::uno:
- return llvm::CmpInst::Predicate::FCMP_UNO;
- case LLVM::FCmpPredicate::_true:
- return llvm::CmpInst::Predicate::FCMP_TRUE;
- }
- llvm_unreachable("incorrect comparison predicate");
-}
-
-static llvm::AtomicRMWInst::BinOp getLLVMAtomicBinOp(AtomicBinOp op) {
- switch (op) {
- case LLVM::AtomicBinOp::xchg:
- return llvm::AtomicRMWInst::BinOp::Xchg;
- case LLVM::AtomicBinOp::add:
- return llvm::AtomicRMWInst::BinOp::Add;
- case LLVM::AtomicBinOp::sub:
- return llvm::AtomicRMWInst::BinOp::Sub;
- case LLVM::AtomicBinOp::_and:
- return llvm::AtomicRMWInst::BinOp::And;
- case LLVM::AtomicBinOp::nand:
- return llvm::AtomicRMWInst::BinOp::Nand;
- case LLVM::AtomicBinOp::_or:
- return llvm::AtomicRMWInst::BinOp::Or;
- case LLVM::AtomicBinOp::_xor:
- return llvm::AtomicRMWInst::BinOp::Xor;
- case LLVM::AtomicBinOp::max:
- return llvm::AtomicRMWInst::BinOp::Max;
- case LLVM::AtomicBinOp::min:
- return llvm::AtomicRMWInst::BinOp::Min;
- case LLVM::AtomicBinOp::umax:
- return llvm::AtomicRMWInst::BinOp::UMax;
- case LLVM::AtomicBinOp::umin:
- return llvm::AtomicRMWInst::BinOp::UMin;
- case LLVM::AtomicBinOp::fadd:
- return llvm::AtomicRMWInst::BinOp::FAdd;
- case LLVM::AtomicBinOp::fsub:
- return llvm::AtomicRMWInst::BinOp::FSub;
- }
- llvm_unreachable("incorrect atomic binary operator");
-}
-
-static llvm::AtomicOrdering getLLVMAtomicOrdering(AtomicOrdering ordering) {
- switch (ordering) {
- case LLVM::AtomicOrdering::not_atomic:
- return llvm::AtomicOrdering::NotAtomic;
- case LLVM::AtomicOrdering::unordered:
- return llvm::AtomicOrdering::Unordered;
- case LLVM::AtomicOrdering::monotonic:
- return llvm::AtomicOrdering::Monotonic;
- case LLVM::AtomicOrdering::acquire:
- return llvm::AtomicOrdering::Acquire;
- case LLVM::AtomicOrdering::release:
- return llvm::AtomicOrdering::Release;
- case LLVM::AtomicOrdering::acq_rel:
- return llvm::AtomicOrdering::AcquireRelease;
- case LLVM::AtomicOrdering::seq_cst:
- return llvm::AtomicOrdering::SequentiallyConsistent;
- }
- llvm_unreachable("incorrect atomic ordering");
-}
-
ModuleTranslation::ModuleTranslation(Operation *module,
std::unique_ptr<llvm::Module> llvmModule)
: mlirModule(module), llvmModule(std::move(llvmModule)),
debugTranslation(
std::make_unique<DebugTranslation>(module, *this->llvmModule)),
ompDialect(module->getContext()->getLoadedDialect("omp")),
- typeTranslator(this->llvmModule->getContext()) {
+ typeTranslator(this->llvmModule->getContext()),
+ iface(module->getContext()) {
assert(satisfiesLLVMModule(mlirModule) &&
"mlirModule should honor LLVM's module semantics.");
}
});
}
-static llvm::FastMathFlags getFastmathFlags(FastmathFlagsInterface &op) {
- using llvmFMF = llvm::FastMathFlags;
- using FuncT = void (llvmFMF::*)(bool);
- const std::pair<FastmathFlags, FuncT> handlers[] = {
- // clang-format off
- {FastmathFlags::nnan, &llvmFMF::setNoNaNs},
- {FastmathFlags::ninf, &llvmFMF::setNoInfs},
- {FastmathFlags::nsz, &llvmFMF::setNoSignedZeros},
- {FastmathFlags::arcp, &llvmFMF::setAllowReciprocal},
- {FastmathFlags::contract, &llvmFMF::setAllowContract},
- {FastmathFlags::afn, &llvmFMF::setApproxFunc},
- {FastmathFlags::reassoc, &llvmFMF::setAllowReassoc},
- {FastmathFlags::fast, &llvmFMF::setFast},
- // clang-format on
- };
- llvm::FastMathFlags ret;
- auto fmf = op.fastmathFlags();
- for (auto it : handlers)
- if (bitEnumContains(fmf, it.first))
- (ret.*(it.second))(true);
- return ret;
-}
-
/// Given a single MLIR operation, create the corresponding LLVM IR operation
/// using the `builder`. LLVM IR Builder does not have a generic interface so
/// this has to be a long chain of `if`s calling different functions with a
/// different number of arguments.
LogicalResult ModuleTranslation::convertOperation(Operation &opInst,
llvm::IRBuilder<> &builder) {
- auto extractPosition = [](ArrayAttr attr) {
- SmallVector<unsigned, 4> position;
- position.reserve(attr.size());
- for (Attribute v : attr)
- position.push_back(v.cast<IntegerAttr>().getValue().getZExtValue());
- return position;
- };
-
- llvm::IRBuilder<>::FastMathFlagGuard fmfGuard(builder);
- if (auto fmf = dyn_cast<FastmathFlagsInterface>(opInst))
- builder.setFastMathFlags(getFastmathFlags(fmf));
-
-#include "mlir/Dialect/LLVMIR/LLVMConversions.inc"
-
- // Emit function calls. If the "callee" attribute is present, this is a
- // direct function call and we also need to look up the remapped function
- // itself. Otherwise, this is an indirect call and the callee is the first
- // operand, look it up as a normal value. Return the llvm::Value representing
- // the function result, which may be of llvm::VoidTy type.
- auto convertCall = [this, &builder](Operation &op) -> llvm::Value * {
- auto operands = lookupValues(op.getOperands());
- ArrayRef<llvm::Value *> operandsRef(operands);
- if (auto attr = op.getAttrOfType<FlatSymbolRefAttr>("callee"))
- return builder.CreateCall(lookupFunction(attr.getValue()), operandsRef);
- auto *calleePtrType =
- cast<llvm::PointerType>(operandsRef.front()->getType());
- auto *calleeType =
- cast<llvm::FunctionType>(calleePtrType->getElementType());
- return builder.CreateCall(calleeType, operandsRef.front(),
- operandsRef.drop_front());
- };
-
- // Emit calls. If the called function has a result, remap the corresponding
- // value. Note that LLVM IR dialect CallOp has either 0 or 1 result.
- if (isa<LLVM::CallOp>(opInst)) {
- llvm::Value *result = convertCall(opInst);
- if (opInst.getNumResults() != 0) {
- mapValue(opInst.getResult(0), result);
- return success();
- }
- // Check that LLVM call returns void for 0-result functions.
- return success(result->getType()->isVoidTy());
- }
- if (auto inlineAsmOp = dyn_cast<LLVM::InlineAsmOp>(opInst)) {
- // TODO: refactor function type creation which usually occurs in std-LLVM
- // conversion.
- SmallVector<Type, 8> operandTypes;
- operandTypes.reserve(inlineAsmOp.operands().size());
- for (auto t : inlineAsmOp.operands().getTypes())
- operandTypes.push_back(t);
-
- Type resultType;
- if (inlineAsmOp.getNumResults() == 0) {
- resultType = LLVM::LLVMVoidType::get(mlirModule->getContext());
- } else {
- assert(inlineAsmOp.getNumResults() == 1);
- resultType = inlineAsmOp.getResultTypes()[0];
- }
- auto ft = LLVM::LLVMFunctionType::get(resultType, operandTypes);
- llvm::InlineAsm *inlineAsmInst =
- inlineAsmOp.asm_dialect().hasValue()
- ? llvm::InlineAsm::get(
- static_cast<llvm::FunctionType *>(convertType(ft)),
- inlineAsmOp.asm_string(), inlineAsmOp.constraints(),
- inlineAsmOp.has_side_effects(), inlineAsmOp.is_align_stack(),
- convertAsmDialectToLLVM(*inlineAsmOp.asm_dialect()))
- : llvm::InlineAsm::get(
- static_cast<llvm::FunctionType *>(convertType(ft)),
- inlineAsmOp.asm_string(), inlineAsmOp.constraints(),
- inlineAsmOp.has_side_effects(), inlineAsmOp.is_align_stack());
- llvm::Value *result =
- builder.CreateCall(inlineAsmInst, lookupValues(inlineAsmOp.operands()));
- if (opInst.getNumResults() != 0)
- mapValue(opInst.getResult(0), result);
+ // TODO(zinenko): this should be the "main" conversion here, remove the
+ // dispatch below.
+ if (succeeded(iface.convertOperation(&opInst, builder, *this)))
return success();
- }
-
- if (auto invOp = dyn_cast<LLVM::InvokeOp>(opInst)) {
- auto operands = lookupValues(opInst.getOperands());
- ArrayRef<llvm::Value *> operandsRef(operands);
- if (auto attr = opInst.getAttrOfType<FlatSymbolRefAttr>("callee")) {
- builder.CreateInvoke(lookupFunction(attr.getValue()),
- lookupBlock(invOp.getSuccessor(0)),
- lookupBlock(invOp.getSuccessor(1)), operandsRef);
- } else {
- auto *calleePtrType =
- cast<llvm::PointerType>(operandsRef.front()->getType());
- auto *calleeType =
- cast<llvm::FunctionType>(calleePtrType->getElementType());
- builder.CreateInvoke(
- calleeType, operandsRef.front(), lookupBlock(invOp.getSuccessor(0)),
- lookupBlock(invOp.getSuccessor(1)), operandsRef.drop_front());
- }
- return success();
- }
-
- if (auto lpOp = dyn_cast<LLVM::LandingpadOp>(opInst)) {
- llvm::Type *ty = convertType(lpOp.getType());
- llvm::LandingPadInst *lpi =
- builder.CreateLandingPad(ty, lpOp.getNumOperands());
-
- // Add clauses
- for (llvm::Value *operand : lookupValues(lpOp.getOperands())) {
- // All operands should be constant - checked by verifier
- if (auto *constOperand = dyn_cast<llvm::Constant>(operand))
- lpi->addClause(constOperand);
- }
- mapValue(lpOp.getResult(), lpi);
- return success();
- }
-
- // Emit branches. We need to look up the remapped blocks and ignore the block
- // arguments that were transformed into PHI nodes.
- if (auto brOp = dyn_cast<LLVM::BrOp>(opInst)) {
- llvm::BranchInst *branch =
- builder.CreateBr(lookupBlock(brOp.getSuccessor()));
- mapBranch(&opInst, branch);
- return success();
- }
- if (auto condbrOp = dyn_cast<LLVM::CondBrOp>(opInst)) {
- auto weights = condbrOp.branch_weights();
- llvm::MDNode *branchWeights = nullptr;
- if (weights) {
- // Map weight attributes to LLVM metadata.
- auto trueWeight =
- weights.getValue().getValue(0).cast<IntegerAttr>().getInt();
- auto falseWeight =
- weights.getValue().getValue(1).cast<IntegerAttr>().getInt();
- branchWeights =
- llvm::MDBuilder(llvmModule->getContext())
- .createBranchWeights(static_cast<uint32_t>(trueWeight),
- static_cast<uint32_t>(falseWeight));
- }
- llvm::BranchInst *branch = builder.CreateCondBr(
- lookupValue(condbrOp.getOperand(0)),
- lookupBlock(condbrOp.getSuccessor(0)),
- lookupBlock(condbrOp.getSuccessor(1)), branchWeights);
- mapBranch(&opInst, branch);
- return success();
- }
- if (auto switchOp = dyn_cast<LLVM::SwitchOp>(opInst)) {
- llvm::MDNode *branchWeights = nullptr;
- if (auto weights = switchOp.branch_weights()) {
- llvm::SmallVector<uint32_t> weightValues;
- weightValues.reserve(weights->size());
- for (llvm::APInt weight : weights->cast<DenseIntElementsAttr>())
- weightValues.push_back(weight.getLimitedValue());
- branchWeights = llvm::MDBuilder(llvmModule->getContext())
- .createBranchWeights(weightValues);
- }
-
- llvm::SwitchInst *switchInst =
- builder.CreateSwitch(lookupValue(switchOp.value()),
- lookupBlock(switchOp.defaultDestination()),
- switchOp.caseDestinations().size(), branchWeights);
-
- auto *ty =
- llvm::cast<llvm::IntegerType>(convertType(switchOp.value().getType()));
- for (auto i :
- llvm::zip(switchOp.case_values()->cast<DenseIntElementsAttr>(),
- switchOp.caseDestinations()))
- switchInst->addCase(
- llvm::ConstantInt::get(ty, std::get<0>(i).getLimitedValue()),
- lookupBlock(std::get<1>(i)));
-
- mapBranch(&opInst, switchInst);
- return success();
- }
-
- // Emit addressof. We need to look up the global value referenced by the
- // operation and store it in the MLIR-to-LLVM value mapping. This does not
- // emit any LLVM instruction.
- if (auto addressOfOp = dyn_cast<LLVM::AddressOfOp>(opInst)) {
- LLVM::GlobalOp global = addressOfOp.getGlobal();
- LLVM::LLVMFuncOp function = addressOfOp.getFunction();
-
- // The verifier should not have allowed this.
- assert((global || function) &&
- "referencing an undefined global or function");
-
- mapValue(addressOfOp.getResult(), global
- ? globalsMapping.lookup(global)
- : lookupFunction(function.getName()));
- return success();
- }
if (ompDialect && opInst.getDialect() == ompDialect)
return convertOmpOperation(opInst, builder);
#include "mlir/ExecutionEngine/JitRunner.h"
#include "mlir/ExecutionEngine/OptUtils.h"
#include "mlir/IR/Dialect.h"
+#include "mlir/Target/LLVMIR.h"
+
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/TargetSelect.h"
mlir::DialectRegistry registry;
registry.insert<mlir::LLVM::LLVMDialect, mlir::omp::OpenMPDialect>();
+ mlir::registerLLVMDialectTranslation(registry);
+
return mlir::JitRunnerMain(argc, argv, registry);
}
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
+#include "mlir/Target/LLVMIR.h"
#include "mlir/Target/NVVMIR.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/Passes.h"
+
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/TargetSelect.h"
registry.insert<mlir::LLVM::LLVMDialect, mlir::NVVM::NVVMDialect,
mlir::async::AsyncDialect, mlir::gpu::GPUDialect,
mlir::StandardOpsDialect>();
+ mlir::registerLLVMDialectTranslation(registry);
return mlir::JitRunnerMain(argc, argv, registry, jitRunnerConfig);
}
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/FileUtilities.h"
+#include "mlir/Target/LLVMIR.h"
#include "mlir/Target/ROCDLIR.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/Passes.h"
mlir::DialectRegistry registry;
registry.insert<mlir::LLVM::LLVMDialect, mlir::gpu::GPUDialect,
mlir::ROCDL::ROCDLDialect, mlir::StandardOpsDialect>();
+ mlir::registerLLVMDialectTranslation(registry);
return mlir::JitRunnerMain(argc, argv, registry, jitRunnerConfig);
}
mlir::DialectRegistry registry;
registry.insert<mlir::LLVM::LLVMDialect, mlir::gpu::GPUDialect,
mlir::spirv::SPIRVDialect, mlir::StandardOpsDialect>();
+ mlir::registerLLVMDialectTranslation(registry);
return mlir::JitRunnerMain(argc, argv, registry, jitRunnerConfig);
}
#include "mlir/ExecutionEngine/OptUtils.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
+#include "mlir/Target/LLVMIR.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/TargetSelect.h"
mlir::DialectRegistry registry;
registry.insert<mlir::LLVM::LLVMDialect, mlir::gpu::GPUDialect,
mlir::spirv::SPIRVDialect, mlir::StandardOpsDialect>();
+ mlir::registerLLVMDialectTranslation(registry);
return mlir::JitRunnerMain(argc, argv, registry, jitRunnerConfig);
}
#include "mlir/InitAllDialects.h"
#include "mlir/Parser.h"
#include "mlir/Pass/PassManager.h"
+#include "mlir/Target/LLVMIR.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/raw_ostream.h"
return %res : i32
}
)mlir";
- MLIRContext context;
- registerAllDialects(context);
+ DialectRegistry registry;
+ registerAllDialects(registry);
+ registerLLVMDialectTranslation(registry);
+ MLIRContext context(registry);
OwningModuleRef module = parseSourceString(moduleStr, &context);
ASSERT_TRUE(!!module);
ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
return %res : f32
}
)mlir";
- MLIRContext context;
- registerAllDialects(context);
+ DialectRegistry registry;
+ registerAllDialects(registry);
+ registerLLVMDialectTranslation(registry);
+ MLIRContext context(registry);
OwningModuleRef module = parseSourceString(moduleStr, &context);
ASSERT_TRUE(!!module);
ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
return
}
)mlir";
- MLIRContext context;
- registerAllDialects(context);
+ DialectRegistry registry;
+ registerAllDialects(registry);
+ registerLLVMDialectTranslation(registry);
+ MLIRContext context(registry);
auto module = parseSourceString(moduleStr, &context);
ASSERT_TRUE(!!module);
ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
return
}
)mlir";
- MLIRContext context;
- registerAllDialects(context);
+ DialectRegistry registry;
+ registerAllDialects(registry);
+ registerLLVMDialectTranslation(registry);
+ MLIRContext context(registry);
auto module = parseSourceString(moduleStr, &context);
ASSERT_TRUE(!!module);
ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
return
}
)mlir";
- MLIRContext context;
- registerAllDialects(context);
+ DialectRegistry registry;
+ registerAllDialects(registry);
+ registerLLVMDialectTranslation(registry);
+ MLIRContext context(registry);
OwningModuleRef module = parseSourceString(moduleStr, &context);
ASSERT_TRUE(!!module);
ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
return
}
)mlir";
- MLIRContext context;
- registerAllDialects(context);
+ DialectRegistry registry;
+ registerAllDialects(registry);
+ registerLLVMDialectTranslation(registry);
+ MLIRContext context(registry);
auto module = parseSourceString(moduleStr, &context);
ASSERT_TRUE(!!module);
ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));