StringRef value, LLVM::Linkage linkage,
LLVM::LLVMDialect *llvmDialect);
+/// LLVM requires some operations to be inside of a Module operation. This
+/// function confirms that the Operation has the desired properties.
+bool satisfiesLLVMModule(Operation *op);
+
} // end namespace LLVM
} // end namespace mlir
#ifndef MLIR_TARGET_LLVMIR_MODULETRANSLATION_H
#define MLIR_TARGET_LLVMIR_MODULETRANSLATION_H
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/Value.h"
class ModuleTranslation {
public:
template <typename T = ModuleTranslation>
- static std::unique_ptr<llvm::Module> translateModule(ModuleOp m) {
+ static std::unique_ptr<llvm::Module> translateModule(Operation *m) {
+ if (!satisfiesLLVMModule(m))
+ return nullptr;
if (failed(checkSupportedModuleOps(m)))
return nullptr;
auto llvmModule = prepareLLVMModule(m);
return std::move(translator.llvmModule);
}
+ /// A helper method to get the single Block in an operation honoring LLVM's
+ /// module requirements.
+ static Block &getModuleBody(Operation *m) { return m->getRegion(0).front(); }
+
protected:
// Translate the given MLIR module expressed in MLIR LLVM IR dialect into an
// LLVM IR module. The MLIR LLVM IR dialect holds a pointer to an
// LLVMContext, the LLVM IR module will be created in that context.
- explicit ModuleTranslation(ModuleOp module) : mlirModule(module) {}
+ explicit ModuleTranslation(Operation *module) : mlirModule(module) {
+ assert(satisfiesLLVMModule(mlirModule) &&
+ "mlirModule should honor LLVM's module semantics.");
+ }
virtual ~ModuleTranslation() {}
virtual LogicalResult convertOperation(Operation &op,
llvm::IRBuilder<> &builder);
- static std::unique_ptr<llvm::Module> prepareLLVMModule(ModuleOp m);
+ static std::unique_ptr<llvm::Module> prepareLLVMModule(Operation *m);
template <typename Range>
SmallVector<llvm::Value *, 8> lookupValues(Range &&values);
private:
/// Check whether the module contains only supported ops directly in its body.
- static LogicalResult checkSupportedModuleOps(ModuleOp m);
+ static LogicalResult checkSupportedModuleOps(Operation *m);
LogicalResult convertFunctions();
void convertGlobals();
Location loc);
// Original and translated module.
- ModuleOp mlirModule;
+ Operation *mlirModule;
std::unique_ptr<llvm::Module> llvmModule;
// Mappings between llvm.mlir.global definitions and corresponding globals.
} // namespace llvm
namespace mlir {
-class ModuleOp;
+class Operation;
-/// Convert the given MLIR module into NVVM IR. This conversion requires the
-/// registration of the LLVM IR dialect and will extract the LLVM context
-/// from the registered LLVM IR dialect. In case of error, report it
-/// to the error handler registered with the MLIR context, if any (obtained from
+/// Convert the given LLVM-module-like operation into NVVM IR. This conversion
+/// requires the registration of the LLVM IR dialect and will extract the LLVM
+/// context from the registered LLVM IR dialect. In case of error, report it to
+/// the error handler registered with the MLIR context, if any (obtained from
/// the MLIR module), and return `nullptr`.
-std::unique_ptr<llvm::Module> translateModuleToNVVMIR(ModuleOp m);
+std::unique_ptr<llvm::Module> translateModuleToNVVMIR(Operation *m);
} // namespace mlir
} // namespace llvm
namespace mlir {
-class ModuleOp;
+class Operation;
-/// Convert the given MLIR module into ROCDL IR. This conversion requires the
-/// registration of the LLVM IR dialect and will extract the LLVM context
-/// from the registered LLVM IR dialect. In case of error, report it
-/// to the error handler registered with the MLIR context, if any (obtained from
+/// Convert the given LLVM-module-like operation into ROCDL IR. This conversion
+/// requires the registration of the LLVM IR dialect and will extract the LLVM
+/// context from the registered LLVM IR dialect. In case of error, report it to
+/// the error handler registered with the MLIR context, if any (obtained from
/// the MLIR module), and return `nullptr`.
-std::unique_ptr<llvm::Module> translateModuleToROCDLIR(ModuleOp m);
+std::unique_ptr<llvm::Module> translateModuleToROCDLIR(Operation *m);
} // namespace mlir
//===----------------------------------------------------------------------===//
GlobalOp AddressOfOp::getGlobal() {
- auto module = getParentOfType<ModuleOp>();
+ Operation *module = getParentOp();
+ while (module && !satisfiesLLVMModule(module))
+ module = module->getParentOp();
assert(module && "unexpected operation outside of a module");
- return module.lookupSymbol<LLVM::GlobalOp>(global_name());
+ return dyn_cast_or_null<LLVM::GlobalOp>(
+ mlir::SymbolTable::lookupSymbolIn(module, global_name()));
}
static void printAddressOfOp(OpAsmPrinter &p, AddressOfOp op) {
if (!llvm::PointerType::isValidElementType(op.getType().getUnderlyingType()))
return op.emitOpError(
"expects type to be a valid element type for an LLVM pointer");
- if (op.getParentOp() && !isa<ModuleOp>(op.getParentOp()))
+ if (op.getParentOp() &&
+ !(op.getParentOp()->hasTrait<OpTrait::SymbolTable>() &&
+ op.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>()))
return op.emitOpError("must appear at the module level");
if (auto strAttr = op.getValueOrNull().dyn_cast_or_null<StringAttr>()) {
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), globalPtr,
ArrayRef<Value *>({cst0, cst0}));
}
+
+bool mlir::LLVM::satisfiesLLVMModule(Operation *op) {
+ return op->hasTrait<OpTrait::SymbolTable>() &&
+ op->hasTrait<OpTrait::IsIsolatedFromAbove>();
+}
class ModuleTranslation : public LLVM::ModuleTranslation {
public:
- explicit ModuleTranslation(ModuleOp module)
+ explicit ModuleTranslation(Operation *module)
: LLVM::ModuleTranslation(module) {}
~ModuleTranslation() override {}
};
} // namespace
-std::unique_ptr<llvm::Module> mlir::translateModuleToNVVMIR(ModuleOp m) {
+std::unique_ptr<llvm::Module> mlir::translateModuleToNVVMIR(Operation *m) {
ModuleTranslation translation(m);
auto llvmModule =
LLVM::ModuleTranslation::translateModule<ModuleTranslation>(m);
// Insert the nvvm.annotations kernel so that the NVVM backend recognizes the
// function as a kernel.
- for (auto func : m.getOps<LLVM::LLVMFuncOp>()) {
+ for (auto func :
+ ModuleTranslation::getModuleBody(m).getOps<LLVM::LLVMFuncOp>()) {
if (!gpu::GPUDialect::isKernel(func))
continue;
class ModuleTranslation : public LLVM::ModuleTranslation {
public:
- explicit ModuleTranslation(ModuleOp module)
+ explicit ModuleTranslation(Operation *module)
: LLVM::ModuleTranslation(module) {}
~ModuleTranslation() override {}
};
} // namespace
-std::unique_ptr<llvm::Module> mlir::translateModuleToROCDLIR(ModuleOp m) {
+std::unique_ptr<llvm::Module> mlir::translateModuleToROCDLIR(Operation *m) {
ModuleTranslation translation(m);
// lower MLIR (with RODL Dialect) to LLVM IR (with ROCDL intrinsics)
// foreach GPU kernel
// 1. Insert AMDGPU_KERNEL calling convention.
// 2. Insert amdgpu-flat-workgroup-size(1, 1024) attribute.
- for (auto func : m.getOps<LLVM::LLVMFuncOp>()) {
+ for (auto func :
+ ModuleTranslation::getModuleBody(m).getOps<LLVM::LLVMFuncOp>()) {
if (!func.getAttrOfType<UnitAttr>(gpu::GPUDialect::getKernelFuncAttrName()))
continue;
// Create named global variables that correspond to llvm.mlir.global
// definitions.
void ModuleTranslation::convertGlobals() {
- for (auto op : mlirModule.getOps<LLVM::GlobalOp>()) {
+ for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>()) {
llvm::Type *type = op.getType().getUnderlyingType();
llvm::Constant *cst = llvm::UndefValue::get(type);
if (op.getValueOrNull()) {
return success();
}
-LogicalResult ModuleTranslation::checkSupportedModuleOps(ModuleOp m) {
- for (Operation &o : m.getBody()->getOperations())
+LogicalResult ModuleTranslation::checkSupportedModuleOps(Operation *m) {
+ for (Operation &o : getModuleBody(m).getOperations())
if (!isa<LLVM::LLVMFuncOp>(&o) && !isa<LLVM::GlobalOp>(&o) &&
- !isa<ModuleTerminatorOp>(&o))
+ !o.isKnownTerminator())
return o.emitOpError("unsupported module-level operation");
return success();
}
LogicalResult ModuleTranslation::convertFunctions() {
// Declare all functions first because there may be function calls that form a
// call graph with cycles.
- for (auto function : mlirModule.getOps<LLVMFuncOp>()) {
+ for (auto function : getModuleBody(mlirModule).getOps<LLVMFuncOp>()) {
llvm::FunctionCallee llvmFuncCst = llvmModule->getOrInsertFunction(
function.getName(),
llvm::cast<llvm::FunctionType>(function.getType().getUnderlyingType()));
}
// Convert functions.
- for (auto function : mlirModule.getOps<LLVMFuncOp>()) {
+ for (auto function : getModuleBody(mlirModule).getOps<LLVMFuncOp>()) {
// Ignore external functions.
if (function.isExternal())
continue;
return success();
}
-std::unique_ptr<llvm::Module> ModuleTranslation::prepareLLVMModule(ModuleOp m) {
- auto *dialect = m.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
+std::unique_ptr<llvm::Module>
+ModuleTranslation::prepareLLVMModule(Operation *m) {
+ auto *dialect = m->getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
assert(dialect && "LLVM dialect must be registered");
auto llvmModule = llvm::CloneModule(dialect->getLLVMModule());