From 44fc7d72b3cb44147394e22f1f21ad36cca7bca8 Mon Sep 17 00:00:00 2001 From: Tres Popp Date: Mon, 16 Dec 2019 01:35:03 -0800 Subject: [PATCH] Remove LLVM dependency on mlir::Module and instead check Traits. PiperOrigin-RevId: 285724678 --- mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h | 4 ++++ mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h | 20 +++++++++++++++----- mlir/include/mlir/Target/NVVMIR.h | 12 ++++++------ mlir/include/mlir/Target/ROCDLIR.h | 12 ++++++------ mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 16 +++++++++++++--- mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp | 7 ++++--- mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp | 7 ++++--- mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 17 +++++++++-------- 8 files changed, 61 insertions(+), 34 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h index 83c30e6..5332a74 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -198,6 +198,10 @@ Value *createGlobalString(Location loc, OpBuilder &builder, StringRef name, StringRef value, LLVM::Linkage linkage, LLVM::LLVMDialect *llvmDialect); +/// LLVM requires some operations to be inside of a Module operation. This +/// function confirms that the Operation has the desired properties. +bool satisfiesLLVMModule(Operation *op); + } // end namespace LLVM } // end namespace mlir diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h index b957c82..2889012 100644 --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -23,6 +23,7 @@ #ifndef MLIR_TARGET_LLVMIR_MODULETRANSLATION_H #define MLIR_TARGET_LLVMIR_MODULETRANSLATION_H +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/Block.h" #include "mlir/IR/Module.h" #include "mlir/IR/Value.h" @@ -50,7 +51,9 @@ class LLVMFuncOp; class ModuleTranslation { public: template - static std::unique_ptr translateModule(ModuleOp m) { + static std::unique_ptr translateModule(Operation *m) { + if (!satisfiesLLVMModule(m)) + return nullptr; if (failed(checkSupportedModuleOps(m))) return nullptr; auto llvmModule = prepareLLVMModule(m); @@ -66,23 +69,30 @@ public: return std::move(translator.llvmModule); } + /// A helper method to get the single Block in an operation honoring LLVM's + /// module requirements. + static Block &getModuleBody(Operation *m) { return m->getRegion(0).front(); } + protected: // Translate the given MLIR module expressed in MLIR LLVM IR dialect into an // LLVM IR module. The MLIR LLVM IR dialect holds a pointer to an // LLVMContext, the LLVM IR module will be created in that context. - explicit ModuleTranslation(ModuleOp module) : mlirModule(module) {} + explicit ModuleTranslation(Operation *module) : mlirModule(module) { + assert(satisfiesLLVMModule(mlirModule) && + "mlirModule should honor LLVM's module semantics."); + } virtual ~ModuleTranslation() {} virtual LogicalResult convertOperation(Operation &op, llvm::IRBuilder<> &builder); - static std::unique_ptr prepareLLVMModule(ModuleOp m); + static std::unique_ptr prepareLLVMModule(Operation *m); template SmallVector lookupValues(Range &&values); private: /// Check whether the module contains only supported ops directly in its body. - static LogicalResult checkSupportedModuleOps(ModuleOp m); + static LogicalResult checkSupportedModuleOps(Operation *m); LogicalResult convertFunctions(); void convertGlobals(); @@ -94,7 +104,7 @@ private: Location loc); // Original and translated module. - ModuleOp mlirModule; + Operation *mlirModule; std::unique_ptr llvmModule; // Mappings between llvm.mlir.global definitions and corresponding globals. diff --git a/mlir/include/mlir/Target/NVVMIR.h b/mlir/include/mlir/Target/NVVMIR.h index 3a4442e..ec9858e 100644 --- a/mlir/include/mlir/Target/NVVMIR.h +++ b/mlir/include/mlir/Target/NVVMIR.h @@ -30,14 +30,14 @@ class Module; } // namespace llvm namespace mlir { -class ModuleOp; +class Operation; -/// Convert the given MLIR module into NVVM IR. This conversion requires the -/// registration of the LLVM IR dialect and will extract the LLVM context -/// from the registered LLVM IR dialect. In case of error, report it -/// to the error handler registered with the MLIR context, if any (obtained from +/// Convert the given LLVM-module-like operation into NVVM IR. This conversion +/// requires the registration of the LLVM IR dialect and will extract the LLVM +/// context from the registered LLVM IR dialect. In case of error, report it to +/// the error handler registered with the MLIR context, if any (obtained from /// the MLIR module), and return `nullptr`. -std::unique_ptr translateModuleToNVVMIR(ModuleOp m); +std::unique_ptr translateModuleToNVVMIR(Operation *m); } // namespace mlir diff --git a/mlir/include/mlir/Target/ROCDLIR.h b/mlir/include/mlir/Target/ROCDLIR.h index 6295a1b..fd00e94 100644 --- a/mlir/include/mlir/Target/ROCDLIR.h +++ b/mlir/include/mlir/Target/ROCDLIR.h @@ -31,14 +31,14 @@ class Module; } // namespace llvm namespace mlir { -class ModuleOp; +class Operation; -/// Convert the given MLIR module into ROCDL IR. This conversion requires the -/// registration of the LLVM IR dialect and will extract the LLVM context -/// from the registered LLVM IR dialect. In case of error, report it -/// to the error handler registered with the MLIR context, if any (obtained from +/// Convert the given LLVM-module-like operation into ROCDL IR. This conversion +/// requires the registration of the LLVM IR dialect and will extract the LLVM +/// context from the registered LLVM IR dialect. In case of error, report it to +/// the error handler registered with the MLIR context, if any (obtained from /// the MLIR module), and return `nullptr`. -std::unique_ptr translateModuleToROCDLIR(ModuleOp m); +std::unique_ptr translateModuleToROCDLIR(Operation *m); } // namespace mlir diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index d037d2e..9201da2 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -790,9 +790,12 @@ static ParseResult parseUndefOp(OpAsmParser &parser, OperationState &result) { //===----------------------------------------------------------------------===// GlobalOp AddressOfOp::getGlobal() { - auto module = getParentOfType(); + Operation *module = getParentOp(); + while (module && !satisfiesLLVMModule(module)) + module = module->getParentOp(); assert(module && "unexpected operation outside of a module"); - return module.lookupSymbol(global_name()); + return dyn_cast_or_null( + mlir::SymbolTable::lookupSymbolIn(module, global_name())); } static void printAddressOfOp(OpAsmPrinter &p, AddressOfOp op) { @@ -1030,7 +1033,9 @@ static LogicalResult verify(GlobalOp op) { if (!llvm::PointerType::isValidElementType(op.getType().getUnderlyingType())) return op.emitOpError( "expects type to be a valid element type for an LLVM pointer"); - if (op.getParentOp() && !isa(op.getParentOp())) + if (op.getParentOp() && + !(op.getParentOp()->hasTrait() && + op.getParentOp()->hasTrait())) return op.emitOpError("must appear at the module level"); if (auto strAttr = op.getValueOrNull().dyn_cast_or_null()) { @@ -1675,3 +1680,8 @@ Value *mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder, loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), globalPtr, ArrayRef({cst0, cst0})); } + +bool mlir::LLVM::satisfiesLLVMModule(Operation *op) { + return op->hasTrait() && + op->hasTrait(); +} diff --git a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp index 606e91b..728dc86 100644 --- a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp @@ -58,7 +58,7 @@ static llvm::Intrinsic::ID getShflBflyIntrinsicId(llvm::Type *resultType, class ModuleTranslation : public LLVM::ModuleTranslation { public: - explicit ModuleTranslation(ModuleOp module) + explicit ModuleTranslation(Operation *module) : LLVM::ModuleTranslation(module) {} ~ModuleTranslation() override {} @@ -73,7 +73,7 @@ protected: }; } // namespace -std::unique_ptr mlir::translateModuleToNVVMIR(ModuleOp m) { +std::unique_ptr mlir::translateModuleToNVVMIR(Operation *m) { ModuleTranslation translation(m); auto llvmModule = LLVM::ModuleTranslation::translateModule(m); @@ -82,7 +82,8 @@ std::unique_ptr mlir::translateModuleToNVVMIR(ModuleOp m) { // Insert the nvvm.annotations kernel so that the NVVM backend recognizes the // function as a kernel. - for (auto func : m.getOps()) { + for (auto func : + ModuleTranslation::getModuleBody(m).getOps()) { if (!gpu::GPUDialect::isKernel(func)) continue; diff --git a/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp index dcd4d6c..7b7c368 100644 --- a/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp @@ -69,7 +69,7 @@ static llvm::Value *createDeviceFunctionCall(llvm::IRBuilder<> &builder, class ModuleTranslation : public LLVM::ModuleTranslation { public: - explicit ModuleTranslation(ModuleOp module) + explicit ModuleTranslation(Operation *module) : LLVM::ModuleTranslation(module) {} ~ModuleTranslation() override {} @@ -84,7 +84,7 @@ protected: }; } // namespace -std::unique_ptr mlir::translateModuleToROCDLIR(ModuleOp m) { +std::unique_ptr mlir::translateModuleToROCDLIR(Operation *m) { ModuleTranslation translation(m); // lower MLIR (with RODL Dialect) to LLVM IR (with ROCDL intrinsics) @@ -94,7 +94,8 @@ std::unique_ptr mlir::translateModuleToROCDLIR(ModuleOp m) { // foreach GPU kernel // 1. Insert AMDGPU_KERNEL calling convention. // 2. Insert amdgpu-flat-workgroup-size(1, 1024) attribute. - for (auto func : m.getOps()) { + for (auto func : + ModuleTranslation::getModuleBody(m).getOps()) { if (!func.getAttrOfType(gpu::GPUDialect::getKernelFuncAttrName())) continue; diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index f985fed..f5f9cca 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -311,7 +311,7 @@ llvm::GlobalVariable::LinkageTypes convertLinkageType(LLVM::Linkage linkage) { // Create named global variables that correspond to llvm.mlir.global // definitions. void ModuleTranslation::convertGlobals() { - for (auto op : mlirModule.getOps()) { + for (auto op : getModuleBody(mlirModule).getOps()) { llvm::Type *type = op.getType().getUnderlyingType(); llvm::Constant *cst = llvm::UndefValue::get(type); if (op.getValueOrNull()) { @@ -470,10 +470,10 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) { return success(); } -LogicalResult ModuleTranslation::checkSupportedModuleOps(ModuleOp m) { - for (Operation &o : m.getBody()->getOperations()) +LogicalResult ModuleTranslation::checkSupportedModuleOps(Operation *m) { + for (Operation &o : getModuleBody(m).getOperations()) if (!isa(&o) && !isa(&o) && - !isa(&o)) + !o.isKnownTerminator()) return o.emitOpError("unsupported module-level operation"); return success(); } @@ -481,7 +481,7 @@ LogicalResult ModuleTranslation::checkSupportedModuleOps(ModuleOp m) { LogicalResult ModuleTranslation::convertFunctions() { // Declare all functions first because there may be function calls that form a // call graph with cycles. - for (auto function : mlirModule.getOps()) { + for (auto function : getModuleBody(mlirModule).getOps()) { llvm::FunctionCallee llvmFuncCst = llvmModule->getOrInsertFunction( function.getName(), llvm::cast(function.getType().getUnderlyingType())); @@ -491,7 +491,7 @@ LogicalResult ModuleTranslation::convertFunctions() { } // Convert functions. - for (auto function : mlirModule.getOps()) { + for (auto function : getModuleBody(mlirModule).getOps()) { // Ignore external functions. if (function.isExternal()) continue; @@ -503,8 +503,9 @@ LogicalResult ModuleTranslation::convertFunctions() { return success(); } -std::unique_ptr ModuleTranslation::prepareLLVMModule(ModuleOp m) { - auto *dialect = m.getContext()->getRegisteredDialect(); +std::unique_ptr +ModuleTranslation::prepareLLVMModule(Operation *m) { + auto *dialect = m->getContext()->getRegisteredDialect(); assert(dialect && "LLVM dialect must be registered"); auto llvmModule = llvm::CloneModule(dialect->getLLVMModule()); -- 2.7.4